summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/main.py1
-rw-r--r--continuedev/src/continuedev/core/sdk.py28
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py47
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py33
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py7
-rw-r--r--continuedev/src/continuedev/server/gui.py9
-rw-r--r--continuedev/src/continuedev/server/ide.py14
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py4
-rw-r--r--continuedev/src/continuedev/server/session_manager.py2
9 files changed, 121 insertions, 24 deletions
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index 88690c83..5931d978 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -102,6 +102,7 @@ class HistoryNode(ContinueBaseModel):
depth: int
deleted: bool = False
active: bool = True
+ logs: List[str] = []
def to_chat_messages(self) -> List[ChatMessage]:
if self.step.description is None or self.step.manage_own_chat_context:
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 280fefa8..53214384 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -37,6 +37,25 @@ class Models:
model_providers: List[ModelProvider]
system_message: str
+ """
+ Better to have sdk.llm.stream_chat(messages, model="claude-2").
+ Then you also don't care that it' async.
+ And it's easier to add more models.
+ And intermediate shared code is easier to add.
+ And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
+ PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
+ Easy to reason about, can place anywhere.
+ And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
+ This can all happen inside of Models?
+
+ class Prompt:
+ def __init__(self, ...info):
+ '''take whatever info is needed to describe the prompt'''
+
+ def to_string(self, model: str) -> str:
+ '''depending on the model, return the single prompt string'''
+ """
+
def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
self.sdk = sdk
self.model_providers = model_providers
@@ -59,8 +78,8 @@ class Models:
def __load_openai_model(self, model: str) -> OpenAI:
api_key = self.provider_keys["openai"]
if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message)
- return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info)
+ return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log)
+ return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log)
def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
api_key = self.provider_keys["hf_inference_api"]
@@ -156,6 +175,9 @@ class ContinueSDK(AbstractContinueSDK):
def history(self) -> History:
return self.__autopilot.history
+ def write_log(self, message: str):
+ self.history.timeline[self.history.current_index].logs.append(message)
+
async def _ensure_absolute_path(self, path: str) -> str:
if os.path.isabs(path):
return path
@@ -263,7 +285,7 @@ class ContinueSDK(AbstractContinueSDK):
for rif in highlighted_code:
msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```",
- role="system", summary=f"{preface}: {rif.filepath}")
+ role="user", summary=f"{preface}: {rif.filepath}")
# Don't insert after latest user message or function call
i = -1
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 33d10985..64bb39a2 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,10 +1,11 @@
from functools import cached_property
-from typing import Any, Coroutine, Dict, Generator, List, Union
+import json
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Union
from ...core.main import ChatMessage
import openai
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
from ...core.config import AzureInfo
@@ -12,11 +13,12 @@ class OpenAI(LLM):
api_key: str
default_model: str
- def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None):
+ def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None):
self.api_key = api_key
self.default_model = default_model
self.system_message = system_message
self.azure_info = azure_info
+ self.write_log = write_log
openai.api_key = api_key
@@ -46,18 +48,29 @@ class OpenAI(LLM):
args["stream"] = True
if args["model"] in CHAT_MODELS:
+ messages = compile_chat_messages(
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ completion = ""
async for chunk in await openai.ChatCompletion.acreate(
- messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ messages=messages,
**args,
):
if "content" in chunk.choices[0].delta:
yield chunk.choices[0].delta.content
+ completion += chunk.choices[0].delta.content
else:
continue
+
+ self.write_log(f"Completion: \n\n{completion}")
else:
+ self.write_log(f"Prompt:\n\n{prompt}")
+ completion = ""
async for chunk in await openai.Completion.acreate(prompt=prompt, **args):
yield chunk.choices[0].text
+ completion += chunk.choices[0].text
+
+ self.write_log(f"Completion:\n\n{completion}")
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
@@ -67,27 +80,39 @@ class OpenAI(LLM):
if not args["model"].endswith("0613") and "functions" in args:
del args["functions"]
+ messages = compile_chat_messages(
+ args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ completion = ""
async for chunk in await openai.ChatCompletion.acreate(
- messages=compile_chat_messages(
- args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message),
+ messages=messages,
**args,
):
yield chunk.choices[0].delta
+ if "content" in chunk.choices[0].delta:
+ completion += chunk.choices[0].delta.content
+ self.write_log(f"Completion: \n\n{completion}")
async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
if args["model"] in CHAT_MODELS:
+ messages = compile_chat_messages(
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
resp = (await openai.ChatCompletion.acreate(
- messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ messages=messages,
**args,
)).choices[0].message.content
+ self.write_log(f"Completion: \n\n{resp}")
else:
+ prompt = prune_raw_prompt_from_top(
+ args["model"], prompt, args["max_tokens"])
+ self.write_log(f"Prompt:\n\n{prompt}")
resp = (await openai.Completion.acreate(
- prompt=prune_raw_prompt_from_top(
- args["model"], prompt, args["max_tokens"]),
+ prompt=prompt,
**args,
)).choices[0].text
+ self.write_log(f"Completion:\n\n{resp}")
return resp
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 3ec492f3..91b5842a 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -1,10 +1,11 @@
+
from functools import cached_property
import json
-from typing import Any, Coroutine, Dict, Generator, List, Literal, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union
import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens
+from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens, format_chat_messages
import certifi
import ssl
@@ -19,12 +20,14 @@ class ProxyServer(LLM):
unique_id: str
name: str
default_model: Literal["gpt-3.5-turbo", "gpt-4"]
+ write_log: Callable[[str], None]
- def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None):
+ def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None):
self.unique_id = unique_id
self.default_model = default_model
self.system_message = system_message
self.name = default_model
+ self.write_log = write_log
@property
def default_args(self):
@@ -36,14 +39,19 @@ class ProxyServer(LLM):
async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/complete", json={
- "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ "messages": messages,
"unique_id": self.unique_id,
**args
}) as resp:
try:
- return await resp.text()
+ response_text = await resp.text()
+ self.write_log(f"Completion: \n\n{response_text}")
+ return response_text
except:
raise Exception(await resp.text())
@@ -51,6 +59,7 @@ class ProxyServer(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_chat", json={
@@ -59,6 +68,7 @@ class ProxyServer(LLM):
**args
}) as resp:
# This is streaming application/json instaed of text/event-stream
+ completion = ""
async for line in resp.content.iter_chunks():
if line[1]:
try:
@@ -67,14 +77,19 @@ class ProxyServer(LLM):
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
- yield json.loads(chunk)
+ loaded_chunk = json.loads(chunk)
+ yield loaded_chunk
+ if "content" in loaded_chunk:
+ completion += loaded_chunk["content"]
except:
raise Exception(str(line[0]))
+ self.write_log(f"Completion: \n\n{completion}")
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+ self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_complete", json={
@@ -82,9 +97,13 @@ class ProxyServer(LLM):
"unique_id": self.unique_id,
**args
}) as resp:
+ completion = ""
async for line in resp.content.iter_any():
if line:
try:
- yield line.decode("utf-8")
+ decoded_line = line.decode("utf-8")
+ yield decoded_line
+ completion += decoded_line
except:
raise Exception(str(line))
+ self.write_log(f"Completion: \n\n{completion}")
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 1d5d6729..13de7990 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -107,3 +107,10 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int,
})
return history
+
+
+def format_chat_messages(messages: List[ChatMessage]) -> str:
+ formatted = ""
+ for msg in messages:
+ formatted += f"<{msg['role'].capitalize()}>\n{msg['content']}\n\n"
+ return formatted
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 4201353e..ae57c0b6 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -99,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.on_set_editing_at_indices(data["indices"])
elif message_type == "set_pinned_at_indices":
self.on_set_pinned_at_indices(data["indices"])
+ elif message_type == "show_logs_at_index":
+ self.on_show_logs_at_index(data["index"])
except Exception as e:
print(e)
@@ -166,6 +168,13 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
indices), self.session.autopilot.continue_sdk.ide.unique_id
)
+ def on_show_logs_at_index(self, index: int):
+ name = f"continue_logs.txt"
+ logs = "\n\n############################################\n\n".join(
+ ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs)
+ create_async_task(
+ self.session.autopilot.ide.showVirtualFile(name, logs))
+
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 43538407..aeff5623 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -224,6 +224,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
"open": open
})
+ async def showVirtualFile(self, name: str, contents: str):
+ await self._send_json("showVirtualFile", {
+ "name": name,
+ "contents": contents
+ })
+
async def setSuggestionsLocked(self, filepath: str, locked: bool = True):
# Lock suggestions in the file so they don't ruin the offset before others are inserted
await self._send_json("setSuggestionsLocked", {
@@ -288,6 +294,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
pass
def __get_autopilot(self):
+ if self.session_id not in self.session_manager.sessions:
+ return None
return self.session_manager.sessions[self.session_id].autopilot
def onFileEdits(self, edits: List[FileEditWithFullContents]):
@@ -442,7 +450,8 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
if session_id is not None:
session_manager.registered_ides[session_id] = ideProtocolServer
other_msgs = await ideProtocolServer.initialize(session_id)
- capture_event(ideProtocolServer.unique_id, "session_started", { "session_id": ideProtocolServer.session_id })
+ capture_event(ideProtocolServer.unique_id, "session_started", {
+ "session_id": ideProtocolServer.session_id})
for other_msg in other_msgs:
handle_msg(other_msg)
@@ -463,5 +472,6 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
- capture_event(ideProtocolServer.unique_id, "session_ended", { "session_id": ideProtocolServer.session_id })
+ capture_event(ideProtocolServer.unique_id, "session_ended", {
+ "session_id": ideProtocolServer.session_id})
session_manager.registered_ides.pop(ideProtocolServer.session_id)
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index d0fb0bf8..0ae7e7fa 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -24,6 +24,10 @@ class AbstractIdeProtocolServer(ABC):
"""Set whether a file is open"""
@abstractmethod
+ async def showVirtualFile(self, name: str, contents: str):
+ """Show a virtual file"""
+
+ @abstractmethod
async def setSuggestionsLocked(self, filepath: str, locked: bool = True):
"""Set whether suggestions are locked"""
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 6d109ca6..90172a4e 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -100,7 +100,7 @@ class SessionManager:
if session_id not in self.sessions:
raise SessionNotFound(f"Session {session_id} not found")
if self.sessions[session_id].ws is None:
- print(f"Session {session_id} has no websocket")
+ # print(f"Session {session_id} has no websocket")
return
await self.sessions[session_id].ws.send_json({