summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-28 17:06:38 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-28 17:06:38 -0700
commitcb0c815ad799050ecc0abdf3d15981e9832b9829 (patch)
treeffa35d2bb595fbfa59e4ae03886d46bc048dd0a2 /continuedev
parent99ece78c8640495fbdabd95d30c26d620045b0ec (diff)
downloadsncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.tar.gz
sncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.tar.bz2
sncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.zip
feat: :sparkles: allow custom OpenAI base_url
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/config.py11
-rw-r--r--continuedev/src/continuedev/core/sdk.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py20
3 files changed, 18 insertions, 15 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index cb9c8977..e367e06c 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -25,10 +25,11 @@ class OnTracebackSteps(BaseModel):
params: Optional[Dict] = {}
-class AzureInfo(BaseModel):
- endpoint: str
- engine: str
- api_version: str
+class OpenAIServerInfo(BaseModel):
+ api_base: Optional[str] = None
+ engine: Optional[str] = None
+ api_version: Optional[str] = None
+ api_type: Literal["azure", "openai"] = "openai"
class ContinueConfig(BaseModel):
@@ -49,7 +50,7 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = []
on_traceback: Optional[List[OnTracebackSteps]] = []
system_message: Optional[str] = None
- azure_openai_info: Optional[AzureInfo] = None
+ openai_server_info: Optional[OpenAIServerInfo] = None
context_providers: List[ContextProvider] = []
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index d75aac00..9ee9ea06 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -81,7 +81,7 @@ class Models:
api_key = self.provider_keys["openai"]
if api_key == "":
return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log)
- return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log)
+ return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log)
def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
api_key = self.provider_keys["hf_inference_api"]
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index a0773c1d..654c7326 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -6,27 +6,29 @@ from ...core.main import ChatMessage
import openai
from ..llm import LLM
from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
-from ...core.config import AzureInfo
+from ...core.config import OpenAIServerInfo
class OpenAI(LLM):
api_key: str
default_model: str
- def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None):
+ def __init__(self, api_key: str, default_model: str, system_message: str = None, openai_server_info: OpenAIServerInfo = None, write_log: Callable[[str], None] = None):
self.api_key = api_key
self.default_model = default_model
self.system_message = system_message
- self.azure_info = azure_info
+ self.openai_server_info = openai_server_info
self.write_log = write_log
openai.api_key = api_key
# Using an Azure OpenAI deployment
- if azure_info is not None:
- openai.api_type = "azure"
- openai.api_base = azure_info.endpoint
- openai.api_version = azure_info.api_version
+ if openai_server_info is not None:
+ openai.api_type = openai_server_info.api_type
+ if openai_server_info.api_base is not None:
+ openai.api_base = openai_server_info.api_base
+ if openai_server_info.api_version is not None:
+ openai.api_version = openai_server_info.api_version
@cached_property
def name(self):
@@ -35,8 +37,8 @@ class OpenAI(LLM):
@property
def default_args(self):
args = {**DEFAULT_ARGS, "model": self.default_model}
- if self.azure_info is not None:
- args["engine"] = self.azure_info.engine
+ if self.openai_server_info is not None:
+ args["engine"] = self.openai_server_info.engine
return args
def count_tokens(self, text: str):