diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index 56bfcec6..f4d6cbc5 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -21,6 +21,7 @@ from common.utils.aes_crypto import simple_aes_decrypt from common.utils.utils import SQLBotLogUtil, equals_ignore_case, get_domain_list, string_to_numeric_hash from common.core.deps import Trans +from common.core.response_middleware import ResponseMiddleware @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id") @@ -87,13 +88,20 @@ def init_dynamic_cors(app: FastAPI): seen.add(domain) unique_domains.append(domain) cors_middleware = None + response_middleware = None for middleware in app.user_middleware: - if middleware.cls == CORSMiddleware: + if not cors_middleware and middleware.cls == CORSMiddleware: cors_middleware = middleware + if not response_middleware and middleware.cls == ResponseMiddleware: + response_middleware = middleware + if cors_middleware and response_middleware: break + + updated_origins = list(set(settings.all_cors_origins + unique_domains)) if cors_middleware: - updated_origins = list(set(settings.all_cors_origins + unique_domains)) cors_middleware.kwargs['allow_origins'] = updated_origins + if response_middleware: + response_middleware.kwargs['allow_origins'] = updated_origins except Exception as e: return False, e diff --git a/backend/common/core/response_middleware.py b/backend/common/core/response_middleware.py index dfd9f1dc..c2eeb43f 100644 --- a/backend/common/core/response_middleware.py +++ b/backend/common/core/response_middleware.py @@ -1,5 +1,6 @@ import json +from redis import typing from starlette.exceptions import HTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request @@ -11,6 +12,7 @@ class ResponseMiddleware(BaseHTTPMiddleware): def __init__(self, app): + self.allow_origins = ["'self'"] super().__init__(app) async def dispatch(self, request, call_next): @@ -76,7 +78,13 @@ async def dispatch(self, request, call_next): if k.lower() not in ("content-length", "content-type") } ) - + content_type = response.headers.get("content-type", "") + static_content_types = ["text/html", "javascript", "typescript", "css"] + if any(ct in content_type for ct in static_content_types): + if self.allow_origins: + frame_ancestors_value = " ".join(self.allow_origins) + response.headers["Content-Security-Policy"] = f"frame-ancestors {frame_ancestors_value};" + return response