diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-17 13:33:29 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-17 13:33:29 -0700 |
commit | 05d665e65aaef62254a4da9a7a381f9984ff0db5 (patch) | |
tree | 9b5c08baa5c7c1da051e4109ae34fb8a141c2754 /continuedev/src/continuedev/libs/llm/openai.py | |
parent | 868e0b7ef5357b89186119c3c2fa8bd427b8db30 (diff) | |
parent | 6e95cb64cd5b2e2d55200bf979106f18d395bb97 (diff) | |
download | sncontinue-05d665e65aaef62254a4da9a7a381f9984ff0db5.tar.gz sncontinue-05d665e65aaef62254a4da9a7a381f9984ff0db5.tar.bz2 sncontinue-05d665e65aaef62254a4da9a7a381f9984ff0db5.zip |
Merge branch 'main' of https://github.com/continuedev/continue into anthropic
Diffstat (limited to 'continuedev/src/continuedev/libs/llm/openai.py')
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 25 |
1 files changed, 18 insertions, 7 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index f0877d90..33d10985 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,30 +1,41 @@ from functools import cached_property -import time from typing import Any, Coroutine, Dict, Generator, List, Union + from ...core.main import ChatMessage import openai from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ...core.config import AzureInfo class OpenAI(LLM): api_key: str default_model: str - def __init__(self, api_key: str, default_model: str, system_message: str = None): + def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None): self.api_key = api_key self.default_model = default_model self.system_message = system_message + self.azure_info = azure_info openai.api_key = api_key + # Using an Azure OpenAI deployment + if azure_info is not None: + openai.api_type = "azure" + openai.api_base = azure_info.endpoint + openai.api_version = azure_info.api_version + @cached_property def name(self): return self.default_model @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.default_model} + args = {**DEFAULT_ARGS, "model": self.default_model} + if self.azure_info is not None: + args["engine"] = self.azure_info.engine + return args def count_tokens(self, text: str): return count_tokens(self.default_model, text) @@ -37,7 +48,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: async for chunk in await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None), + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), **args, ): if "content" in chunk.choices[0].delta: @@ -58,7 +69,7 @@ class OpenAI(LLM): async for chunk in await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], messages, args["max_tokens"], functions=args.get("functions", None)), + args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message), **args, ): yield chunk.choices[0].delta @@ -69,7 +80,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: resp = (await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None), + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), **args, )).choices[0].message.content else: |