summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/core
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-28 17:06:38 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-28 17:06:38 -0700
commitcb0c815ad799050ecc0abdf3d15981e9832b9829 (patch)
treeffa35d2bb595fbfa59e4ae03886d46bc048dd0a2 /continuedev/src/continuedev/core
parent99ece78c8640495fbdabd95d30c26d620045b0ec (diff)
downloadsncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.tar.gz
sncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.tar.bz2
sncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.zip
feat: :sparkles: allow custom OpenAI base_url
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r--continuedev/src/continuedev/core/config.py11
-rw-r--r--continuedev/src/continuedev/core/sdk.py2
2 files changed, 7 insertions, 6 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index cb9c8977..e367e06c 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -25,10 +25,11 @@ class OnTracebackSteps(BaseModel):
params: Optional[Dict] = {}
-class AzureInfo(BaseModel):
- endpoint: str
- engine: str
- api_version: str
+class OpenAIServerInfo(BaseModel):
+ api_base: Optional[str] = None
+ engine: Optional[str] = None
+ api_version: Optional[str] = None
+ api_type: Literal["azure", "openai"] = "openai"
class ContinueConfig(BaseModel):
@@ -49,7 +50,7 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = []
on_traceback: Optional[List[OnTracebackSteps]] = []
system_message: Optional[str] = None
- azure_openai_info: Optional[AzureInfo] = None
+ openai_server_info: Optional[OpenAIServerInfo] = None
context_providers: List[ContextProvider] = []
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index d75aac00..9ee9ea06 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -81,7 +81,7 @@ class Models:
api_key = self.provider_keys["openai"]
if api_key == "":
return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log)
- return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log)
+ return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log)
def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
api_key = self.provider_keys["hf_inference_api"]