diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-28 17:06:38 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-28 17:06:38 -0700 |
commit | cb0c815ad799050ecc0abdf3d15981e9832b9829 (patch) | |
tree | ffa35d2bb595fbfa59e4ae03886d46bc048dd0a2 /continuedev | |
parent | 99ece78c8640495fbdabd95d30c26d620045b0ec (diff) | |
download | sncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.tar.gz sncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.tar.bz2 sncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.zip |
feat: :sparkles: allow custom OpenAI base_url
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 20 |
3 files changed, 18 insertions, 15 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index cb9c8977..e367e06c 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -25,10 +25,11 @@ class OnTracebackSteps(BaseModel): params: Optional[Dict] = {} -class AzureInfo(BaseModel): - endpoint: str - engine: str - api_version: str +class OpenAIServerInfo(BaseModel): + api_base: Optional[str] = None + engine: Optional[str] = None + api_version: Optional[str] = None + api_type: Literal["azure", "openai"] = "openai" class ContinueConfig(BaseModel): @@ -49,7 +50,7 @@ class ContinueConfig(BaseModel): slash_commands: Optional[List[SlashCommand]] = [] on_traceback: Optional[List[OnTracebackSteps]] = [] system_message: Optional[str] = None - azure_openai_info: Optional[AzureInfo] = None + openai_server_info: Optional[OpenAIServerInfo] = None context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index d75aac00..9ee9ea06 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -81,7 +81,7 @@ class Models: api_key = self.provider_keys["openai"] if api_key == "": return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log) - return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log) + return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log) def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: api_key = self.provider_keys["hf_inference_api"] diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index a0773c1d..654c7326 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -6,27 +6,29 @@ from ...core.main import ChatMessage import openai from ..llm import LLM from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top -from ...core.config import AzureInfo +from ...core.config import OpenAIServerInfo class OpenAI(LLM): api_key: str default_model: str - def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): + def __init__(self, api_key: str, default_model: str, system_message: str = None, openai_server_info: OpenAIServerInfo = None, write_log: Callable[[str], None] = None): self.api_key = api_key self.default_model = default_model self.system_message = system_message - self.azure_info = azure_info + self.openai_server_info = openai_server_info self.write_log = write_log openai.api_key = api_key # Using an Azure OpenAI deployment - if azure_info is not None: - openai.api_type = "azure" - openai.api_base = azure_info.endpoint - openai.api_version = azure_info.api_version + if openai_server_info is not None: + openai.api_type = openai_server_info.api_type + if openai_server_info.api_base is not None: + openai.api_base = openai_server_info.api_base + if openai_server_info.api_version is not None: + openai.api_version = openai_server_info.api_version @cached_property def name(self): @@ -35,8 +37,8 @@ class OpenAI(LLM): @property def default_args(self): args = {**DEFAULT_ARGS, "model": self.default_model} - if self.azure_info is not None: - args["engine"] = self.azure_info.engine + if self.openai_server_info is not None: + args["engine"] = self.openai_server_info.engine return args def count_tokens(self, text: str): |