summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/models.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py1
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py25
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py25
4 files changed, 31 insertions, 24 deletions
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
index 8b1b1f00..ec89d503 100644
--- a/continuedev/src/continuedev/core/models.py
+++ b/continuedev/src/continuedev/core/models.py
@@ -38,6 +38,10 @@ class Models(BaseModel):
kwargs = {}
if llm.required_api_key:
kwargs["api_key"] = await self.sdk.get_api_key(llm.required_api_key)
+ if llm.required_unique_id:
+ kwargs["unique_id"] = self.sdk.ide.unique_id
+ if llm.required_write_log:
+ kwargs["write_log"] = self.sdk.write_log
await llm.start(**kwargs)
async def start(self, sdk: "ContinueSDK"):
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 52e44bfe..401709c9 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -16,6 +16,7 @@ class GGML(LLM):
def __init__(self, system_message: str = None):
self.system_message = system_message
+ @property
def name(self):
return "ggml"
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index ef8830a6..5ac4d211 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,6 +1,6 @@
from functools import cached_property
import json
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Union, Optional
from pydantic import BaseModel
from ...core.main import ChatMessage
@@ -16,14 +16,13 @@ class AzureInfo(BaseModel):
class OpenAI(LLM):
+ model: str
+ system_message: Optional[str] = None
+ azure_info: Optional[AzureInfo] = None
+ write_log: Optional[Callable[[str], None]] = None
+
required_api_key = "OPENAI_API_KEY"
- default_model: str
-
- def __init__(self, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None):
- self.default_model = default_model
- self.system_message = system_message
- self.azure_info = azure_info
- self.write_log = write_log
+ required_write_log = True
async def start(self, *, api_key):
self.api_key = api_key
@@ -38,18 +37,19 @@ class OpenAI(LLM):
async def stop(self):
pass
+ @property
def name(self):
- return self.default_model
+ return self.model
@property
def default_args(self):
- args = {**DEFAULT_ARGS, "model": self.default_model}
+ args = {**DEFAULT_ARGS, "model": self.model}
if self.azure_info is not None:
args["engine"] = self.azure_info.engine
return args
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
@@ -85,7 +85,8 @@ class OpenAI(LLM):
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
- args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613"
+ # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it?
+ args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613"
if not args["model"].endswith("0613") and "functions" in args:
del args["functions"]
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index c0e2a403..2c0e1dc4 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -1,7 +1,7 @@
import json
import traceback
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional
import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
@@ -19,29 +19,30 @@ SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app"
class ProxyServer(LLM):
unique_id: str
- name: str
- default_model: Literal["gpt-3.5-turbo", "gpt-4"]
+ model: str
+ system_message: Optional[str]
write_log: Callable[[str], None]
- def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None):
- self.unique_id = unique_id
- self.default_model = default_model
- self.system_message = system_message
- self.name = default_model
- self.write_log = write_log
+ required_unique_id = True
+ required_write_log = True
async def start(self):
+ # TODO put ClientSession here
pass
async def stop(self):
pass
@property
+ def name(self):
+ return self.model
+
+ @property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.model}
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
def get_headers(self):
# headers with unique id
@@ -103,7 +104,7 @@ class ProxyServer(LLM):
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+ self.model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: