diff --git a/.gitignore b/.gitignore index fe90143..9814158 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,7 @@ test.py test-script.py .coverage coverage.xml - +examples/mcd-poc +IMPLEMENTATION_NOTES.md +examples/MCD_DEVELOPER_GUIDE.md +DESIGN_DOC_REVISED.md \ No newline at end of file diff --git a/examples/MultipleCustomDomains.md b/examples/MultipleCustomDomains.md new file mode 100644 index 0000000..85d188d --- /dev/null +++ b/examples/MultipleCustomDomains.md @@ -0,0 +1,139 @@ +# Multiple Custom Domains (MCD) Guide + +This guide explains how to implement Multiple Custom Domain (MCD) support using the Auth0 Python SDKs. + +## What is MCD? + +Multiple Custom Domains (MCD) allows your application to serve different organizations or tenants from different hostnames, each mapping to a different Auth0 tenant/domain. + +**Example:** +- `https://acme.yourapp.com` → Auth0 tenant: `acme.auth0.com` +- `https://globex.yourapp.com` → Auth0 tenant: `globex.auth0.com` + +Each tenant gets its own branded login experience while using a single application codebase. + +## Configuration Methods + +### Method 1: Static Domain (Single Tenant) + +For applications with a single Auth0 domain: + +```python +from auth0_server_python import ServerClient + +client = ServerClient( + domain="your-tenant.auth0.com", # Static string + client_id="your_client_id", + client_secret="your_client_secret", + secret="your_encryption_secret" +) +``` + +### Method 2: Dynamic Domain Resolver (MCD) + +For MCD support, provide a domain resolver function that receives a `DomainResolverContext`: + +```python +from auth0_server_python import ServerClient +from auth0_server_python.auth_types import DomainResolverContext + +# Map your app hostnames to Auth0 domains +DOMAIN_MAP = { + "acme.yourapp.com": "acme.auth0.com", + "globex.yourapp.com": "globex.auth0.com", +} +DEFAULT_DOMAIN = "default.auth0.com" + +async def domain_resolver(context: DomainResolverContext) -> str: + """ + Resolve Auth0 domain based on request hostname. + + Args: + context: Contains request_url and request_headers + + Returns: + Auth0 domain string (e.g., "acme.auth0.com") + """ + # Extract hostname from request headers + if not context.request_headers: + return DEFAULT_DOMAIN + + host = context.request_headers.get('host', DEFAULT_DOMAIN) + host_without_port = host.split(':')[0] + + # Look up Auth0 domain + return DOMAIN_MAP.get(host_without_port, DEFAULT_DOMAIN) + +client = ServerClient( + domain=domain_resolver, # Callable function + client_id="your_client_id", + client_secret="your_client_secret", + secret="your_encryption_secret" +) +``` + +## DomainResolverContext + +The `DomainResolverContext` object provides request information to your resolver: + +| Property | Type | Description | +|----------|------|-------------| +| `request_url` | `Optional[str]` | Full request URL (e.g., "https://acme.yourapp.com/auth/login") | +| `request_headers` | `Optional[dict[str, str]]` | Request headers dictionary | + +**Common headers:** +- `host`: Request hostname (e.g., "acme.yourapp.com") +- `x-forwarded-host`: Original host when behind proxy/load balancer + +**Example usage:** + +```python +async def domain_resolver(context: DomainResolverContext) -> str: + # Check if we have request headers + if not context.request_headers: + return DEFAULT_DOMAIN + + # Use x-forwarded-host if behind proxy, otherwise use host + host = (context.request_headers.get('x-forwarded-host') or + context.request_headers.get('host', '')) + + # Remove port number if present + hostname = host.split(':')[0].lower() + + # Look up in mapping + return DOMAIN_MAP.get(hostname, DEFAULT_DOMAIN) +``` + +## Error Handling + +### DomainResolverError + +The domain resolver should return a valid Auth0 domain string. Invalid returns will raise `DomainResolverError`: + +```python +from auth0_server_python.error import DomainResolverError + +async def domain_resolver(context: DomainResolverContext) -> str: + try: + domain = lookup_domain_from_db(context) + + if not domain: + # Return default instead of None + return DEFAULT_DOMAIN + + return domain # Must be a non-empty string + + except Exception as e: + # Log error and return default + logger.error(f"Domain resolution failed: {e}") + return DEFAULT_DOMAIN +``` + +**Invalid return values that raise `DomainResolverError`:** +- `None` +- Empty string `""` +- Non-string types (int, list, dict, etc.) + +**Exceptions raised by your resolver:** +- Automatically wrapped in `DomainResolverError` +- Original exception accessible via `.original_error` \ No newline at end of file diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 1e62aa5..627549e 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -6,7 +6,7 @@ import asyncio import json import time -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Callable, Generic, Optional, TypeVar, Union from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import httpx @@ -38,8 +38,10 @@ AccessTokenForConnectionErrorCode, ApiError, BackchannelLogoutError, + ConfigurationError, CustomTokenExchangeError, CustomTokenExchangeErrorCode, + DomainResolverError, InvalidArgumentError, MissingRequiredArgumentError, MissingTransactionError, @@ -47,13 +49,17 @@ StartLinkUserError, ) from auth0_server_python.utils import PKCE, URL, State +from auth0_server_python.utils.helpers import ( + build_domain_resolver_context, + validate_resolved_domain_value, +) from authlib.integrations.base_client.errors import OAuthError from authlib.integrations.httpx_client import AsyncOAuth2Client from pydantic import ValidationError # Generic type for store options TStoreOptions = TypeVar('TStoreOptions') -INTERNAL_AUTHORIZE_PARAMS = ["client_id", "redirect_uri", "response_type", +INTERNAL_AUTHORIZE_PARAMS = ["client_id", "response_type", "code_challenge", "code_challenge_method", "state", "nonce", "scope"] @@ -70,9 +76,9 @@ class ServerClient(Generic[TStoreOptions]): def __init__( self, - domain: str, - client_id: str, - client_secret: str, + domain: Union[str, Callable[[Optional[dict[str, Any]]], str]] = None, + client_id: str = None, + client_secret: str = None, redirect_uri: Optional[str] = None, secret: str = None, transaction_store=None, @@ -80,13 +86,13 @@ def __init__( transaction_identifier: str = "_a0_tx", state_identifier: str = "_a0_session", authorization_params: Optional[dict[str, Any]] = None, - pushed_authorization_requests: bool = False + pushed_authorization_requests: bool = False, ): """ Initialize the Auth0 server client. Args: - domain: Auth0 domain (e.g., 'your-tenant.auth0.com') + domain: Auth0 domain - either a static string (e.g., 'tenant.auth0.com') or a callable that resolves domain dynamically. client_id: Auth0 client ID client_secret: Auth0 client secret redirect_uri: Default redirect URI for authentication @@ -96,12 +102,35 @@ def __init__( transaction_identifier: Identifier for transaction data state_identifier: Identifier for state data authorization_params: Default parameters for authorization requests + pushed_authorization_requests: Whether to use Pushed Authorization Requests """ if not secret: raise MissingRequiredArgumentError("secret") - # Store configuration - self._domain = domain + if domain is None: + raise ConfigurationError( + "Domain is required" + ) + + # Validate domain type + if not isinstance(domain, str) and not callable(domain): + raise ConfigurationError( + f"Domain must be either a string or a callable function. " + f"Got {type(domain).__name__} instead." + ) + + # Determine if domain is static string or dynamic callable + if callable(domain): + self._domain = None + self._domain_resolver = domain + else: + # Validate static domain string + domain_str = str(domain) + if not domain_str or domain_str.strip() == "": + raise ConfigurationError("Domain cannot be empty.") + self._domain = domain_str + self._domain_resolver = None + self._client_id = client_id self._client_secret = client_secret self._redirect_uri = redirect_uri @@ -122,14 +151,162 @@ def __init__( self._my_account_client = MyAccountClient(domain=domain) + # Cache for OIDC metadata and JWKS (Requirement 3: MCD Support) + self._metadata_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} + self._jwks_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} + self._cache_ttl = 3600 # 1 hour TTL + self._cache_max_entries = 100 # Max 100 domains to prevent memory bloat + + def _normalize_domain(self, domain: str) -> str: + """ + Normalize domain for comparison and URL construction. + Handles cases with/without https:// scheme. + """ + if domain.startswith('https://'): + return domain + elif domain.startswith('http://'): + return domain.replace('http://', 'https://') + else: + return f'https://{domain}' + + def _normalize_issuer(self, issuer: str) -> str: + """ + Normalize issuer URL for comparison. + + Args: + issuer: The issuer URL to normalize + + Returns: + Normalized issuer URL (lowercase) + """ + if not issuer: + return issuer + + # Lowercase first for case-insensitive comparison and scheme detection + issuer = issuer.lower() + + # Ensure https:// prefix + if issuer.startswith('http://'): + issuer = issuer.replace('http://', 'https://', 1) + elif not issuer.startswith('https://'): + issuer = f'https://{issuer}' + + # Remove trailing slash + return issuer.rstrip('/') + async def _fetch_oidc_metadata(self, domain: str) -> dict: - """Fetch OpenID Connect discovery metadata from the Auth0 domain.""" - metadata_url = f"https://{domain}/.well-known/openid-configuration" + """Fetch OIDC metadata from domain.""" + normalized_domain = self._normalize_domain(domain) + metadata_url = f"{normalized_domain}/.well-known/openid-configuration" async with httpx.AsyncClient() as client: response = await client.get(metadata_url) response.raise_for_status() return response.json() + async def _get_oidc_metadata_cached(self, domain: str) -> dict: + """ + Get OIDC metadata with caching. + + Args: + domain: Auth0 domain + + Returns: + OIDC metadata document + """ + now = time.time() + + # Check cache + if domain in self._metadata_cache: + cached = self._metadata_cache[domain] + if cached["expires_at"] > now: + return cached["data"] + + # Cache miss/expired - fetch fresh + metadata = await self._fetch_oidc_metadata(domain) + + # Enforce cache size limit (FIFO eviction) + if len(self._metadata_cache) >= self._cache_max_entries: + oldest_key = next(iter(self._metadata_cache)) + del self._metadata_cache[oldest_key] + + # Store in cache + self._metadata_cache[domain] = { + "data": metadata, + "expires_at": now + self._cache_ttl + } + + return metadata + + async def _fetch_jwks(self, jwks_uri: str) -> dict: + """ + Fetch JWKS (JSON Web Key Set) from jwks_uri. + + Args: + jwks_uri: The JWKS endpoint URL + + Returns: + JWKS document containing public keys + + Raises: + ApiError: If JWKS fetch fails + """ + try: + async with httpx.AsyncClient() as client: + response = await client.get(jwks_uri) + response.raise_for_status() + return response.json() + except Exception as e: + raise ApiError("jwks_fetch_error", f"Failed to fetch JWKS from {jwks_uri}", e) + + async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: + """ + Get JWKS with caching usingOIDC discovery. + + Args: + domain: Auth0 domain + metadata: Optional OIDC metadata (if already fetched) + + Returns: + JWKS document + + Raises: + ApiError: If JWKS fetch fails or jwks_uri missing from metadata + """ + now = time.time() + + # Check cache + if domain in self._jwks_cache: + cached = self._jwks_cache[domain] + if cached["expires_at"] > now: + return cached["data"] + + # Get jwks_uri from OIDC metadata + if not metadata: + metadata = await self._get_oidc_metadata_cached(domain) + + jwks_uri = metadata.get('jwks_uri') + if not jwks_uri: + raise ApiError( + "missing_jwks_uri", + f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant." + ) + + # Fetch JWKS + jwks = await self._fetch_jwks(jwks_uri) + + # Enforce cache size limit (FIFO eviction) + if len(self._jwks_cache) >= self._cache_max_entries: + oldest_key = next(iter(self._jwks_cache)) + del self._jwks_cache[oldest_key] + + # Store in cache + self._jwks_cache[domain] = { + "data": jwks, + "expires_at": now + self._cache_ttl + } + + return jwks + # ============================================================================ # INTERACTIVE LOGIN FLOW # Handles browser-based authentication using the Authorization Code flow @@ -146,12 +323,38 @@ async def start_interactive_login( Args: options: Configuration options for the login process + store_options: Store options containing request/response Returns: Authorization URL to redirect the user to """ options = options or StartInteractiveLoginOptions() + # Resolve domain (static or dynamic) + if self._domain_resolver: + # Build context and call developer's resolver + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + origin_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + else: + origin_domain = self._domain + + # Fetch OIDC metadata from resolved domain + try: + metadata = await self._get_oidc_metadata_cached(origin_domain) + origin_issuer = metadata.get('issuer') + except Exception as e: + raise ApiError("metadata_error", + "Failed to fetch OIDC metadata", e) + # Get effective authorization params (merge defaults with provided ones) auth_params = dict(self._default_authorization_params) if options.authorization_params: @@ -180,17 +383,20 @@ async def start_interactive_login( state = PKCE.generate_random_string(32) auth_params["state"] = state - #merge any requested scope with defaults + # Merge any requested scope with defaults requested_scope = options.authorization_params.get("scope", None) if options.authorization_params else None audience = auth_params.get("audience", None) merged_scope = self._merge_scope_with_defaults(requested_scope, audience) auth_params["scope"] = merged_scope - # Build the transaction data to store + # Build the transaction data to store with origin domain and issuer transaction_data = TransactionData( code_verifier=code_verifier, app_state=options.app_state, audience=audience, + origin_domain=origin_domain, + origin_issuer=origin_issuer, + redirect_uri=auth_params.get("redirect_uri"), ) # Store the transaction data @@ -199,11 +405,9 @@ async def start_interactive_login( transaction_data, options=store_options ) - try: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) - except Exception as e: - raise ApiError("metadata_error", - "Failed to fetch OIDC metadata", e) + + # Set metadata for OAuth client + self._oauth.metadata = metadata # If PAR is enabled, use the PAR endpoint if self._pushed_authorization_requests: par_endpoint = self._oauth.metadata.get( @@ -294,34 +498,105 @@ async def complete_interactive_login( if not code: raise MissingRequiredArgumentError("code") - if not self._oauth.metadata or "token_endpoint" not in self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + # Get origin domain and issuer from transaction + origin_domain = transaction_data.origin_domain + origin_issuer = transaction_data.origin_issuer + + # Fetch metadata from the origin domain + metadata = await self._get_oidc_metadata_cached(origin_domain) + self._oauth.metadata = metadata # Exchange the code for tokens + # Use redirect_uri from transaction if available, otherwise fall back to default + token_redirect_uri = transaction_data.redirect_uri or self._redirect_uri try: token_endpoint = self._oauth.metadata["token_endpoint"] token_response = await self._oauth.fetch_token( token_endpoint, code=code, code_verifier=transaction_data.code_verifier, - redirect_uri=self._redirect_uri, + redirect_uri=token_redirect_uri, ) except OAuthError as e: # Raise a custom error (or handle it as appropriate) raise ApiError( "token_error", f"Token exchange failed: {str(e)}", e) + print(f"Token Response : {token_response}") - # Use the userinfo field from the token_response for user claims + # Use the userinfo field from the token_response for user claims user_info = token_response.get("userinfo") user_claims = None + id_token = token_response.get("id_token") + if user_info: user_claims = UserClaims.parse_obj(user_info) - else: - id_token = token_response.get("id_token") - if id_token: - claims = jwt.decode(id_token, options={ - "verify_signature": False}) + elif id_token: + # Fetch JWKS for signature verification (Requirement 3) + jwks = await self._get_jwks_cached(origin_domain, metadata) + + # Decode and verify ID token with signature verification enabled + try: + # Get the signing key from JWKS + unverified_header = jwt.get_unverified_header(id_token) + kid = unverified_header.get("kid") + + # Find the key with matching kid + signing_key = None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + signing_key = jwt.PyJWK.from_dict(key) + break + + if not signing_key: + raise ApiError( + "jwks_key_not_found", + f"No matching key found in JWKS for kid: {kid}" + ) + + # Decode with signature + claims = jwt.decode( + id_token, + signing_key.key, + algorithms=["RS256"], + audience=self._client_id, + options={"verify_signature": True, "verify_iss": False} + ) + + # Custom normalized issuer validation + token_issuer = claims.get("iss", "") + if self._normalize_issuer(token_issuer) != self._normalize_issuer(origin_issuer): + raise ApiError( + "invalid_issuer", + f"ID token issuer mismatch. Token issuer: {token_issuer}, Expected: {origin_issuer}. " + f"Ensure your Auth0 domain is configured correctly." + ) + user_claims = UserClaims.parse_obj(claims) + except jwt.InvalidSignatureError as e: + raise ApiError( + "invalid_signature", + f"ID token signature verification failed. The token may have been tampered with or is from an untrusted source: {str(e)}", + e + ) + except jwt.InvalidAudienceError as e: + raise ApiError( + "invalid_audience", + f"ID token audience mismatch. Expected: {self._client_id}. Ensure your client_id is configured correctly: {str(e)}", + e + ) + except jwt.ExpiredSignatureError as e: + raise ApiError( + "token_expired", + f"ID token has expired: {str(e)}", + e + ) + except jwt.InvalidTokenError as e: + raise ApiError( + "invalid_token", + f"ID token verification failed: {str(e)}", + e + ) + # Build a token set using the token response data token_set = TokenSet( @@ -343,6 +618,7 @@ async def complete_interactive_login( # might be None if not provided refresh_token=token_response.get("refresh_token"), token_sets=[token_set], + domain=origin_domain, internal={ "sid": sid, "created_at": int(time.time()) @@ -384,6 +660,25 @@ async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Opti state_data = await self._state_store.get(self._state_identifier, store_options) if state_data: + # Validate session domain matches current request domain + if self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + current_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + session_domain = getattr(state_data, 'domain', None) + + if session_domain and session_domain != current_domain: + # Session created with different domain - reject for security + return None + if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() return state_data.get("user") @@ -402,6 +697,25 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O state_data = await self._state_store.get(self._state_identifier, store_options) if state_data: + # Validate session domain matches current request domain + if self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + current_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + session_domain = getattr(state_data, 'domain', None) + + if session_domain and session_domain != current_domain: + # Session created with different domain - reject for security + return None + if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() session_data = {k: v for k, v in state_data.items() @@ -414,24 +728,30 @@ async def logout( options: Optional[LogoutOptions] = None, store_options: Optional[dict[str, Any]] = None ) -> str: - """ - Logs the user out and returns the Auth0 logout URL. - - Args: - options: Logout options including return_to URL. - store_options: Optional options used to pass to the State Store. - - Returns: - The Auth0 logout URL to redirect the user to. - """ options = options or LogoutOptions() # Delete the session from the state store await self._state_store.delete(self._state_identifier, store_options) + # Resolve domain dynamically for MCD support + if self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + else: + domain = self._domain + # Use the URL helper to create the logout URL. logout_url = URL.create_logout_url( - self._domain, self._client_id, options.return_to) + domain, self._client_id, options.return_to) return logout_url @@ -441,7 +761,7 @@ async def handle_backchannel_logout( store_options: Optional[dict[str, Any]] = None ) -> None: """ - Handles backchannel logout requests (OIDC Back-Channel Logout specification). + Handles backchannel logout requests. Args: logout_token: The logout token sent by Auth0 @@ -451,9 +771,50 @@ async def handle_backchannel_logout( raise BackchannelLogoutError("Missing logout token") try: - # Decode the token without verification - claims = jwt.decode(logout_token, options={ - "verify_signature": False}) + # Fetch JWKS for signature verification (Requirement 3) + jwks = await self._get_jwks_cached(self._domain) + + # Decode and verify logout token with signature verification enabled + try: + # Get the signing key from JWKS + unverified_header = jwt.get_unverified_header(logout_token) + kid = unverified_header.get("kid") + + # Find the key with matching kid + signing_key = None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + signing_key = jwt.PyJWK.from_dict(key) + break + + if not signing_key: + raise BackchannelLogoutError( + f"No matching key found in JWKS for kid: {kid}" + ) + + claims = jwt.decode( + logout_token, + signing_key.key, + algorithms=["RS256"], + options={"verify_signature": True, "verify_iss": False} + ) + + # Normalized issuer validation + token_issuer = claims.get("iss", "") + expected_issuer = self._normalize_domain(self._domain) + if self._normalize_issuer(token_issuer) != self._normalize_issuer(expected_issuer): + raise BackchannelLogoutError( + f"Logout token issuer mismatch. Token issuer: {token_issuer}, Expected: {expected_issuer}. " + f"Ensure your Auth0 domain is configured correctly." + ) + except jwt.InvalidSignatureError as e: + raise BackchannelLogoutError( + f"Logout token signature verification failed: {str(e)}" + ) + except jwt.InvalidTokenError as e: + raise BackchannelLogoutError( + f"Logout token verification failed: {str(e)}" + ) # Validate the token is a logout token events = claims.get("events", {}) @@ -469,7 +830,7 @@ async def handle_backchannel_logout( await self._state_store.delete_by_logout_token(logout_claims.dict(), store_options) - except (jwt.JoseError, ValidationError) as e: + except (jwt.PyJWTError, ValidationError) as e: raise BackchannelLogoutError( f"Error processing logout token: {str(e)}") @@ -500,6 +861,28 @@ async def get_access_token( """ state_data = await self._state_store.get(self._state_identifier, store_options) + # Validate session domain matches current request domain + if state_data and self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + current_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + session_domain = getattr(state_data, 'domain', None) + + if session_domain and session_domain != current_domain: + # Session created with different domain - reject for security + raise AccessTokenError( + AccessTokenErrorCode.MISSING_REFRESH_TOKEN, + "Session domain mismatch. User needs to re-authenticate with the current domain." + ) + auth_params = self._default_authorization_params or {} # Get audience passed in on options or use defaults @@ -531,7 +914,12 @@ async def get_access_token( # Get new token with refresh token try: - get_refresh_token_options = {"refresh_token": state_data_dict["refresh_token"]} + # Use session's domain for token refresh + session_domain = state_data_dict.get("domain") or self._domain + get_refresh_token_options = { + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain + } if audience: get_refresh_token_options["audience"] = audience @@ -557,6 +945,7 @@ async def get_access_token( f"Failed to get token with refresh token: {str(e)}" ) + async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, Any]: """ Retrieves a token by exchanging a refresh token. @@ -576,9 +965,12 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise MissingRequiredArgumentError("refresh_token") try: - # Ensure we have the OIDC metadata + # Use session domain if provided, otherwise fallback to static domain + domain = options.get("domain") or self._domain + + # Ensure we have the OIDC metadata from the correct domain if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth.metadata = await self._get_oidc_metadata_cached(domain) token_endpoint = self._oauth.metadata.get("token_endpoint") if not token_endpoint: @@ -858,7 +1250,7 @@ async def initiate_backchannel_authentication( try: # Fetch OpenID Connect metadata if not already fetched if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) # Get the issuer from metadata issuer = self._oauth_metadata.get( @@ -953,7 +1345,7 @@ async def backchannel_authentication_grant(self, auth_req_id: str) -> dict[str, try: # Ensure we have the OIDC metadata if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth.metadata = await self._get_oidc_metadata_cached(self._domain) token_endpoint = self._oauth.metadata.get("token_endpoint") if not token_endpoint: @@ -1173,13 +1565,13 @@ async def _build_link_user_url( connection_scope: Optional[str] = None, authorization_params: Optional[dict[str, Any]] = None ) -> str: - """Helper: Builds the authorization URL for linking user accounts.""" + """Build a URL for linking user accounts""" # Generate code challenge from verifier code_challenge = PKCE.generate_code_challenge(code_verifier) # Get metadata if not already fetched if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) # Get authorization endpoint auth_endpoint = self._oauth_metadata.get("authorization_endpoint", @@ -1216,13 +1608,13 @@ async def _build_unlink_user_url( state: str, authorization_params: Optional[dict[str, Any]] = None ) -> str: - """Helper: Builds the authorization URL for unlinking user accounts.""" + """Build a URL for unlinking user accounts""" # Generate code challenge from verifier code_challenge = PKCE.generate_code_challenge(code_verifier) # Get metadata if not already fetched if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) # Get authorization endpoint auth_endpoint = self._oauth_metadata.get("authorization_endpoint", @@ -1302,10 +1694,13 @@ async def get_access_token_for_connection( "A refresh token was not found but is required to be able to retrieve an access token for a connection." ) # Get new token for connection + # Use session's domain for token exchange + session_domain = state_data_dict.get("domain") or self._domain token_endpoint_response = await self.get_token_for_connection({ "connection": options.get("connection"), "login_hint": options.get("login_hint"), - "refresh_token": state_data_dict["refresh_token"] + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain }) # Update state data with new token @@ -1337,9 +1732,12 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "http://auth0.com/oauth/token-type/federated-connection-access-token" GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" try: - # Ensure we have OIDC metadata + # Use session domain if provided, otherwise fallback to static domain + domain = options.get("domain") or self._domain + + # Ensure we have OIDC metadata from the correct domain if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth.metadata = await self._get_oidc_metadata_cached(domain) token_endpoint = self._oauth.metadata.get("token_endpoint") if not token_endpoint: diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 4b36ca3..bbd77d4 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -66,6 +66,7 @@ class SessionData(BaseModel): refresh_token: Optional[str] = None token_sets: list[TokenSet] = Field(default_factory=list) connection_token_sets: list[ConnectionTokenSet] = Field(default_factory=list) + domain: Optional[str] = None class Config: extra = "allow" # Allow additional fields not defined in the model @@ -89,6 +90,8 @@ class TransactionData(BaseModel): app_state: Optional[Any] = None auth_session: Optional[str] = None redirect_uri: Optional[str] = None + origin_domain: Optional[str] = None + origin_issuer: Optional[str] = None class Config: extra = "allow" # Allow additional fields not defined in the model @@ -213,6 +216,29 @@ class StartLinkUserOptions(BaseModel): authorization_params: Optional[dict[str, Any]] = None app_state: Optional[Any] = None +# ============================================================================= +# Multiple Custom Domain +# ============================================================================= + +class DomainResolverContext(BaseModel): + """ + Context passed to domain resolver function for MCD support. + + Contains request information needed to determine the correct Auth0 domain + based on the incoming request's hostname or headers. + + Attributes: + request_url: The full request URL (e.g., "https://a.my-app.com/auth/login") + request_headers: Dictionary of request headers (e.g., {"host": "a.my-app.com", "x-forwarded-host": "..."}) + + Example: + async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get('host', '').split(':')[0] + return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) + """ + request_url: Optional[str] = None + request_headers: Optional[dict[str, str]] = None + # ============================================================================= # Custom Token Exchange Types # ============================================================================= diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index c593368..59f1c66 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -101,6 +101,17 @@ def __init__(self, argument: str): self.argument = argument +class ConfigurationError(Auth0Error): + """ + Error raised when SDK configuration is invalid. + This includes invalid combinations of parameters or incorrect configuration values. + """ + code = "configuration_error" + + def __init__(self, message: str): + super().__init__(message) + self.name = "ConfigurationError" + class InvalidArgumentError(Auth0Error): """ Error raised when a given argument is an invalid value. @@ -125,6 +136,21 @@ def __init__(self, message: str): self.name = "BackchannelLogoutError" +class DomainResolverError(Auth0Error): + """ + Error raised when domain resolver function fails or returns invalid value. + + This error indicates an issue with the custom domain resolver function + provided for MCD (Multiple Custom Domains) support. + """ + code = "domain_resolver_error" + + def __init__(self, message: str, original_error: Exception = None): + super().__init__(message) + self.name = "DomainResolverError" + self.original_error = original_error + + class AccessTokenForConnectionError(Auth0Error): """Error when retrieving access tokens for a specific connection fails.""" diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 4f1b90b..b41cb6a 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -1,8 +1,9 @@ import json import time -from unittest.mock import ANY, AsyncMock, MagicMock +from unittest.mock import ANY, AsyncMock, MagicMock, patch from urllib.parse import parse_qs, urlparse +import jwt import pytest from auth0_server_python.auth_server.my_account_client import MyAccountClient from auth0_server_python.auth_server.server_client import ServerClient @@ -15,18 +16,22 @@ ConnectedAccountConnection, ConnectParams, CustomTokenExchangeOptions, + DomainResolverContext, ListConnectedAccountConnectionsResponse, ListConnectedAccountsResponse, LoginWithCustomTokenExchangeOptions, LogoutOptions, + StateData, TransactionData, ) from auth0_server_python.error import ( AccessTokenForConnectionError, ApiError, BackchannelLogoutError, + ConfigurationError, CustomTokenExchangeError, CustomTokenExchangeErrorCode, + DomainResolverError, InvalidArgumentError, MissingRequiredArgumentError, MissingTransactionError, @@ -52,7 +57,7 @@ async def test_init_no_secret_raises(): @pytest.mark.asyncio -async def test_start_interactive_login_no_redirect_uri(): +async def test_start_interactive_login_no_redirect_uri(mocker): client = ServerClient( domain="auth0.local", client_id="", @@ -61,6 +66,14 @@ async def test_start_interactive_login_no_redirect_uri(): transaction_store=AsyncMock(), secret="some-secret" ) + + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://auth0.local/", "authorization_endpoint": "https://auth0.local/authorize"} + ) + with pytest.raises(MissingRequiredArgumentError) as exc: await client.start_interactive_login() # Check the error message @@ -84,7 +97,7 @@ async def test_start_interactive_login_builds_auth_url(mocker): # Mock out HTTP calls or the internal methods that create the auth URL mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) mock_oauth = mocker.patch.object( @@ -125,8 +138,13 @@ async def test_complete_interactive_login_no_transaction(): @pytest.mark.asyncio async def test_complete_interactive_login_returns_app_state(mocker): mock_tx_store = AsyncMock() - # The stored transaction includes an appState - mock_tx_store.get.return_value = TransactionData(code_verifier="123", app_state={"foo": "bar"}) + # The stored transaction includes an appState with origin_domain and origin_issuer + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + app_state={"foo": "bar"}, + origin_domain="auth0.local", + origin_issuer="https://auth0.local/" + ) mock_state_store = AsyncMock() @@ -139,6 +157,13 @@ async def test_complete_interactive_login_returns_app_state(mocker): secret="some-secret", ) + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://auth0.local/", "token_endpoint": "https://auth0.local/token"} + ) + # Patch token exchange mocker.patch.object(client._oauth, "metadata", {"token_endpoint": "https://auth0.local/token"}) @@ -214,7 +239,7 @@ async def test_complete_link_user_returns_app_state(mocker): ) # Patch token exchange - mocker.patch.object(client, "_fetch_oidc_metadata", return_value={"token_endpoint": "https://auth0.local/token"}) + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={"token_endpoint": "https://auth0.local/token"}) async_fetch_token = AsyncMock() async_fetch_token.return_value = { "access_token": "token123", @@ -410,7 +435,8 @@ async def test_get_access_token_refresh_expired(mocker): assert token == "new_token" mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ - "refresh_token": "refresh_xyz" + "refresh_token": "refresh_xyz", + "domain": "auth0.local" }) @pytest.mark.asyncio @@ -451,6 +477,7 @@ async def test_get_access_token_refresh_merging_default_scope(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "default", "scope": "openid profile email foo:bar" }) @@ -492,6 +519,7 @@ async def test_get_access_token_refresh_with_auth_params_scope(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "scope": "openid profile email" }) @@ -532,6 +560,7 @@ async def test_get_access_token_refresh_with_auth_params_audience(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "my_audience" }) @@ -578,6 +607,7 @@ async def test_get_access_token_mrrt(mocker): assert len(stored_state["token_sets"]) == 2 get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "some_audience", "scope": "foo:bar", }) @@ -631,6 +661,7 @@ async def test_get_access_token_mrrt_with_auth_params_scope(mocker): assert len(stored_state["token_sets"]) == 2 get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "some_audience", "scope": "foo:bar", }) @@ -858,8 +889,21 @@ async def test_handle_backchannel_logout_ok(mocker): secret="some-secret" ) + # Mock JWKS fetch to prevent network call + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) mocker.patch("jwt.decode", return_value={ "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, + "iss": "https://auth0.local", "sub": "user_sub", "sid": "session_id_123" }) @@ -884,7 +928,7 @@ async def test_build_link_user_url_success(mocker): # Patch _fetch_oidc_metadata to return an authorization_endpoint mock_fetch = mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -942,7 +986,7 @@ async def test_build_link_user_url_fallback_authorize(mocker): # Patch _fetch_oidc_metadata to NOT have an authorization_endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={} # empty dict, triggers fallback ) @@ -979,7 +1023,7 @@ async def test_build_unlink_user_url_success(mocker): # Patch out metadata mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -1012,7 +1056,7 @@ async def test_build_unlink_user_url_fallback_authorize(mocker): ) # No 'authorization_endpoint' - mocker.patch.object(client, "_fetch_oidc_metadata", return_value={}) + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={}) result_url = await client._build_unlink_user_url( connection="", @@ -1043,7 +1087,7 @@ async def test_build_unlink_user_url_with_metadata(mocker): # Patch the metadata fetch to include a valid authorization endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -1096,7 +1140,7 @@ async def test_build_unlink_user_url_no_authorization_endpoint(mocker): # Patch _fetch_oidc_metadata to return no authorization_endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={} ) result_url = await client._build_unlink_user_url( @@ -1127,7 +1171,7 @@ async def test_backchannel_auth_with_audience_and_binding_message(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1176,7 +1220,7 @@ async def test_backchannel_auth_rar(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1227,7 +1271,7 @@ async def test_backchannel_auth_token_exchange_failed(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1277,7 +1321,7 @@ async def test_initiate_backchannel_authentication_success(mocker): # Mock OIDC metadata mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/backchannel" @@ -1325,7 +1369,7 @@ async def test_initiate_backchannel_authentication_error_response(mocker): ) mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/backchannel" @@ -2150,7 +2194,6 @@ async def test_list_connected_account_connections_with_invalid_take_param(mocker assert "The 'take' parameter must be a positive integer." in str(exc.value) mock_my_account_client.list_connected_account_connections.assert_not_awaited() - # ============================================================================= # Custom Token Exchange Tests # ============================================================================= @@ -2177,7 +2220,6 @@ async def test_custom_token_exchange_success(mocker): "_fetch_oidc_metadata", return_value={"token_endpoint": "https://auth0.local/oauth/token"} ) - # Mock httpx response mock_response = MagicMock() mock_response.status_code = 200 @@ -2824,3 +2866,1000 @@ async def test_login_with_custom_token_exchange_failure_propagates(mocker): ) ) assert exc.value.code == "unauthorized" + +# ============================================================================= +# OIDC Metadata and JWKS Fetching Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_fetch_jwks_success(): + """Test successful JWKS fetch from URI.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_jwks = { + "keys": [ + { + "kty": "RSA", + "use": "sig", + "kid": "test-key-id", + "n": "test-modulus", + "e": "AQAB" + } + ] + } + + # Mock httpx client + mock_response = MagicMock() + mock_response.json.return_value = mock_jwks + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = mock_response + + with patch('httpx.AsyncClient', return_value=mock_client): + jwks = await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") + + assert jwks == mock_jwks + assert "keys" in jwks + mock_client.get.assert_awaited_once_with("https://tenant.auth0.com/.well-known/jwks.json") + + +@pytest.mark.asyncio +async def test_fetch_jwks_failure(): + """Test JWKS fetch failure raises ApiError.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Mock httpx client to raise exception + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.side_effect = Exception("Network error") + + with patch('httpx.AsyncClient', return_value=mock_client): + with pytest.raises(ApiError, match="Failed to fetch JWKS"): + await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") + + +@pytest.mark.asyncio +async def test_oidc_metadata_caching(): + """Test OIDC metadata is cached and reused.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "authorization_endpoint": "https://tenant.auth0.com/authorize", + "token_endpoint": "https://tenant.auth0.com/oauth/token", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + # Mock _fetch_oidc_metadata to track calls + fetch_count = 0 + async def mock_fetch(domain): + nonlocal fetch_count + fetch_count += 1 + return mock_metadata + + client._fetch_oidc_metadata = mock_fetch + + # First call - should fetch + result1 = await client._get_oidc_metadata_cached("tenant.auth0.com") + assert result1 == mock_metadata + assert fetch_count == 1 + first_fetch_count = fetch_count + + # Second call - should use cache + result2 = await client._get_oidc_metadata_cached("tenant.auth0.com") + assert result2 == mock_metadata + assert fetch_count == first_fetch_count # Should NOT increment + + # Verify cache contains data + assert "tenant.auth0.com" in client._metadata_cache + assert client._metadata_cache["tenant.auth0.com"]["data"] == mock_metadata + + +@pytest.mark.asyncio +async def test_oidc_metadata_cache_expiration(): + """Test OIDC metadata cache expires after TTL.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Set short TTL for testing + client._cache_ttl = 1 # 1 second + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + fetch_count = 0 + async def mock_fetch(domain): + nonlocal fetch_count + fetch_count += 1 + return mock_metadata + + client._fetch_oidc_metadata = mock_fetch + + # First call + await client._get_oidc_metadata_cached("tenant.auth0.com") + assert fetch_count == 1 + + # Wait for cache to expire + time.sleep(1.1) + + # Second call after expiration - should fetch again + await client._get_oidc_metadata_cached("tenant.auth0.com") + assert fetch_count == 2 + + +@pytest.mark.asyncio +async def test_jwks_caching(): + """Test JWKS is cached and reused.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + mock_jwks = { + "keys": [{"kty": "RSA", "kid": "key1"}] + } + + # Mock the fetch methods + client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata) + + fetch_count = 0 + async def mock_fetch_jwks(uri): + nonlocal fetch_count + fetch_count += 1 + return mock_jwks + + client._fetch_jwks = mock_fetch_jwks + + # First call - should fetch + result1 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) + assert result1 == mock_jwks + assert fetch_count == 1 + first_fetch_count = fetch_count + + # Second call - should use cache + result2 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) + assert result2 == mock_jwks + assert fetch_count == first_fetch_count # Should NOT increment + + +@pytest.mark.asyncio +async def test_jwks_cache_size_limit(): + """Test JWKS cache enforces max size limit with FIFO eviction.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Set small cache size for testing + client._cache_max_entries = 3 + + mock_jwks = {"keys": [{"kty": "RSA"}]} + + # Mock methods + async def mock_fetch_metadata(domain): + return {"jwks_uri": f"https://{domain}/.well-known/jwks.json"} + + async def mock_fetch_jwks(uri): + return mock_jwks + + client._fetch_oidc_metadata = mock_fetch_metadata + client._fetch_jwks = mock_fetch_jwks + + # Fill cache to limit + await client._get_jwks_cached("domain1.auth0.com") + await client._get_jwks_cached("domain2.auth0.com") + await client._get_jwks_cached("domain3.auth0.com") + + assert len(client._jwks_cache) == 3 + assert "domain1.auth0.com" in client._jwks_cache + + # Add one more - should evict oldest (domain1) + await client._get_jwks_cached("domain4.auth0.com") + + assert len(client._jwks_cache) == 3 + assert "domain1.auth0.com" not in client._jwks_cache # Evicted + assert "domain4.auth0.com" in client._jwks_cache + + +@pytest.mark.asyncio +async def test_jwks_missing_uri_raises_error(): + """Test that missing jwks_uri in metadata raises ApiError.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Metadata WITHOUT jwks_uri + mock_metadata_no_jwks_uri = { + "issuer": "https://tenant.auth0.com/", + "authorization_endpoint": "https://tenant.auth0.com/authorize" + # No jwks_uri + } + + client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata_no_jwks_uri) + + # Should raise ApiError when jwks_uri is missing + with pytest.raises(ApiError) as exc_info: + await client._get_jwks_cached("tenant.auth0.com") + + assert exc_info.value.code == "missing_jwks_uri" + assert "non-RFC-compliant" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_metadata_cache_size_limit(): + """Test OIDC metadata cache enforces max size limit.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + client._cache_max_entries = 2 + + async def mock_fetch(domain): + return {"issuer": f"https://{domain}/"} + + client._fetch_oidc_metadata = mock_fetch + + # Fill cache + await client._get_oidc_metadata_cached("domain1.auth0.com") + await client._get_oidc_metadata_cached("domain2.auth0.com") + + assert len(client._metadata_cache) == 2 + + # Add third - should evict first + await client._get_oidc_metadata_cached("domain3.auth0.com") + + assert len(client._metadata_cache) == 2 + assert "domain1.auth0.com" not in client._metadata_cache + assert "domain3.auth0.com" in client._metadata_cache + + +# ============================================================================= +# Issuer Validation Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_complete_login_issuer_validation_success(mocker): + """Test complete login with valid issuer in ID token.""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant.auth0.com", + origin_issuer="https://tenant.auth0.com/" + ) + + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Mock OIDC metadata + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # Mock JWKS fetch + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock OAuth fetch_token + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock jwt.get_unverified_header + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + + # Mock PyJWK.from_dict + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + # Mock jwt.decode with valid issuer + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://tenant.auth0.com/", # Matches origin_issuer + "aud": "test_client" + }) + + # Should succeed without raising error + result = await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert result is not None + assert "state_data" in result + + +@pytest.mark.asyncio +async def test_complete_login_issuer_mismatch_raises_error(mocker): + """Test that issuer mismatch in ID token raises ApiError.""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant.auth0.com", + origin_issuer="https://tenant.auth0.com/" + ) + + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Mock OIDC metadata + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # Mock JWKS fetch + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock OAuth fetch_token + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock jwt.get_unverified_header + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + + # Mock PyJWK.from_dict + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + # Mock jwt.decode to return claims with a WRONG issuer + # Our custom normalized issuer validation should catch this mismatch + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://wrong-issuer.auth0.com/", # Different from expected: https://tenant.auth0.com/ + "aud": "test_client", + "exp": 9999999999 + }) + + # Should raise ApiError with invalid_issuer code + with pytest.raises(ApiError) as exc_info: + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert exc_info.value.code == "invalid_issuer" + assert "issuer mismatch" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_normalize_domain_handles_different_schemes(): + """Test that _normalize_domain handles various URL schemes correctly.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Test domain without scheme + assert client._normalize_domain("auth0.com") == "https://auth0.com" + + # Test domain with https scheme (should remain unchanged) + assert client._normalize_domain("https://auth0.com") == "https://auth0.com" + + # Test domain with http scheme (should convert to https) + assert client._normalize_domain("http://auth0.com") == "https://auth0.com" + + # Test domain with trailing slash + assert client._normalize_domain("https://auth0.com/") == "https://auth0.com/" + + +@pytest.mark.asyncio +async def test_normalize_issuer_handles_edge_cases(): + """Test that _normalize_issuer handles edge cases for robust issuer comparison. + + This test documents the edge cases that could cause issuer validation failures + with PyJWT's strict string comparison: + - Trailing slash differences + - Case sensitivity + - HTTP vs HTTPS schemes + - Missing scheme + """ + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Test trailing slash normalization + assert client._normalize_issuer("https://auth0.com/") == "https://auth0.com" + assert client._normalize_issuer("https://auth0.com") == "https://auth0.com" + assert client._normalize_issuer("https://auth0.com/") == client._normalize_issuer("https://auth0.com") + + # Test case insensitivity + assert client._normalize_issuer("HTTPS://AUTH0.COM/") == "https://auth0.com" + assert client._normalize_issuer("Https://Auth0.Com") == "https://auth0.com" + assert client._normalize_issuer("HTTPS://AUTH0.COM/") == client._normalize_issuer("https://auth0.com") + + # Test HTTP to HTTPS conversion + assert client._normalize_issuer("http://auth0.com") == "https://auth0.com" + assert client._normalize_issuer("HTTP://AUTH0.COM/") == "https://auth0.com" + + # Test missing scheme + assert client._normalize_issuer("auth0.com") == "https://auth0.com" + assert client._normalize_issuer("AUTH0.COM/") == "https://auth0.com" + + # Test empty/None handling + assert client._normalize_issuer("") == "" + assert client._normalize_issuer(None) is None + + +# ============================================================================= +# MCD Tests : Multiple Issuer Configuration Methods Tests +# ============================================================================= + +@pytest.mark.asyncio +async def test_domain_as_static_string(): + """Test Method 1: Static domain string configuration.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + assert client._domain == "tenant.auth0.com" + assert client._domain_resolver is None + + +@pytest.mark.asyncio +async def test_domain_as_callable_function(): + """Test Method 2: Domain resolver function configuration.""" + async def domain_resolver(store_options): + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + assert client._domain is None + assert client._domain_resolver == domain_resolver + + +@pytest.mark.asyncio +async def test_missing_domain_raises_configuration_error(): + """Test that missing domain parameter raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="Domain is required"): + ServerClient( + domain=None, + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +@pytest.mark.asyncio +async def test_invalid_domain_type_list(): + """Test that list domain raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="must be either a string or a callable"): + ServerClient( + domain=["tenant.auth0.com"], + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +@pytest.mark.asyncio +async def test_empty_domain_string(): + """Test that empty domain string raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="Domain cannot be empty"): + ServerClient( + domain="", + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +# ============================================================================= +# MCD Tests : Domain Resolver Context Tests +# ============================================================================= + +@pytest.mark.asyncio +async def test_domain_resolver_receives_context(mocker): + """Test that domain resolver receives DomainResolverContext with request data.""" + received_context = None + + async def domain_resolver(context): + nonlocal received_context + received_context = context + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Mock request with headers + mock_request = MagicMock() + mock_request.url = "https://a.my-app.com/auth/login" + mock_request.headers = {"host": "a.my-app.com", "x-forwarded-host": "a.my-app.com"} + + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} + ) + + try: + await client.start_interactive_login(store_options={"request": mock_request}) + except Exception: # noqa: S110 + pass # We only care about context being passed + + assert received_context is not None + assert isinstance(received_context, DomainResolverContext) + assert received_context.request_url == "https://a.my-app.com/auth/login" + assert received_context.request_headers.get("host") == "a.my-app.com" + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_none(): + """Test that domain resolver returning None raises DomainResolverError.""" + async def bad_resolver(context): + return None + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="returned None"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_empty_string(): + """Test that domain resolver returning empty string raises DomainResolverError.""" + async def bad_resolver(context): + return "" + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="empty string"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_exception(): + """Test that domain resolver exceptions are wrapped in DomainResolverError.""" + async def bad_resolver(context): + raise ValueError("Something went wrong") + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="raised an exception"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_with_no_request(mocker): + """Test that domain resolver works with empty context when no request.""" + received_context = None + + async def domain_resolver(context): + nonlocal received_context + received_context = context + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} + ) + + try: + await client.start_interactive_login(store_options=None) + except Exception: # noqa: S110 + pass # We only care about context being passed + assert received_context is not None + assert received_context.request_url is None + assert received_context.request_headers is None + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_non_string_type(): + """Test that domain resolver returning non-string raises DomainResolverError.""" + async def bad_resolver(context): + return 12345 + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="must return a string"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + +# ============================================================================= +# MCD Tests : Domain-specific Session Management Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_session_stores_origin_domain(mocker): + """Test that session stores origin domain from login (Requirement 5).""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant1.auth0.com", + origin_issuer="https://tenant1.auth0.com/" + ) + + captured_state = None + async def capture_state(identifier, state_data, options=None): + nonlocal captured_state + captured_state = state_data + + mock_state_store = AsyncMock() + mock_state_store.set = AsyncMock(side_effect=capture_state) + + client = ServerClient( + domain="tenant1.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": "https://tenant1.auth0.com/", + "token_endpoint": "https://tenant1.auth0.com/token" + }) + mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) + + async_fetch_token = AsyncMock(return_value={ + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid" + }) + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": "https://tenant1.auth0.com/"}) + + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + # Verify session has domain field set + assert captured_state.domain == "tenant1.auth0.com" + + +@pytest.mark.asyncio +async def test_cross_domain_session_rejected(): + """Test that session from domain1 cannot be used with domain2 (Requirement 5).""" + # Create session with domain1 + session_data = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns domain2 (different from session) + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # get_user should return None (session rejected) + user = await client.get_user(store_options={"request": {}}) + assert user is None + + +@pytest.mark.asyncio +async def test_logout_uses_current_domain(mocker): + """Test that logout uses current resolved domain (Requirement 7).""" + current_domain = "tenant2.auth0.com" + + async def domain_resolver(context): + return current_domain + + mock_state_store = AsyncMock() + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + logout_url = await client.logout(store_options={"request": {}}) + + # Verify logout URL uses current domain + assert current_domain in logout_url + assert logout_url.startswith(f"https://{current_domain}") + + +@pytest.mark.asyncio +async def test_logout_clears_session_for_current_domain(): + """Test that logout clears session (Requirement 7).""" + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + await client.logout() + + # Verify session was deleted + mock_state_store.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_domain_migration_old_sessions_remain_valid(): + """Test that old sessions remain valid with old domain requests (Requirement 8).""" + old_domain = "old-tenant.auth0.com" + + # Session from old domain + session_data = StateData( + user={"sub": "user123"}, + domain=old_domain, + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns old domain + async def domain_resolver(context): + return old_domain + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Should successfully retrieve user + user = await client.get_user(store_options={"request": {}}) + assert user is not None + assert user["sub"] == "user123" + + +@pytest.mark.asyncio +async def test_domain_migration_new_sessions_use_new_domain(mocker): + """Test that new logins create sessions with new domain (Requirement 8).""" + new_domain = "new-tenant.auth0.com" + + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain=new_domain, + origin_issuer=f"https://{new_domain}/" + ) + + captured_state = None + async def capture_state(identifier, state_data, options=None): + nonlocal captured_state + captured_state = state_data + + mock_state_store = AsyncMock() + mock_state_store.set = AsyncMock(side_effect=capture_state) + + client = ServerClient( + domain=new_domain, + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": f"https://{new_domain}/", + "token_endpoint": f"https://{new_domain}/token" + }) + mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) + + async_fetch_token = AsyncMock(return_value={ + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid" + }) + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": f"https://{new_domain}/"}) + + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + # Verify new session has new domain + assert captured_state.domain == new_domain + + +@pytest.mark.asyncio +async def test_domain_migration_sessions_isolated(): + """Test that old domain sessions cannot be used with new domain (Requirement 8).""" + old_domain = "old-tenant.auth0.com" + new_domain = "new-tenant.auth0.com" + + # Session from old domain + session_data = StateData( + user={"sub": "user123"}, + domain=old_domain, + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns NEW domain (migration happened) + async def domain_resolver(context): + return new_domain + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Should reject old session + user = await client.get_user(store_options={"request": {}}) + assert user is None diff --git a/src/auth0_server_python/utils/helpers.py b/src/auth0_server_python/utils/helpers.py index c57ab18..05cb0f8 100644 --- a/src/auth0_server_python/utils/helpers.py +++ b/src/auth0_server_python/utils/helpers.py @@ -6,6 +6,9 @@ from typing import Any, Optional from urllib.parse import parse_qs, urlencode, urlparse +from auth0_server_python.auth_types import DomainResolverContext +from auth0_server_python.error import DomainResolverError + class PKCE: @classmethod @@ -224,3 +227,69 @@ def create_logout_url(domain: str, client_id: str, return_to: Optional[str] = No if return_to: params["returnTo"] = return_to return URL.build_url(base_url, params) + + +# ============================================================================= +# Domain Resolver Utilities +# ============================================================================= + +def build_domain_resolver_context(store_options: Optional[dict[str, Any]]) -> 'DomainResolverContext': + """ + Build DomainResolverContext from store_options. + + Extracts request information in a framework-agnostic way using duck typing. + + Args: + store_options: Dictionary containing 'request' and 'response' objects + + Returns: + DomainResolverContext with extracted request data + """ + + if not store_options: + return DomainResolverContext() + + request = store_options.get('request') + if not request: + return DomainResolverContext() + + # Framework-agnostic extraction using duck typing + request_url = str(request.url) if hasattr(request, 'url') else None + request_headers = dict(request.headers) if hasattr(request, 'headers') else None + + return DomainResolverContext( + request_url=request_url, + request_headers=request_headers + ) + + +def validate_resolved_domain_value(domain_value: Any) -> str: + """ + Validate the value returned by domain resolver. + + Args: + domain_value: The value returned by the domain resolver + + Returns: + The validated domain string + + Raises: + DomainResolverError: If the returned value is invalid + """ + + if domain_value is None: + raise DomainResolverError( + "Domain resolver returned None. Must return a valid domain string." + ) + + if not isinstance(domain_value, str): + raise DomainResolverError( + f"Domain resolver must return a string. Got {type(domain_value).__name__} instead." + ) + + if not domain_value.strip(): + raise DomainResolverError( + "Domain resolver returned an empty string. Must return a valid domain." + ) + + return domain_value