summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/libs/llm/openai.py
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-17 13:33:29 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-17 13:33:29 -0700
commit05d665e65aaef62254a4da9a7a381f9984ff0db5 (patch)
tree9b5c08baa5c7c1da051e4109ae34fb8a141c2754 /continuedev/src/continuedev/libs/llm/openai.py
parent868e0b7ef5357b89186119c3c2fa8bd427b8db30 (diff)
parent6e95cb64cd5b2e2d55200bf979106f18d395bb97 (diff)
downloadsncontinue-05d665e65aaef62254a4da9a7a381f9984ff0db5.tar.gz
sncontinue-05d665e65aaef62254a4da9a7a381f9984ff0db5.tar.bz2
sncontinue-05d665e65aaef62254a4da9a7a381f9984ff0db5.zip
Merge branch 'main' of https://github.com/continuedev/continue into anthropic
Diffstat (limited to 'continuedev/src/continuedev/libs/llm/openai.py')
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py25
1 files changed, 18 insertions, 7 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index f0877d90..33d10985 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,30 +1,41 @@
from functools import cached_property
-import time
from typing import Any, Coroutine, Dict, Generator, List, Union
+
from ...core.main import ChatMessage
import openai
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ...core.config import AzureInfo
class OpenAI(LLM):
api_key: str
default_model: str
- def __init__(self, api_key: str, default_model: str, system_message: str = None):
+ def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None):
self.api_key = api_key
self.default_model = default_model
self.system_message = system_message
+ self.azure_info = azure_info
openai.api_key = api_key
+ # Using an Azure OpenAI deployment
+ if azure_info is not None:
+ openai.api_type = "azure"
+ openai.api_base = azure_info.endpoint
+ openai.api_version = azure_info.api_version
+
@cached_property
def name(self):
return self.default_model
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ args = {**DEFAULT_ARGS, "model": self.default_model}
+ if self.azure_info is not None:
+ args["engine"] = self.azure_info.engine
+ return args
def count_tokens(self, text: str):
return count_tokens(self.default_model, text)
@@ -37,7 +48,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
**args,
):
if "content" in chunk.choices[0].delta:
@@ -58,7 +69,7 @@ class OpenAI(LLM):
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], messages, args["max_tokens"], functions=args.get("functions", None)),
+ args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message),
**args,
):
yield chunk.choices[0].delta
@@ -69,7 +80,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
resp = (await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
**args,
)).choices[0].message.content
else: