summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/libs
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src/continuedev/libs')
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py47
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py33
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py7
3 files changed, 69 insertions, 18 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 33d10985..64bb39a2 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,10 +1,11 @@
from functools import cached_property
-from typing import Any, Coroutine, Dict, Generator, List, Union
+import json
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Union
from ...core.main import ChatMessage
import openai
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
from ...core.config import AzureInfo
@@ -12,11 +13,12 @@ class OpenAI(LLM):
api_key: str
default_model: str
- def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None):
+ def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None):
self.api_key = api_key
self.default_model = default_model
self.system_message = system_message
self.azure_info = azure_info
+ self.write_log = write_log
openai.api_key = api_key
@@ -46,18 +48,29 @@ class OpenAI(LLM):
args["stream"] = True
if args["model"] in CHAT_MODELS:
+ messages = compile_chat_messages(
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ completion = ""
async for chunk in await openai.ChatCompletion.acreate(
- messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ messages=messages,
**args,
):
if "content" in chunk.choices[0].delta:
yield chunk.choices[0].delta.content
+ completion += chunk.choices[0].delta.content
else:
continue
+
+ self.write_log(f"Completion: \n\n{completion}")
else:
+ self.write_log(f"Prompt:\n\n{prompt}")
+ completion = ""
async for chunk in await openai.Completion.acreate(prompt=prompt, **args):
yield chunk.choices[0].text
+ completion += chunk.choices[0].text
+
+ self.write_log(f"Completion:\n\n{completion}")
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
@@ -67,27 +80,39 @@ class OpenAI(LLM):
if not args["model"].endswith("0613") and "functions" in args:
del args["functions"]
+ messages = compile_chat_messages(
+ args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ completion = ""
async for chunk in await openai.ChatCompletion.acreate(
- messages=compile_chat_messages(
- args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message),
+ messages=messages,
**args,
):
yield chunk.choices[0].delta
+ if "content" in chunk.choices[0].delta:
+ completion += chunk.choices[0].delta.content
+ self.write_log(f"Completion: \n\n{completion}")
async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
if args["model"] in CHAT_MODELS:
+ messages = compile_chat_messages(
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
resp = (await openai.ChatCompletion.acreate(
- messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ messages=messages,
**args,
)).choices[0].message.content
+ self.write_log(f"Completion: \n\n{resp}")
else:
+ prompt = prune_raw_prompt_from_top(
+ args["model"], prompt, args["max_tokens"])
+ self.write_log(f"Prompt:\n\n{prompt}")
resp = (await openai.Completion.acreate(
- prompt=prune_raw_prompt_from_top(
- args["model"], prompt, args["max_tokens"]),
+ prompt=prompt,
**args,
)).choices[0].text
+ self.write_log(f"Completion:\n\n{resp}")
return resp
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 3ec492f3..91b5842a 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -1,10 +1,11 @@
+
from functools import cached_property
import json
-from typing import Any, Coroutine, Dict, Generator, List, Literal, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union
import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens
+from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens, format_chat_messages
import certifi
import ssl
@@ -19,12 +20,14 @@ class ProxyServer(LLM):
unique_id: str
name: str
default_model: Literal["gpt-3.5-turbo", "gpt-4"]
+ write_log: Callable[[str], None]
- def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None):
+ def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None):
self.unique_id = unique_id
self.default_model = default_model
self.system_message = system_message
self.name = default_model
+ self.write_log = write_log
@property
def default_args(self):
@@ -36,14 +39,19 @@ class ProxyServer(LLM):
async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/complete", json={
- "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ "messages": messages,
"unique_id": self.unique_id,
**args
}) as resp:
try:
- return await resp.text()
+ response_text = await resp.text()
+ self.write_log(f"Completion: \n\n{response_text}")
+ return response_text
except:
raise Exception(await resp.text())
@@ -51,6 +59,7 @@ class ProxyServer(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_chat", json={
@@ -59,6 +68,7 @@ class ProxyServer(LLM):
**args
}) as resp:
# This is streaming application/json instaed of text/event-stream
+ completion = ""
async for line in resp.content.iter_chunks():
if line[1]:
try:
@@ -67,14 +77,19 @@ class ProxyServer(LLM):
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
- yield json.loads(chunk)
+ loaded_chunk = json.loads(chunk)
+ yield loaded_chunk
+ if "content" in loaded_chunk:
+ completion += loaded_chunk["content"]
except:
raise Exception(str(line[0]))
+ self.write_log(f"Completion: \n\n{completion}")
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_complete", json={
@@ -82,9 +97,13 @@ class ProxyServer(LLM):
"unique_id": self.unique_id,
**args
}) as resp:
+ completion = ""
async for line in resp.content.iter_any():
if line:
try:
- yield line.decode("utf-8")
+ decoded_line = line.decode("utf-8")
+ yield decoded_line
+ completion += decoded_line
except:
raise Exception(str(line))
+ self.write_log(f"Completion: \n\n{completion}")
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 1d5d6729..13de7990 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -107,3 +107,10 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int,
})
return history
+
+
+def format_chat_messages(messages: List[ChatMessage]) -> str:
+ formatted = ""
+ for msg in messages:
+ formatted += f"<{msg['role'].capitalize()}>\n{msg['content']}\n\n"
+ return formatted