summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py11
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py3
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py6
3 files changed, 10 insertions, 10 deletions
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
index f5b3c18c..a0f46fa9 100644
--- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
+++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
@@ -1,4 +1,4 @@
-from typing import Any, Coroutine, Dict, Generator, List, Union, Optional
+from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable
from ...core.main import ChatMessage
from . import LLM
@@ -23,14 +23,13 @@ class MaybeProxyOpenAI(LLM):
def context_length(self):
return self.llm.context_length
- async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]):
if api_key is None or api_key.strip() == "":
- self.llm = ProxyServer(
- unique_id="", model=self.model, write_log=kwargs["write_log"])
+ self.llm = ProxyServer(model=self.model)
else:
- self.llm = OpenAI(model=self.model, write_log=kwargs["write_log"])
+ self.llm = OpenAI(model=self.model)
- await self.llm.start(api_key=api_key, **kwargs)
+ await self.llm.start(api_key=api_key, write_log=write_log, unique_id=unique_id)
async def stop(self):
await self.llm.stop()
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index deb6df4c..16428d4e 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -36,7 +36,8 @@ class OpenAI(LLM):
write_log: Optional[Callable[[str], None]] = None
api_key: str = None
- async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs):
+ self.write_log = write_log
self.api_key = api_key
openai.api_key = self.api_key
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 44734b1c..5ee8ad90 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -34,11 +34,11 @@ class ProxyServer(LLM):
requires_unique_id = True
requires_write_log = True
- async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs):
self._client_session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(ssl_context=ssl_context))
- self.write_log = kwargs["write_log"]
- self.unique_id = kwargs["unique_id"]
+ self.write_log = write_log
+ self.unique_id = unique_id
async def stop(self):
await self._client_session.close()