summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/libs/llm/openai.py
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src/continuedev/libs/llm/openai.py')
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py47
1 files changed, 36 insertions, 11 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 33d10985..64bb39a2 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,10 +1,11 @@
from functools import cached_property
-from typing import Any, Coroutine, Dict, Generator, List, Union
+import json
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Union
from ...core.main import ChatMessage
import openai
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
from ...core.config import AzureInfo
@@ -12,11 +13,12 @@ class OpenAI(LLM):
api_key: str
default_model: str
- def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None):
+ def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None):
self.api_key = api_key
self.default_model = default_model
self.system_message = system_message
self.azure_info = azure_info
+ self.write_log = write_log
openai.api_key = api_key
@@ -46,18 +48,29 @@ class OpenAI(LLM):
args["stream"] = True
if args["model"] in CHAT_MODELS:
+ messages = compile_chat_messages(
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ completion = ""
async for chunk in await openai.ChatCompletion.acreate(
- messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ messages=messages,
**args,
):
if "content" in chunk.choices[0].delta:
yield chunk.choices[0].delta.content
+ completion += chunk.choices[0].delta.content
else:
continue
+
+ self.write_log(f"Completion: \n\n{completion}")
else:
+ self.write_log(f"Prompt:\n\n{prompt}")
+ completion = ""
async for chunk in await openai.Completion.acreate(prompt=prompt, **args):
yield chunk.choices[0].text
+ completion += chunk.choices[0].text
+
+ self.write_log(f"Completion:\n\n{completion}")
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
@@ -67,27 +80,39 @@ class OpenAI(LLM):
if not args["model"].endswith("0613") and "functions" in args:
del args["functions"]
+ messages = compile_chat_messages(
+ args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ completion = ""
async for chunk in await openai.ChatCompletion.acreate(
- messages=compile_chat_messages(
- args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message),
+ messages=messages,
**args,
):
yield chunk.choices[0].delta
+ if "content" in chunk.choices[0].delta:
+ completion += chunk.choices[0].delta.content
+ self.write_log(f"Completion: \n\n{completion}")
async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
if args["model"] in CHAT_MODELS:
+ messages = compile_chat_messages(
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
resp = (await openai.ChatCompletion.acreate(
- messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ messages=messages,
**args,
)).choices[0].message.content
+ self.write_log(f"Completion: \n\n{resp}")
else:
+ prompt = prune_raw_prompt_from_top(
+ args["model"], prompt, args["max_tokens"])
+ self.write_log(f"Prompt:\n\n{prompt}")
resp = (await openai.Completion.acreate(
- prompt=prune_raw_prompt_from_top(
- args["model"], prompt, args["max_tokens"]),
+ prompt=prompt,
**args,
)).choices[0].text
+ self.write_log(f"Completion:\n\n{resp}")
return resp