diff --git a/backend/apps/ai_model/openai/llm.py b/backend/apps/ai_model/openai/llm.py index 3867e7b6..a03693c6 100644 --- a/backend/apps/ai_model/openai/llm.py +++ b/backend/apps/ai_model/openai/llm.py @@ -1,5 +1,5 @@ from collections.abc import Iterator, Mapping -from typing import Any, cast +from typing import Any, Optional, cast from langchain_core.language_models import LanguageModelInput from langchain_core.messages import ( @@ -84,6 +84,26 @@ def _convert_delta_to_message_chunk( class BaseChatOpenAI(ChatOpenAI): + @property + def _default_params(self) -> dict[str, Any]: + max_tokens = self.max_tokens + params = super()._default_params + if max_tokens: + params["max_tokens"] = max_tokens + return params + + def _get_request_payload( + self, + input_: LanguageModelInput, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> dict: + max_tokens = self.max_tokens + payload = super()._get_request_payload(input_, stop=stop, **kwargs) + if max_tokens: + payload["max_tokens"] = max_tokens + return payload usage_metadata: dict = {} # custom_get_token_ids = custom_get_token_ids