diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/main.py | 1 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 28 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 47 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 33 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 9 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide.py | 14 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 2 | 
9 files changed, 121 insertions, 24 deletions
| diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 88690c83..5931d978 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -102,6 +102,7 @@ class HistoryNode(ContinueBaseModel):      depth: int      deleted: bool = False      active: bool = True +    logs: List[str] = []      def to_chat_messages(self) -> List[ChatMessage]:          if self.step.description is None or self.step.manage_own_chat_context: diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 280fefa8..53214384 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -37,6 +37,25 @@ class Models:      model_providers: List[ModelProvider]      system_message: str +    """ +    Better to have sdk.llm.stream_chat(messages, model="claude-2"). +    Then you also don't care that it' async. +    And it's easier to add more models. +    And intermediate shared code is easier to add. +    And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" +    PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. +    Easy to reason about, can place anywhere. +    And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. +    This can all happen inside of Models? + +    class Prompt: +        def __init__(self, ...info): +            '''take whatever info is needed to describe the prompt''' + +        def to_string(self, model: str) -> str: +            '''depending on the model, return the single prompt string''' +    """ +      def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):          self.sdk = sdk          self.model_providers = model_providers @@ -59,8 +78,8 @@ class Models:      def __load_openai_model(self, model: str) -> OpenAI:          api_key = self.provider_keys["openai"]          if api_key == "": -            return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message) -        return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info) +            return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log) +        return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log)      def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:          api_key = self.provider_keys["hf_inference_api"] @@ -156,6 +175,9 @@ class ContinueSDK(AbstractContinueSDK):      def history(self) -> History:          return self.__autopilot.history +    def write_log(self, message: str): +        self.history.timeline[self.history.current_index].logs.append(message) +      async def _ensure_absolute_path(self, path: str) -> str:          if os.path.isabs(path):              return path @@ -263,7 +285,7 @@ class ContinueSDK(AbstractContinueSDK):          for rif in highlighted_code:              msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```", -                              role="system", summary=f"{preface}: {rif.filepath}") +                              role="user", summary=f"{preface}: {rif.filepath}")              # Don't insert after latest user message or function call              i = -1 diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 33d10985..64bb39a2 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,10 +1,11 @@  from functools import cached_property -from typing import Any, Coroutine, Dict, Generator, List, Union +import json +from typing import Any, Callable, Coroutine, Dict, Generator, List, Union  from ...core.main import ChatMessage  import openai  from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top  from ...core.config import AzureInfo @@ -12,11 +13,12 @@ class OpenAI(LLM):      api_key: str      default_model: str -    def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None): +    def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None):          self.api_key = api_key          self.default_model = default_model          self.system_message = system_message          self.azure_info = azure_info +        self.write_log = write_log          openai.api_key = api_key @@ -46,18 +48,29 @@ class OpenAI(LLM):          args["stream"] = True          if args["model"] in CHAT_MODELS: +            messages = compile_chat_messages( +                args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +            self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") +            completion = ""              async for chunk in await openai.ChatCompletion.acreate( -                messages=compile_chat_messages( -                    args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), +                messages=messages,                  **args,              ):                  if "content" in chunk.choices[0].delta:                      yield chunk.choices[0].delta.content +                    completion += chunk.choices[0].delta.content                  else:                      continue + +            self.write_log(f"Completion: \n\n{completion}")          else: +            self.write_log(f"Prompt:\n\n{prompt}") +            completion = ""              async for chunk in await openai.Completion.acreate(prompt=prompt, **args):                  yield chunk.choices[0].text +                completion += chunk.choices[0].text + +            self.write_log(f"Completion:\n\n{completion}")      async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy() @@ -67,27 +80,39 @@ class OpenAI(LLM):          if not args["model"].endswith("0613") and "functions" in args:              del args["functions"] +        messages = compile_chat_messages( +            args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message) +        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") +        completion = ""          async for chunk in await openai.ChatCompletion.acreate( -            messages=compile_chat_messages( -                args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message), +            messages=messages,              **args,          ):              yield chunk.choices[0].delta +            if "content" in chunk.choices[0].delta: +                completion += chunk.choices[0].delta.content +        self.write_log(f"Completion: \n\n{completion}")      async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs}          if args["model"] in CHAT_MODELS: +            messages = compile_chat_messages( +                args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +            self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")              resp = (await openai.ChatCompletion.acreate( -                messages=compile_chat_messages( -                    args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), +                messages=messages,                  **args,              )).choices[0].message.content +            self.write_log(f"Completion: \n\n{resp}")          else: +            prompt = prune_raw_prompt_from_top( +                args["model"], prompt, args["max_tokens"]) +            self.write_log(f"Prompt:\n\n{prompt}")              resp = (await openai.Completion.acreate( -                prompt=prune_raw_prompt_from_top( -                    args["model"], prompt, args["max_tokens"]), +                prompt=prompt,                  **args,              )).choices[0].text +            self.write_log(f"Completion:\n\n{resp}")          return resp diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 3ec492f3..91b5842a 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,10 +1,11 @@ +  from functools import cached_property  import json -from typing import Any, Coroutine, Dict, Generator, List, Literal, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union  import aiohttp  from ...core.main import ChatMessage  from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens +from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens, format_chat_messages  import certifi  import ssl @@ -19,12 +20,14 @@ class ProxyServer(LLM):      unique_id: str      name: str      default_model: Literal["gpt-3.5-turbo", "gpt-4"] +    write_log: Callable[[str], None] -    def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None): +    def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None):          self.unique_id = unique_id          self.default_model = default_model          self.system_message = system_message          self.name = default_model +        self.write_log = write_log      @property      def default_args(self): @@ -36,14 +39,19 @@ class ProxyServer(LLM):      async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} +        messages = compile_chat_messages( +            args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:              async with session.post(f"{SERVER_URL}/complete", json={ -                "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), +                "messages": messages,                  "unique_id": self.unique_id,                  **args              }) as resp:                  try: -                    return await resp.text() +                    response_text = await resp.text() +                    self.write_log(f"Completion: \n\n{response_text}") +                    return response_text                  except:                      raise Exception(await resp.text()) @@ -51,6 +59,7 @@ class ProxyServer(LLM):          args = {**self.default_args, **kwargs}          messages = compile_chat_messages(              self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:              async with session.post(f"{SERVER_URL}/stream_chat", json={ @@ -59,6 +68,7 @@ class ProxyServer(LLM):                  **args              }) as resp:                  # This is streaming application/json instaed of text/event-stream +                completion = ""                  async for line in resp.content.iter_chunks():                      if line[1]:                          try: @@ -67,14 +77,19 @@ class ProxyServer(LLM):                              chunks = json_chunk.split("\n")                              for chunk in chunks:                                  if chunk.strip() != "": -                                    yield json.loads(chunk) +                                    loaded_chunk = json.loads(chunk) +                                    yield loaded_chunk +                                    if "content" in loaded_chunk: +                                        completion += loaded_chunk["content"]                          except:                              raise Exception(str(line[0])) +                self.write_log(f"Completion: \n\n{completion}")      async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages(              self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) +        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:              async with session.post(f"{SERVER_URL}/stream_complete", json={ @@ -82,9 +97,13 @@ class ProxyServer(LLM):                  "unique_id": self.unique_id,                  **args              }) as resp: +                completion = ""                  async for line in resp.content.iter_any():                      if line:                          try: -                            yield line.decode("utf-8") +                            decoded_line = line.decode("utf-8") +                            yield decoded_line +                            completion += decoded_line                          except:                              raise Exception(str(line)) +                self.write_log(f"Completion: \n\n{completion}") diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 1d5d6729..13de7990 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -107,3 +107,10 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int,          })      return history + + +def format_chat_messages(messages: List[ChatMessage]) -> str: +    formatted = "" +    for msg in messages: +        formatted += f"<{msg['role'].capitalize()}>\n{msg['content']}\n\n" +    return formatted diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 4201353e..ae57c0b6 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -99,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):                  self.on_set_editing_at_indices(data["indices"])              elif message_type == "set_pinned_at_indices":                  self.on_set_pinned_at_indices(data["indices"]) +            elif message_type == "show_logs_at_index": +                self.on_show_logs_at_index(data["index"])          except Exception as e:              print(e) @@ -166,6 +168,13 @@ class GUIProtocolServer(AbstractGUIProtocolServer):                  indices), self.session.autopilot.continue_sdk.ide.unique_id          ) +    def on_show_logs_at_index(self, index: int): +        name = f"continue_logs.txt" +        logs = "\n\n############################################\n\n".join( +            ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) +        create_async_task( +            self.session.autopilot.ide.showVirtualFile(name, logs)) +  @router.websocket("/ws")  async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 43538407..aeff5623 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -224,6 +224,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):              "open": open          }) +    async def showVirtualFile(self, name: str, contents: str): +        await self._send_json("showVirtualFile", { +            "name": name, +            "contents": contents +        }) +      async def setSuggestionsLocked(self, filepath: str, locked: bool = True):          # Lock suggestions in the file so they don't ruin the offset before others are inserted          await self._send_json("setSuggestionsLocked", { @@ -288,6 +294,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer):          pass      def __get_autopilot(self): +        if self.session_id not in self.session_manager.sessions: +            return None          return self.session_manager.sessions[self.session_id].autopilot      def onFileEdits(self, edits: List[FileEditWithFullContents]): @@ -442,7 +450,8 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):          if session_id is not None:              session_manager.registered_ides[session_id] = ideProtocolServer          other_msgs = await ideProtocolServer.initialize(session_id) -        capture_event(ideProtocolServer.unique_id, "session_started", { "session_id": ideProtocolServer.session_id }) +        capture_event(ideProtocolServer.unique_id, "session_started", { +                      "session_id": ideProtocolServer.session_id})          for other_msg in other_msgs:              handle_msg(other_msg) @@ -463,5 +472,6 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):          if websocket.client_state != WebSocketState.DISCONNECTED:              await websocket.close() -        capture_event(ideProtocolServer.unique_id, "session_ended", { "session_id": ideProtocolServer.session_id }) +        capture_event(ideProtocolServer.unique_id, "session_ended", { +                      "session_id": ideProtocolServer.session_id})          session_manager.registered_ides.pop(ideProtocolServer.session_id) diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index d0fb0bf8..0ae7e7fa 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -24,6 +24,10 @@ class AbstractIdeProtocolServer(ABC):          """Set whether a file is open"""      @abstractmethod +    async def showVirtualFile(self, name: str, contents: str): +        """Show a virtual file""" + +    @abstractmethod      async def setSuggestionsLocked(self, filepath: str, locked: bool = True):          """Set whether suggestions are locked""" diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 6d109ca6..90172a4e 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -100,7 +100,7 @@ class SessionManager:          if session_id not in self.sessions:              raise SessionNotFound(f"Session {session_id} not found")          if self.sessions[session_id].ws is None: -            print(f"Session {session_id} has no websocket") +            # print(f"Session {session_id} has no websocket")              return          await self.sessions[session_id].ws.send_json({ | 
