diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-08-30 19:55:18 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-08-30 19:55:18 -0700 | 
| commit | 28f5d7bedab05a8b061e4e7ee9055a5403786bbc (patch) | |
| tree | 8e32e9a0edcddf3dd3bf5dbf76e14fb09b15ca8e | |
| parent | a0e2e2d3d606d8bf465eac541a84aa57316ee271 (diff) | |
| download | sncontinue-28f5d7bedab05a8b061e4e7ee9055a5403786bbc.tar.gz sncontinue-28f5d7bedab05a8b061e4e7ee9055a5403786bbc.tar.bz2 sncontinue-28f5d7bedab05a8b061e4e7ee9055a5403786bbc.zip | |
fix: :art: many small improvements
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 19 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/llamacpp.py | 12 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ollama.py | 149 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/prompts/chat.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/edit_config.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/paths.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 11 | ||||
| -rw-r--r-- | extension/src/activation/environmentSetup.ts | 9 | 
10 files changed, 93 insertions, 135 deletions
| diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 778f81b3..37992b67 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -228,6 +228,8 @@ class ContinueSDK(AbstractContinueSDK):          spec.loader.exec_module(config)          self._last_valid_config = config.config +        logger.debug("Loaded Continue config file from %s", path) +          return config.config      def get_code_context( diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 9a7d0ac9..16bc2fce 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -4,12 +4,7 @@ from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic  from ...core.main import ChatMessage  from ..llm import LLM -from ..util.count_tokens import ( -    DEFAULT_ARGS, -    compile_chat_messages, -    count_tokens, -    format_chat_messages, -) +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens  class AnthropicLLM(LLM): @@ -118,9 +113,10 @@ class AnthropicLLM(LLM):          )          completion = "" -        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") +        prompt = self.__messages_to_prompt(messages) +        self.write_log(f"Prompt: \n\n{prompt}")          async for chunk in await self._async_client.completions.create( -            prompt=self.__messages_to_prompt(messages), **args +            prompt=prompt, **args          ):              yield {"role": "assistant", "content": chunk.completion}              completion += chunk.completion @@ -143,11 +139,10 @@ class AnthropicLLM(LLM):              system_message=self.system_message,          ) -        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") +        prompt = self.__messages_to_prompt(messages) +        self.write_log(f"Prompt: \n\n{prompt}")          resp = ( -            await self._async_client.completions.create( -                prompt=self.__messages_to_prompt(messages), **args -            ) +            await self._async_client.completions.create(prompt=prompt, **args)          ).completion          self.write_log(f"Completion: \n\n{resp}") diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index be82c445..db3aaed7 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -147,16 +147,6 @@ class GGML(LLM):      ) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} -        # messages = compile_chat_messages( -        #     args["model"], -        #     with_history, -        #     self.context_length, -        #     args["max_tokens"], -        #     prompt, -        #     functions=None, -        #     system_message=self.system_message, -        # ) -          self.write_log(f"Prompt: \n\n{prompt}")          async with aiohttp.ClientSession(              connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index 9e424fde..6625065e 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -6,12 +6,7 @@ import aiohttp  from ...core.main import ChatMessage  from ..llm import LLM -from ..util.count_tokens import ( -    DEFAULT_ARGS, -    compile_chat_messages, -    count_tokens, -    format_chat_messages, -) +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens  from .prompts.chat import code_llama_template_messages @@ -108,7 +103,8 @@ class LlamaCpp(LLM):              system_message=self.system_message,          ) -        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") +        prompt = self.convert_to_chat(messages) +        self.write_log(f"Prompt: \n\n{prompt}")          completion = ""          async with aiohttp.ClientSession(              connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) @@ -116,7 +112,7 @@ class LlamaCpp(LLM):              async with client_session.post(                  f"{self.server_url}/completion",                  json={ -                    "prompt": self.convert_to_chat(messages), +                    "prompt": prompt,                      **self._transform_args(args),                  },                  headers={"Content-Type": "application/json"}, diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index c754e54d..240d922b 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -8,6 +8,7 @@ import aiohttp  from ...core.main import ChatMessage  from ..llm import LLM  from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from .prompts.chat import llama2_template_messages  class Ollama(LLM): @@ -57,43 +58,6 @@ class Ollama(LLM):      def count_tokens(self, text: str):          return count_tokens(self.name, text) -    def convert_to_chat(self, msgs: ChatMessage) -> str: -        if len(msgs) == 0: -            return "" - -        prompt = "" -        has_system = msgs[0]["role"] == "system" -        if has_system and msgs[0]["content"] == "": -            has_system = False -            msgs.pop(0) - -        # TODO: Instead make stream_complete and stream_chat the same method. -        if len(msgs) == 1 and "[INST]" in msgs[0]["content"]: -            return msgs[0]["content"] - -        if has_system: -            system_message = dedent( -                f"""\ -                <<SYS>> -                {self.system_message} -                <</SYS>> -                 -                """ -            ) -            if len(msgs) > 1: -                prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" -            else: -                prompt += f"[INST] {system_message} [/INST]" -                return - -        for i in range(2 if has_system else 0, len(msgs)): -            if msgs[i]["role"] == "user": -                prompt += f"[INST] {msgs[i]['content']} [/INST]" -            else: -                prompt += msgs[i]["content"] - -        return prompt -      async def stream_complete(          self, prompt, with_history: List[ChatMessage] = None, **kwargs      ) -> Generator[Union[Any, List, Dict], None, None]: @@ -107,37 +71,36 @@ class Ollama(LLM):              functions=None,              system_message=self.system_message,          ) -        prompt = self.convert_to_chat(messages) +        prompt = llama2_template_messages(messages)          async with self._client_session.post(              f"{self.server_url}/api/generate",              json={ -                "prompt": prompt, +                "template": prompt,                  "model": self.model, +                "system": self.system_message, +                "options": {"temperature": args["temperature"]},              },          ) as resp:              url_decode_buffer = ""              async for line in resp.content.iter_any():                  if line: -                    try: -                        json_chunk = line.decode("utf-8") -                        chunks = json_chunk.split("\n") -                        for chunk in chunks: -                            if chunk.strip() != "": -                                j = json.loads(chunk) -                                if "response" in j: -                                    url_decode_buffer += j["response"] - -                                    if ( -                                        "&" in url_decode_buffer -                                        and url_decode_buffer.index("&") -                                        > len(url_decode_buffer) - 5 -                                    ): -                                        continue -                                    yield urllib.parse.unquote(url_decode_buffer) -                                    url_decode_buffer = "" -                    except: -                        raise Exception(str(line[0])) +                    json_chunk = line.decode("utf-8") +                    chunks = json_chunk.split("\n") +                    for chunk in chunks: +                        if chunk.strip() != "": +                            j = json.loads(chunk) +                            if "response" in j: +                                url_decode_buffer += j["response"] + +                                if ( +                                    "&" in url_decode_buffer +                                    and url_decode_buffer.index("&") +                                    > len(url_decode_buffer) - 5 +                                ): +                                    continue +                                yield urllib.parse.unquote(url_decode_buffer) +                                url_decode_buffer = ""      async def stream_chat(          self, messages: List[ChatMessage] = None, **kwargs @@ -152,67 +115,63 @@ class Ollama(LLM):              functions=None,              system_message=self.system_message,          ) -        prompt = self.convert_to_chat(messages) +        prompt = llama2_template_messages(messages)          self.write_log(f"Prompt: {prompt}")          async with self._client_session.post(              f"{self.server_url}/api/generate",              json={ -                "prompt": prompt, +                "template": prompt,                  "model": self.model, +                "system": self.system_message, +                "options": {"temperature": args["temperature"]},              },          ) as resp:              # This is streaming application/json instaed of text/event-stream              url_decode_buffer = ""              async for line in resp.content.iter_chunks():                  if line[1]: -                    try: -                        json_chunk = line[0].decode("utf-8") -                        chunks = json_chunk.split("\n") -                        for chunk in chunks: -                            if chunk.strip() != "": -                                j = json.loads(chunk) -                                if "response" in j: -                                    url_decode_buffer += j["response"] -                                    if ( -                                        "&" in url_decode_buffer -                                        and url_decode_buffer.index("&") -                                        > len(url_decode_buffer) - 5 -                                    ): -                                        continue -                                    yield { -                                        "role": "assistant", -                                        "content": urllib.parse.unquote( -                                            url_decode_buffer -                                        ), -                                    } -                                    url_decode_buffer = "" -                    except: -                        raise Exception(str(line[0])) +                    json_chunk = line[0].decode("utf-8") +                    chunks = json_chunk.split("\n") +                    for chunk in chunks: +                        if chunk.strip() != "": +                            j = json.loads(chunk) +                            if "response" in j: +                                url_decode_buffer += j["response"] +                                if ( +                                    "&" in url_decode_buffer +                                    and url_decode_buffer.index("&") +                                    > len(url_decode_buffer) - 5 +                                ): +                                    continue +                                yield { +                                    "role": "assistant", +                                    "content": urllib.parse.unquote(url_decode_buffer), +                                } +                                url_decode_buffer = ""      async def complete(          self, prompt: str, with_history: List[ChatMessage] = None, **kwargs      ) -> Coroutine[Any, Any, str]:          completion = "" - +        args = {**self.default_args, **kwargs}          async with self._client_session.post(              f"{self.server_url}/api/generate",              json={ -                "prompt": prompt, +                "template": prompt,                  "model": self.model, +                "system": self.system_message, +                "options": {"temperature": args["temperature"]},              },          ) as resp:              async for line in resp.content.iter_any():                  if line: -                    try: -                        json_chunk = line.decode("utf-8") -                        chunks = json_chunk.split("\n") -                        for chunk in chunks: -                            if chunk.strip() != "": -                                j = json.loads(chunk) -                                if "response" in j: -                                    completion += urllib.parse.unquote(j["response"]) -                    except: -                        raise Exception(str(line[0])) +                    json_chunk = line.decode("utf-8") +                    chunks = json_chunk.split("\n") +                    for chunk in chunks: +                        if chunk.strip() != "": +                            j = json.loads(chunk) +                            if "response" in j: +                                completion += urllib.parse.unquote(j["response"])          return completion diff --git a/continuedev/src/continuedev/libs/llm/prompts/chat.py b/continuedev/src/continuedev/libs/llm/prompts/chat.py index 110dfaae..c7c208c0 100644 --- a/continuedev/src/continuedev/libs/llm/prompts/chat.py +++ b/continuedev/src/continuedev/libs/llm/prompts/chat.py @@ -7,6 +7,11 @@ def llama2_template_messages(msgs: ChatMessage) -> str:      if len(msgs) == 0:          return "" +    if msgs[0]["role"] == "assistant": +        # These models aren't trained to handle assistant message coming first, +        # and typically these are just introduction messages from Continue +        msgs.pop(0) +      prompt = ""      has_system = msgs[0]["role"] == "system" diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index 3dd4646c..eed43054 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -5,12 +5,14 @@ import redbaron  from .paths import getConfigFilePath +  def get_config_source():      config_file_path = getConfigFilePath()      with open(config_file_path, "r") as file:          source_code = file.read()      return source_code +  def load_red():      source_code = get_config_source() @@ -104,6 +106,8 @@ def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron:  def create_string_node(string: str) -> redbaron.RedBaron: +    if "\n" in string: +        return redbaron.RedBaron(f'"""{string}"""')[0]      return redbaron.RedBaron(f'"{string}"')[0] diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index a411c5c3..b3e9ecc1 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -60,13 +60,6 @@ def getConfigFilePath() -> str:          if existing_content.strip() == "":              with open(path, "w") as f:                  f.write(default_config) -        elif " continuedev.core" in existing_content: -            with open(path, "w") as f: -                f.write( -                    existing_content.replace( -                        " continuedev.", " continuedev.src.continuedev." -                    ) -                )      return path diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 08c5efc5..dbf9ba0d 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -182,7 +182,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer):          name = "continue_logs.txt"          logs = "\n\n############################################\n\n".join(              [ -                "This is a log of the exact prompt/completion pairs sent/received from the LLM during this step" +                "This is a log of the prompt/completion pairs sent/received from the LLM during this step"              ]              + self.session.autopilot.continue_sdk.history.timeline[index].logs          ) @@ -244,8 +244,13 @@ class GUIProtocolServer(AbstractGUIProtocolServer):                      if prev_model is not None:                          exists = False                          for other in unused_models: -                            if display_llm_class(prev_model) == display_llm_class( -                                other +                            if ( +                                prev_model.__class__.__name__ +                                == other.__class__.__name__ +                                and ( +                                    not other.name.startswith("gpt") +                                    or prev_model.name == other.name +                                )                              ):                                  exists = True                                  break diff --git a/extension/src/activation/environmentSetup.ts b/extension/src/activation/environmentSetup.ts index 7ca87768..e67a5852 100644 --- a/extension/src/activation/environmentSetup.ts +++ b/extension/src/activation/environmentSetup.ts @@ -346,4 +346,13 @@ export async function startContinuePythonServer(redownload: boolean = true) {    // Write the current version of vscode extension to a file called server_version.txt    fs.writeFileSync(serverVersionPath(), getExtensionVersion()); + +  // If running on remote, forward the port +  if ( +    vscode.env.remoteName && +    vscode.extensions.getExtension("continue.continue")?.extensionKind === +      vscode.ExtensionKind.Workspace +  ) { +    await vscode.env.asExternalUri(vscode.Uri.parse(getContinueServerUrl())); +  }  } | 
