diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/core/models.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 25 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 25 |
4 files changed, 31 insertions, 24 deletions
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index 8b1b1f00..ec89d503 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -38,6 +38,10 @@ class Models(BaseModel): kwargs = {} if llm.required_api_key: kwargs["api_key"] = await self.sdk.get_api_key(llm.required_api_key) + if llm.required_unique_id: + kwargs["unique_id"] = self.sdk.ide.unique_id + if llm.required_write_log: + kwargs["write_log"] = self.sdk.write_log await llm.start(**kwargs) async def start(self, sdk: "ContinueSDK"): diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 52e44bfe..401709c9 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -16,6 +16,7 @@ class GGML(LLM): def __init__(self, system_message: str = None): self.system_message = system_message + @property def name(self): return "ggml" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index ef8830a6..5ac4d211 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,6 +1,6 @@ from functools import cached_property import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Union, Optional from pydantic import BaseModel from ...core.main import ChatMessage @@ -16,14 +16,13 @@ class AzureInfo(BaseModel): class OpenAI(LLM): + model: str + system_message: Optional[str] = None + azure_info: Optional[AzureInfo] = None + write_log: Optional[Callable[[str], None]] = None + required_api_key = "OPENAI_API_KEY" - default_model: str - - def __init__(self, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): - self.default_model = default_model - self.system_message = system_message - self.azure_info = azure_info - self.write_log = write_log + required_write_log = True async def start(self, *, api_key): self.api_key = api_key @@ -38,18 +37,19 @@ class OpenAI(LLM): async def stop(self): pass + @property def name(self): - return self.default_model + return self.model @property def default_args(self): - args = {**DEFAULT_ARGS, "model": self.default_model} + args = {**DEFAULT_ARGS, "model": self.model} if self.azure_info is not None: args["engine"] = self.azure_info.engine return args def count_tokens(self, text: str): - return count_tokens(self.default_model, text) + return count_tokens(self.model, text) async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() @@ -85,7 +85,8 @@ class OpenAI(LLM): args = self.default_args.copy() args.update(kwargs) args["stream"] = True - args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613" + # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it? + args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613" if not args["model"].endswith("0613") and "functions" in args: del args["functions"] diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index c0e2a403..2c0e1dc4 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,7 +1,7 @@ import json import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional import aiohttp from ...core.main import ChatMessage from ..llm import LLM @@ -19,29 +19,30 @@ SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" class ProxyServer(LLM): unique_id: str - name: str - default_model: Literal["gpt-3.5-turbo", "gpt-4"] + model: str + system_message: Optional[str] write_log: Callable[[str], None] - def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None): - self.unique_id = unique_id - self.default_model = default_model - self.system_message = system_message - self.name = default_model - self.write_log = write_log + required_unique_id = True + required_write_log = True async def start(self): + # TODO put ClientSession here pass async def stop(self): pass @property + def name(self): + return self.model + + @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.default_model} + return {**DEFAULT_ARGS, "model": self.model} def count_tokens(self, text: str): - return count_tokens(self.default_model, text) + return count_tokens(self.model, text) def get_headers(self): # headers with unique id @@ -103,7 +104,7 @@ class ProxyServer(LLM): async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + self.model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: |