diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-07-28 17:06:38 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-07-28 17:06:38 -0700 | 
| commit | cb0c815ad799050ecc0abdf3d15981e9832b9829 (patch) | |
| tree | ffa35d2bb595fbfa59e4ae03886d46bc048dd0a2 /continuedev/src | |
| parent | 99ece78c8640495fbdabd95d30c26d620045b0ec (diff) | |
| download | sncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.tar.gz sncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.tar.bz2 sncontinue-cb0c815ad799050ecc0abdf3d15981e9832b9829.zip | |
feat: :sparkles: allow custom OpenAI base_url
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 11 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 20 | 
3 files changed, 18 insertions, 15 deletions
| diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index cb9c8977..e367e06c 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -25,10 +25,11 @@ class OnTracebackSteps(BaseModel):      params: Optional[Dict] = {} -class AzureInfo(BaseModel): -    endpoint: str -    engine: str -    api_version: str +class OpenAIServerInfo(BaseModel): +    api_base: Optional[str] = None +    engine: Optional[str] = None +    api_version: Optional[str] = None +    api_type: Literal["azure", "openai"] = "openai"  class ContinueConfig(BaseModel): @@ -49,7 +50,7 @@ class ContinueConfig(BaseModel):      slash_commands: Optional[List[SlashCommand]] = []      on_traceback: Optional[List[OnTracebackSteps]] = []      system_message: Optional[str] = None -    azure_openai_info: Optional[AzureInfo] = None +    openai_server_info: Optional[OpenAIServerInfo] = None      context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index d75aac00..9ee9ea06 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -81,7 +81,7 @@ class Models:          api_key = self.provider_keys["openai"]          if api_key == "":              return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log) -        return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log) +        return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log)      def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:          api_key = self.provider_keys["hf_inference_api"] diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index a0773c1d..654c7326 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -6,27 +6,29 @@ from ...core.main import ChatMessage  import openai  from ..llm import LLM  from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top -from ...core.config import AzureInfo +from ...core.config import OpenAIServerInfo  class OpenAI(LLM):      api_key: str      default_model: str -    def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): +    def __init__(self, api_key: str, default_model: str, system_message: str = None, openai_server_info: OpenAIServerInfo = None, write_log: Callable[[str], None] = None):          self.api_key = api_key          self.default_model = default_model          self.system_message = system_message -        self.azure_info = azure_info +        self.openai_server_info = openai_server_info          self.write_log = write_log          openai.api_key = api_key          # Using an Azure OpenAI deployment -        if azure_info is not None: -            openai.api_type = "azure" -            openai.api_base = azure_info.endpoint -            openai.api_version = azure_info.api_version +        if openai_server_info is not None: +            openai.api_type = openai_server_info.api_type +            if openai_server_info.api_base is not None: +                openai.api_base = openai_server_info.api_base +            if openai_server_info.api_version is not None: +                openai.api_version = openai_server_info.api_version      @cached_property      def name(self): @@ -35,8 +37,8 @@ class OpenAI(LLM):      @property      def default_args(self):          args = {**DEFAULT_ARGS, "model": self.default_model} -        if self.azure_info is not None: -            args["engine"] = self.azure_info.engine +        if self.openai_server_info is not None: +            args["engine"] = self.openai_server_info.engine          return args      def count_tokens(self, text: str): | 
