summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-17 13:30:47 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-17 13:30:47 -0700
commit6b708464d76a92e12dac40081cc51ba35f7fc0d0 (patch)
tree466125a30bc706d5f06c591f680949f98fd6da57 /continuedev
parent6d5c07240763b985a32bdb554d600c7423698497 (diff)
downloadsncontinue-6b708464d76a92e12dac40081cc51ba35f7fc0d0.tar.gz
sncontinue-6b708464d76a92e12dac40081cc51ba35f7fc0d0.tar.bz2
sncontinue-6b708464d76a92e12dac40081cc51ba35f7fc0d0.zip
support for azure endpoints
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py2
-rw-r--r--continuedev/src/continuedev/core/config.py8
-rw-r--r--continuedev/src/continuedev/core/sdk.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py19
4 files changed, 24 insertions, 7 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index fb8da2e8..4e177ac9 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -36,7 +36,7 @@ def get_error_title(e: Exception) -> str:
elif isinstance(e, openai_errors.APIConnectionError):
return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\""
elif isinstance(e, openai_errors.InvalidRequestError):
- return 'Your API key does not have access to GPT-4. You can use ours for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""'
+ return 'Invalid request sent to OpenAI. Please try again.'
elif e.__str__().startswith("Cannot connect to host"):
return "The request failed. Please check your internet connection and try again."
return e.__str__() or e.__repr__()
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 91a47c8e..98615c64 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -67,13 +67,18 @@ DEFAULT_SLASH_COMMANDS = [
]
+class AzureInfo(BaseModel):
+ endpoint: str
+ engine: str
+ api_version: str
+
+
class ContinueConfig(BaseModel):
"""
A pydantic class for the continue config file.
"""
steps_on_startup: Optional[Dict[str, Dict]] = {}
disallowed_steps: Optional[List[str]] = []
- server_url: Optional[str] = None
allow_anonymous_telemetry: Optional[bool] = True
default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
"gpt-4", "ggml"] = 'gpt-4'
@@ -86,6 +91,7 @@ class ContinueConfig(BaseModel):
on_traceback: Optional[List[OnTracebackSteps]] = [
OnTracebackSteps(step_name="DefaultOnTracebackStep")]
system_message: Optional[str] = None
+ azure_openai_info: Optional[AzureInfo] = None
# Want to force these to be the slash commands for now
@validator('slash_commands', pre=True)
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index ac57c122..7e612d3b 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -56,7 +56,7 @@ class Models:
api_key = self.provider_keys["openai"]
if api_key == "":
return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message)
- return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message)
+ return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info)
def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
api_key = self.provider_keys["hf_inference_api"]
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index d973f19e..33d10985 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,30 +1,41 @@
from functools import cached_property
-import time
from typing import Any, Coroutine, Dict, Generator, List, Union
+
from ...core.main import ChatMessage
import openai
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ...core.config import AzureInfo
class OpenAI(LLM):
api_key: str
default_model: str
- def __init__(self, api_key: str, default_model: str, system_message: str = None):
+ def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None):
self.api_key = api_key
self.default_model = default_model
self.system_message = system_message
+ self.azure_info = azure_info
openai.api_key = api_key
+ # Using an Azure OpenAI deployment
+ if azure_info is not None:
+ openai.api_type = "azure"
+ openai.api_base = azure_info.endpoint
+ openai.api_version = azure_info.api_version
+
@cached_property
def name(self):
return self.default_model
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ args = {**DEFAULT_ARGS, "model": self.default_model}
+ if self.azure_info is not None:
+ args["engine"] = self.azure_info.engine
+ return args
def count_tokens(self, text: str):
return count_tokens(self.default_model, text)