"""JWKS (JSON Web Key Set) manager for Axioms authentication.
This module provides both sync and async JWKS managers:
- JWKSManager: Thread-based manager for Flask, Django (WSGI), and other sync frameworks
- AsyncJWKSManager: Asyncio-based manager for FastAPI, Django (ASGI), and other async frameworks
"""
import asyncio
import atexit
import logging
import threading
import time
from typing import TYPE_CHECKING, Optional
from urllib.parse import urlparse
import httpx
from .errors import AxiomsError
if TYPE_CHECKING:
from .config import AxiomsConfig
logger = logging.getLogger(__name__)
[docs]
class JWKSManager:
"""Thread-safe JWKS manager with background refresh support.
This manager handles JWKS fetching with:
- HTTP requests using httpx (sync mode for framework compatibility)
- Periodic background refresh using threading
- In-memory caching with TTL
- Thread-safe access
- Framework-agnostic (works with FastAPI, Django, Flask)
The manager can be initialized on application startup or will lazy-initialize
on first use with a blocking fetch followed by background refreshes.
"""
_instance = None
_lock = threading.Lock()
[docs]
def __new__(cls):
"""Singleton pattern to ensure only one manager instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
[docs]
def __init__(self):
"""Initialize the JWKS manager."""
if self._initialized:
return
self._jwks_cache = {} # url -> (data, timestamp)
self._cache_lock = threading.RLock()
self._client = None
self._refresh_thread = None
self._refresh_interval = 3600 # 1 hour default
self._cache_ttl = (
7200 # 2 hours default (matches AxiomsConfig and AsyncJWKSManager)
)
self._running = False
self._stop_event = threading.Event()
self._initialized = True
# Register cleanup on exit
atexit.register(self.shutdown)
[docs]
def initialize(
self,
jwks_url: str,
refresh_interval: int = 3600,
cache_ttl: int = 600,
prefetch: bool = True,
):
"""Initialize the manager and start background refresh.
Args:
jwks_url: JWKS URL to fetch.
refresh_interval: Interval in seconds between refresh attempts (default: 3600).
cache_ttl: Cache TTL in seconds (default: 600).
prefetch: If True, pre-fetch JWKS before starting background refresh.
"""
with self._lock:
if self._running:
logger.debug("JWKS manager already running")
return
self._refresh_interval = refresh_interval
self._cache_ttl = cache_ttl
# Create httpx sync client with timeout
self._client = httpx.Client(
timeout=httpx.Timeout(10.0),
follow_redirects=True,
verify=True, # SSL verification enabled
)
# Pre-fetch JWKS if requested
if prefetch:
try:
self._fetch_jwks(jwks_url)
logger.info(f"Successfully pre-fetched JWKS from {jwks_url}")
except Exception as e:
logger.warning(f"Failed to pre-fetch JWKS: {e}")
# Don't raise - allow fallback to on-demand fetch
# Start background refresh thread
self._running = True
self._stop_event.clear()
self._refresh_thread = threading.Thread(
target=self._refresh_loop,
args=(jwks_url,),
daemon=True,
name="JWKSRefreshThread",
)
self._refresh_thread.start()
logger.info(
f"Started JWKS background refresh (interval: {refresh_interval}s)"
)
[docs]
def shutdown(self):
"""Shutdown the manager and cleanup resources."""
with self._lock:
if not self._running:
return
self._running = False
self._stop_event.set()
# Wait for refresh thread to finish
if self._refresh_thread and self._refresh_thread.is_alive():
self._refresh_thread.join(timeout=5.0)
# Close httpx client
if self._client:
self._client.close()
self._client = None
logger.info("JWKS manager shutdown complete")
def _refresh_loop(self, jwks_url: str):
"""Background thread to periodically refresh JWKS.
Args:
jwks_url: JWKS URL to refresh.
"""
while self._running:
# Wait for refresh interval or stop event
if self._stop_event.wait(timeout=self._refresh_interval):
# Stop event was set
break
if not self._running:
break
try:
self._fetch_jwks(jwks_url)
logger.debug(f"Refreshed JWKS from {jwks_url}")
except Exception as e:
logger.error(f"Error refreshing JWKS: {e}")
def _fetch_jwks(self, url: str) -> bytes:
"""Fetch JWKS data from URL using httpx.
Args:
url: JWKS URL to fetch.
Returns:
bytes: JWKS data.
Raises:
AxiomsError: If URL scheme is invalid.
Exception: If fetch fails.
"""
# Validate URL scheme
parsed_url = urlparse(url)
if parsed_url.scheme not in ("http", "https"):
logger.error(f"Invalid URL scheme: {parsed_url.scheme}. URL: {url}")
raise AxiomsError(
{
"error": "server_error",
"error_description": (
"Invalid JWKS URL configuration. "
"Only http and https schemes are allowed."
),
},
500,
)
# Use httpx client if initialized, otherwise create temporary one
if self._client is not None:
response = self._client.get(url)
else:
with httpx.Client(
timeout=httpx.Timeout(10.0), follow_redirects=True, verify=True
) as client:
response = client.get(url)
response.raise_for_status()
data = response.content
# Update cache with thread safety
with self._cache_lock:
timestamp = time.time()
self._jwks_cache[url] = (data, timestamp)
return data
[docs]
def get_jwks(self, url: str) -> bytes:
"""Get JWKS data from cache or fetch if needed.
This method is thread-safe and can be called from any context.
Args:
url: JWKS URL.
Returns:
bytes: JWKS data.
"""
# Check cache first
with self._cache_lock:
if url in self._jwks_cache:
data, timestamp = self._jwks_cache[url]
age = time.time() - timestamp
if age < self._cache_ttl:
logger.debug(f"JWKS cache hit for {url} (age: {age:.1f}s)")
return data
# Cache miss or expired - fetch new data
logger.debug(f"JWKS cache miss for {url}, fetching...")
return self._fetch_jwks(url)
[docs]
class AsyncJWKSManager:
"""Async JWKS manager with background refresh support for async applications.
This manager handles JWKS fetching with:
- HTTP requests using httpx.AsyncClient
- Periodic background refresh using asyncio
- In-memory caching with TTL
- Async-safe access
- For FastAPI, Django (ASGI), and other async frameworks
The manager should be initialized on application startup.
"""
_instance = None
_lock = asyncio.Lock()
[docs]
def __new__(cls):
"""Singleton pattern to ensure only one manager instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
[docs]
def __init__(self):
"""Initialize the async JWKS manager."""
if self._initialized:
return
self._jwks_cache = {} # url -> (data, timestamp)
self._cache_lock = asyncio.Lock()
self._client = None
self._refresh_task = None
self._refresh_interval = 3600 # 1 hour default
self._cache_ttl = 7200 # 2 hours default
self._running = False
self._stop_event = asyncio.Event()
self._initialized = True
[docs]
async def initialize(
self,
jwks_url: str,
refresh_interval: int = 3600,
cache_ttl: int = 7200,
prefetch: bool = True,
):
"""Initialize the manager and start background refresh.
Args:
jwks_url: JWKS URL to fetch.
refresh_interval: Interval in seconds between refresh attempts (default: 3600).
cache_ttl: Cache TTL in seconds (default: 7200).
prefetch: If True, pre-fetch JWKS before starting background refresh.
"""
if self._running:
logger.debug("Async JWKS manager already running")
return
self._refresh_interval = refresh_interval
self._cache_ttl = cache_ttl
# Create httpx async client with timeout
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(10.0),
follow_redirects=True,
verify=True, # SSL verification enabled
)
# Pre-fetch JWKS if requested
if prefetch:
try:
await self._fetch_jwks(jwks_url)
logger.info(f"Successfully pre-fetched JWKS from {jwks_url}")
except Exception as e:
logger.warning(f"Failed to pre-fetch JWKS: {e}")
# Don't raise - allow fallback to on-demand fetch
# Start background refresh task
self._running = True
self._stop_event.clear()
self._refresh_task = asyncio.create_task(self._refresh_loop(jwks_url))
logger.info(
f"Started async JWKS background refresh (interval: {refresh_interval}s)"
)
[docs]
async def shutdown(self):
"""Shutdown the manager and cleanup resources."""
if not self._running:
return
self._running = False
self._stop_event.set()
# Cancel refresh task
if self._refresh_task and not self._refresh_task.done():
self._refresh_task.cancel()
try:
await self._refresh_task
except asyncio.CancelledError:
pass
# Close httpx client
if self._client:
await self._client.aclose()
self._client = None
logger.info("Async JWKS manager shutdown complete")
async def _refresh_loop(self, jwks_url: str):
"""Background task to periodically refresh JWKS.
Args:
jwks_url: JWKS URL to refresh.
"""
while self._running:
try:
# Wait for refresh interval or stop event
await asyncio.wait_for(
self._stop_event.wait(), timeout=self._refresh_interval
)
# Stop event was set
break
except asyncio.TimeoutError:
# Timeout reached, time to refresh
pass
if not self._running:
break
try:
await self._fetch_jwks(jwks_url)
logger.debug(f"Refreshed JWKS from {jwks_url}")
except Exception as e:
logger.error(f"Error refreshing JWKS: {e}")
async def _fetch_jwks(self, url: str) -> bytes:
"""Fetch JWKS data from URL using httpx.
Args:
url: JWKS URL to fetch.
Returns:
bytes: JWKS data.
Raises:
AxiomsError: If URL scheme is invalid.
Exception: If fetch fails.
"""
# Validate URL scheme
parsed_url = urlparse(url)
if parsed_url.scheme not in ("http", "https"):
logger.error(f"Invalid URL scheme: {parsed_url.scheme}. URL: {url}")
raise AxiomsError(
{
"error": "server_error",
"error_description": (
"Invalid JWKS URL configuration. "
"Only http and https schemes are allowed."
),
},
500,
)
# Use httpx async client if initialized, otherwise create temporary one
if self._client is not None:
response = await self._client.get(url)
else:
async with httpx.AsyncClient(
timeout=httpx.Timeout(10.0), follow_redirects=True, verify=True
) as client:
response = await client.get(url)
response.raise_for_status()
data = response.content
# Update cache with async safety
async with self._cache_lock:
timestamp = time.time()
self._jwks_cache[url] = (data, timestamp)
return data
[docs]
async def get_jwks(self, url: str) -> bytes:
"""Get JWKS data from cache or fetch if needed.
This method is async-safe and can be called from async contexts.
Args:
url: JWKS URL.
Returns:
bytes: JWKS data.
"""
# Check cache first
async with self._cache_lock:
if url in self._jwks_cache:
data, timestamp = self._jwks_cache[url]
age = time.time() - timestamp
if age < self._cache_ttl:
logger.debug(f"JWKS cache hit for {url} (age: {age:.1f}s)")
return data
# Cache miss or expired - fetch new data
logger.debug(f"JWKS cache miss for {url}, fetching...")
return await self._fetch_jwks(url)
# Global JWKS manager instances
_jwks_manager = JWKSManager()
_async_jwks_manager = AsyncJWKSManager()
[docs]
def initialize_jwks_manager(
config: Optional["AxiomsConfig"] = None,
jwks_url: Optional[str] = None,
refresh_interval: int = 3600,
cache_ttl: int = 7200,
prefetch: bool = True,
):
"""Initialize the global JWKS manager for sync/threading-based frameworks.
This should be called during application startup to enable background JWKS
refresh and avoid blocking requests. Uses threading for background refresh.
Use this for:
- Flask (WSGI mode - the default, even with async route handlers)
- Django WSGI
- Any framework using threading/WSGI
For truly async frameworks using asyncio event loops:
- FastAPI: Use initialize_async_jwks_manager
- Django ASGI: Use initialize_async_jwks_manager
- Flask (ASGI mode with Hypercorn/Uvicorn): Use initialize_async_jwks_manager
Note: Flask 2.0+ supports async/await in route handlers but still runs on WSGI
by default (using thread pools). Only use initialize_async_jwks_manager if running
Flask on an ASGI server like Hypercorn.
Args:
config: AxiomsConfig object (recommended). If provided, uses config values by default.
jwks_url: JWKS URL to fetch. If None and config provided, uses config.AXIOMS_JWKS_URL.
refresh_interval: Interval in seconds between refresh attempts (default: 3600).
cache_ttl: Cache TTL in seconds (default: 7200, must be >= 2x refresh_interval).
prefetch: If True, pre-fetch JWKS before starting background refresh.
Example:
Flask-WSGI::
from axioms_core import AxiomsConfig, initialize_jwks_manager, shutdown_jwks_manager
from flask import Flask
app = Flask(__name__)
config = AxiomsConfig(
AXIOMS_JWKS_URL="https://auth.example.com/.well-known/jwks.json",
AXIOMS_JWKS_REFRESH_INTERVAL=1800,
AXIOMS_JWKS_CACHE_TTL=3600,
)
@app.before_first_request
def startup():
initialize_jwks_manager(config=config)
# Shutdown automatically called via atexit, or manually:
@app.teardown_appcontext
def shutdown(exception=None):
shutdown_jwks_manager()
Example:
Django-WSGI::
# In apps.py
from django.apps import AppConfig
from axioms_core import initialize_jwks_manager
class MyAppConfig(AppConfig):
def ready(self):
initialize_jwks_manager(
jwks_url="https://auth.example.com/.well-known/jwks.json",
refresh_interval=1800,
cache_ttl=3600
)
"""
# Use config values if provided
if config is not None:
jwks_url = jwks_url or config.AXIOMS_JWKS_URL
refresh_interval = config.AXIOMS_JWKS_REFRESH_INTERVAL
cache_ttl = config.AXIOMS_JWKS_CACHE_TTL
prefetch = config.AXIOMS_JWKS_PREFETCH
if jwks_url is None:
raise ValueError(
"jwks_url must be provided either directly or via config.AXIOMS_JWKS_URL"
)
_jwks_manager.initialize(jwks_url, refresh_interval, cache_ttl, prefetch)
[docs]
def shutdown_jwks_manager():
"""Shutdown the global JWKS manager.
This should be called during application shutdown to cleanup resources.
It's also automatically called via atexit registration.
"""
_jwks_manager.shutdown()
[docs]
async def initialize_async_jwks_manager(
config: Optional["AxiomsConfig"] = None,
jwks_url: Optional[str] = None,
refresh_interval: int = 3600,
cache_ttl: int = 7200,
prefetch: bool = True,
):
"""Initialize the global async JWKS manager for asyncio-based frameworks.
This should be called during application startup to enable background JWKS
refresh and avoid blocking requests. Uses asyncio for background refresh.
Use this for frameworks running on asyncio event loops:
- FastAPI (always uses asyncio)
- Django ASGI (async mode)
- Flask on ASGI servers (Hypercorn, Uvicorn, etc.)
- Any ASGI application
For threading-based frameworks, use initialize_jwks_manager instead:
- Flask (WSGI mode - the default)
- Django WSGI
Note: Flask 2.0+ supports async/await syntax but runs on WSGI by default.
Only use this function if running Flask on an ASGI server like Hypercorn.
Args:
config: AxiomsConfig object (recommended). If provided, uses config values by default.
jwks_url: JWKS URL to fetch. If None and config provided, uses config.AXIOMS_JWKS_URL.
refresh_interval: Interval in seconds between refresh attempts (default: 3600).
cache_ttl: Cache TTL in seconds (default: 7200, must be >= 2x refresh_interval).
prefetch: If True, pre-fetch JWKS before starting background refresh.
Example:
with FastAPI (lifespan context manager)::
from contextlib import asynccontextmanager
from fastapi import FastAPI
from axioms_core import (
AxiomsConfig,
initialize_async_jwks_manager,
shutdown_async_jwks_manager
)
config = AxiomsConfig(
AXIOMS_JWKS_URL="https://auth.example.com/.well-known/jwks.json",
AXIOMS_JWKS_REFRESH_INTERVAL=1800,
AXIOMS_JWKS_CACHE_TTL=3600,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
await initialize_async_jwks_manager(config=config)
yield
# Shutdown
await shutdown_async_jwks_manager()
app = FastAPI(lifespan=lifespan)
Example:
FastAPI with startup/shutdown events::
from fastapi import FastAPI
from axioms_core import initialize_async_jwks_manager, shutdown_async_jwks_manager
app = FastAPI()
@app.on_event("startup")
async def startup():
await initialize_async_jwks_manager(
jwks_url="https://auth.example.com/.well-known/jwks.json"
)
@app.on_event("shutdown")
async def shutdown():
await shutdown_async_jwks_manager()
Example:
Django-ASGI::
# In asgi.py
from django.core.asgi import get_asgi_application
from axioms_core import initialize_async_jwks_manager
import asyncio
# Initialize JWKS manager before application starts
asyncio.run(initialize_async_jwks_manager(
jwks_url="https://auth.example.com/.well-known/jwks.json",
refresh_interval=1800,
cache_ttl=3600
))
application = get_asgi_application()
"""
# Use config values if provided
if config is not None:
jwks_url = jwks_url or config.AXIOMS_JWKS_URL
refresh_interval = config.AXIOMS_JWKS_REFRESH_INTERVAL
cache_ttl = config.AXIOMS_JWKS_CACHE_TTL
prefetch = config.AXIOMS_JWKS_PREFETCH
if jwks_url is None:
raise ValueError(
"jwks_url must be provided either directly or via config.AXIOMS_JWKS_URL"
)
await _async_jwks_manager.initialize(
jwks_url, refresh_interval, cache_ttl, prefetch
)
[docs]
async def shutdown_async_jwks_manager():
"""Shutdown the global async JWKS manager.
This should be called during application shutdown to cleanup resources.
"""
await _async_jwks_manager.shutdown()