diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 6 |
3 files changed, 10 insertions, 10 deletions
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index f5b3c18c..a0f46fa9 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -1,4 +1,4 @@ -from typing import Any, Coroutine, Dict, Generator, List, Union, Optional +from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable from ...core.main import ChatMessage from . import LLM @@ -23,14 +23,13 @@ class MaybeProxyOpenAI(LLM): def context_length(self): return self.llm.context_length - async def start(self, *, api_key: Optional[str] = None, **kwargs): + async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]): if api_key is None or api_key.strip() == "": - self.llm = ProxyServer( - unique_id="", model=self.model, write_log=kwargs["write_log"]) + self.llm = ProxyServer(model=self.model) else: - self.llm = OpenAI(model=self.model, write_log=kwargs["write_log"]) + self.llm = OpenAI(model=self.model) - await self.llm.start(api_key=api_key, **kwargs) + await self.llm.start(api_key=api_key, write_log=write_log, unique_id=unique_id) async def stop(self): await self.llm.stop() diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index deb6df4c..16428d4e 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -36,7 +36,8 @@ class OpenAI(LLM): write_log: Optional[Callable[[str], None]] = None api_key: str = None - async def start(self, *, api_key: Optional[str] = None, **kwargs): + async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs): + self.write_log = write_log self.api_key = api_key openai.api_key = self.api_key diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 44734b1c..5ee8ad90 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -34,11 +34,11 @@ class ProxyServer(LLM): requires_unique_id = True requires_write_log = True - async def start(self, *, api_key: Optional[str] = None, **kwargs): + async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs): self._client_session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(ssl_context=ssl_context)) - self.write_log = kwargs["write_log"] - self.unique_id = kwargs["unique_id"] + self.write_log = write_log + self.unique_id = unique_id async def stop(self): await self._client_session.close() |