diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 19 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_tgi.py | 127 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/llamacpp.py | 12 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ollama.py | 148 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/prompts/chat.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/edit_config.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/paths.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/plugins/steps/chat.py | 30 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 11 | 
11 files changed, 232 insertions, 143 deletions
| diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 778f81b3..37992b67 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -228,6 +228,8 @@ class ContinueSDK(AbstractContinueSDK):          spec.loader.exec_module(config)          self._last_valid_config = config.config +        logger.debug("Loaded Continue config file from %s", path) +          return config.config      def get_code_context( diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 9a7d0ac9..16bc2fce 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -4,12 +4,7 @@ from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic  from ...core.main import ChatMessage  from ..llm import LLM -from ..util.count_tokens import ( -    DEFAULT_ARGS, -    compile_chat_messages, -    count_tokens, -    format_chat_messages, -) +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens  class AnthropicLLM(LLM): @@ -118,9 +113,10 @@ class AnthropicLLM(LLM):          )          completion = "" -        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") +        prompt = self.__messages_to_prompt(messages) +        self.write_log(f"Prompt: \n\n{prompt}")          async for chunk in await self._async_client.completions.create( -            prompt=self.__messages_to_prompt(messages), **args +            prompt=prompt, **args          ):              yield {"role": "assistant", "content": chunk.completion}              completion += chunk.completion @@ -143,11 +139,10 @@ class AnthropicLLM(LLM):              system_message=self.system_message,          ) -        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") +        prompt = self.__messages_to_prompt(messages) +        self.write_log(f"Prompt: \n\n{prompt}")          resp = ( -            await self._async_client.completions.create( -                prompt=self.__messages_to_prompt(messages), **args -            ) +            await self._async_client.completions.create(prompt=prompt, **args)          ).completion          self.write_log(f"Completion: \n\n{resp}") diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index be82c445..db3aaed7 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -147,16 +147,6 @@ class GGML(LLM):      ) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} -        # messages = compile_chat_messages( -        #     args["model"], -        #     with_history, -        #     self.context_length, -        #     args["max_tokens"], -        #     prompt, -        #     functions=None, -        #     system_message=self.system_message, -        # ) -          self.write_log(f"Prompt: \n\n{prompt}")          async with aiohttp.ClientSession(              connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py new file mode 100644 index 00000000..f04e700d --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -0,0 +1,127 @@ +import json +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union + +import aiohttp + +from ...core.main import ChatMessage +from ..llm import LLM +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from .prompts.chat import code_llama_template_messages + + +class HuggingFaceTGI(LLM): +    model: str = "huggingface-tgi" +    max_context_length: int = 2048 +    server_url: str = "http://localhost:8080" +    verify_ssl: Optional[bool] = None + +    template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages + +    requires_write_log = True + +    write_log: Optional[Callable[[str], None]] = None + +    class Config: +        arbitrary_types_allowed = True + +    async def start(self, write_log: Callable[[str], None], **kwargs): +        self.write_log = write_log + +    async def stop(self): +        pass + +    @property +    def name(self): +        return self.model + +    @property +    def context_length(self): +        return self.max_context_length + +    @property +    def default_args(self): +        return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + +    def _transform_args(self, args): +        args = { +            **args, +            "max_new_tokens": args.get("max_tokens", 1024), +        } +        args.pop("max_tokens", None) +        return args + +    def count_tokens(self, text: str): +        return count_tokens(self.name, text) + +    async def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]: +        args = self.default_args.copy() +        args.update(kwargs) +        args["stream"] = True + +        args = {**self.default_args, **kwargs} +        messages = compile_chat_messages( +            self.name, +            with_history, +            self.context_length, +            args["max_tokens"], +            prompt, +            functions=args.get("functions", None), +            system_message=self.system_message, +        ) + +        prompt = self.template_messages(messages) +        self.write_log(f"Prompt: \n\n{prompt}") +        completion = "" +        async with aiohttp.ClientSession( +            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) +        ) as client_session: +            async with client_session.post( +                f"{self.server_url}", +                json={"inputs": prompt, **self._transform_args(args)}, +            ) as resp: +                async for line in resp.content.iter_any(): +                    if line: +                        chunk = line.decode("utf-8") +                        json_chunk = json.loads(chunk) +                        text = json_chunk["details"]["best_of_sequences"][0][ +                            "generated_text" +                        ] +                        yield text +                        completion += text + +        self.write_log(f"Completion: \n\n{completion}") + +    async def stream_chat( +        self, messages: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]: +        args = {**self.default_args, **kwargs} +        messages = compile_chat_messages( +            self.name, +            messages, +            self.context_length, +            args["max_tokens"], +            None, +            functions=args.get("functions", None), +            system_message=self.system_message, +        ) + +        async for chunk in self.stream_complete( +            None, self.template_messages(messages), **args +        ): +            yield { +                "role": "assistant", +                "content": chunk, +            } + +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, str]: +        args = {**self.default_args, **kwargs} + +        completion = "" +        async for chunk in self.stream_complete(prompt, with_history, **args): +            completion += chunk + +        return completion diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index 9e424fde..6625065e 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -6,12 +6,7 @@ import aiohttp  from ...core.main import ChatMessage  from ..llm import LLM -from ..util.count_tokens import ( -    DEFAULT_ARGS, -    compile_chat_messages, -    count_tokens, -    format_chat_messages, -) +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens  from .prompts.chat import code_llama_template_messages @@ -108,7 +103,8 @@ class LlamaCpp(LLM):              system_message=self.system_message,          ) -        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") +        prompt = self.convert_to_chat(messages) +        self.write_log(f"Prompt: \n\n{prompt}")          completion = ""          async with aiohttp.ClientSession(              connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) @@ -116,7 +112,7 @@ class LlamaCpp(LLM):              async with client_session.post(                  f"{self.server_url}/completion",                  json={ -                    "prompt": self.convert_to_chat(messages), +                    "prompt": prompt,                      **self._transform_args(args),                  },                  headers={"Content-Type": "application/json"}, diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index c754e54d..03300435 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -8,6 +8,7 @@ import aiohttp  from ...core.main import ChatMessage  from ..llm import LLM  from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from .prompts.chat import llama2_template_messages  class Ollama(LLM): @@ -57,43 +58,6 @@ class Ollama(LLM):      def count_tokens(self, text: str):          return count_tokens(self.name, text) -    def convert_to_chat(self, msgs: ChatMessage) -> str: -        if len(msgs) == 0: -            return "" - -        prompt = "" -        has_system = msgs[0]["role"] == "system" -        if has_system and msgs[0]["content"] == "": -            has_system = False -            msgs.pop(0) - -        # TODO: Instead make stream_complete and stream_chat the same method. -        if len(msgs) == 1 and "[INST]" in msgs[0]["content"]: -            return msgs[0]["content"] - -        if has_system: -            system_message = dedent( -                f"""\ -                <<SYS>> -                {self.system_message} -                <</SYS>> -                 -                """ -            ) -            if len(msgs) > 1: -                prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" -            else: -                prompt += f"[INST] {system_message} [/INST]" -                return - -        for i in range(2 if has_system else 0, len(msgs)): -            if msgs[i]["role"] == "user": -                prompt += f"[INST] {msgs[i]['content']} [/INST]" -            else: -                prompt += msgs[i]["content"] - -        return prompt -      async def stream_complete(          self, prompt, with_history: List[ChatMessage] = None, **kwargs      ) -> Generator[Union[Any, List, Dict], None, None]: @@ -107,37 +71,36 @@ class Ollama(LLM):              functions=None,              system_message=self.system_message,          ) -        prompt = self.convert_to_chat(messages) +        prompt = llama2_template_messages(messages)          async with self._client_session.post(              f"{self.server_url}/api/generate",              json={ -                "prompt": prompt, +                "template": prompt,                  "model": self.model, +                "system": self.system_message, +                "options": {"temperature": args["temperature"]},              },          ) as resp:              url_decode_buffer = ""              async for line in resp.content.iter_any():                  if line: -                    try: -                        json_chunk = line.decode("utf-8") -                        chunks = json_chunk.split("\n") -                        for chunk in chunks: -                            if chunk.strip() != "": -                                j = json.loads(chunk) -                                if "response" in j: -                                    url_decode_buffer += j["response"] - -                                    if ( -                                        "&" in url_decode_buffer -                                        and url_decode_buffer.index("&") -                                        > len(url_decode_buffer) - 5 -                                    ): -                                        continue -                                    yield urllib.parse.unquote(url_decode_buffer) -                                    url_decode_buffer = "" -                    except: -                        raise Exception(str(line[0])) +                    json_chunk = line.decode("utf-8") +                    chunks = json_chunk.split("\n") +                    for chunk in chunks: +                        if chunk.strip() != "": +                            j = json.loads(chunk) +                            if "response" in j: +                                url_decode_buffer += j["response"] + +                                if ( +                                    "&" in url_decode_buffer +                                    and url_decode_buffer.index("&") +                                    > len(url_decode_buffer) - 5 +                                ): +                                    continue +                                yield urllib.parse.unquote(url_decode_buffer) +                                url_decode_buffer = ""      async def stream_chat(          self, messages: List[ChatMessage] = None, **kwargs @@ -152,67 +115,56 @@ class Ollama(LLM):              functions=None,              system_message=self.system_message,          ) -        prompt = self.convert_to_chat(messages) +        prompt = llama2_template_messages(messages) -        self.write_log(f"Prompt: {prompt}") +        self.write_log(f"Prompt:\n{prompt}") +        completion = ""          async with self._client_session.post(              f"{self.server_url}/api/generate",              json={ -                "prompt": prompt, +                "template": prompt,                  "model": self.model, +                "system": self.system_message, +                "options": {"temperature": args["temperature"]},              },          ) as resp: -            # This is streaming application/json instaed of text/event-stream -            url_decode_buffer = ""              async for line in resp.content.iter_chunks():                  if line[1]: -                    try: -                        json_chunk = line[0].decode("utf-8") -                        chunks = json_chunk.split("\n") -                        for chunk in chunks: -                            if chunk.strip() != "": -                                j = json.loads(chunk) -                                if "response" in j: -                                    url_decode_buffer += j["response"] -                                    if ( -                                        "&" in url_decode_buffer -                                        and url_decode_buffer.index("&") -                                        > len(url_decode_buffer) - 5 -                                    ): -                                        continue -                                    yield { -                                        "role": "assistant", -                                        "content": urllib.parse.unquote( -                                            url_decode_buffer -                                        ), -                                    } -                                    url_decode_buffer = "" -                    except: -                        raise Exception(str(line[0])) +                    json_chunk = line[0].decode("utf-8") +                    chunks = json_chunk.split("\n") +                    for chunk in chunks: +                        if chunk.strip() != "": +                            j = json.loads(chunk) +                            if "response" in j: +                                yield { +                                    "role": "assistant", +                                    "content": j["response"], +                                } +                                completion += j["response"] +        self.write_log(f"Completion:\n{completion}")      async def complete(          self, prompt: str, with_history: List[ChatMessage] = None, **kwargs      ) -> Coroutine[Any, Any, str]:          completion = "" - +        args = {**self.default_args, **kwargs}          async with self._client_session.post(              f"{self.server_url}/api/generate",              json={ -                "prompt": prompt, +                "template": prompt,                  "model": self.model, +                "system": self.system_message, +                "options": {"temperature": args["temperature"]},              },          ) as resp:              async for line in resp.content.iter_any():                  if line: -                    try: -                        json_chunk = line.decode("utf-8") -                        chunks = json_chunk.split("\n") -                        for chunk in chunks: -                            if chunk.strip() != "": -                                j = json.loads(chunk) -                                if "response" in j: -                                    completion += urllib.parse.unquote(j["response"]) -                    except: -                        raise Exception(str(line[0])) +                    json_chunk = line.decode("utf-8") +                    chunks = json_chunk.split("\n") +                    for chunk in chunks: +                        if chunk.strip() != "": +                            j = json.loads(chunk) +                            if "response" in j: +                                completion += urllib.parse.unquote(j["response"])          return completion diff --git a/continuedev/src/continuedev/libs/llm/prompts/chat.py b/continuedev/src/continuedev/libs/llm/prompts/chat.py index 110dfaae..c7c208c0 100644 --- a/continuedev/src/continuedev/libs/llm/prompts/chat.py +++ b/continuedev/src/continuedev/libs/llm/prompts/chat.py @@ -7,6 +7,11 @@ def llama2_template_messages(msgs: ChatMessage) -> str:      if len(msgs) == 0:          return "" +    if msgs[0]["role"] == "assistant": +        # These models aren't trained to handle assistant message coming first, +        # and typically these are just introduction messages from Continue +        msgs.pop(0) +      prompt = ""      has_system = msgs[0]["role"] == "system" diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index 3dd4646c..eed43054 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -5,12 +5,14 @@ import redbaron  from .paths import getConfigFilePath +  def get_config_source():      config_file_path = getConfigFilePath()      with open(config_file_path, "r") as file:          source_code = file.read()      return source_code +  def load_red():      source_code = get_config_source() @@ -104,6 +106,8 @@ def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron:  def create_string_node(string: str) -> redbaron.RedBaron: +    if "\n" in string: +        return redbaron.RedBaron(f'"""{string}"""')[0]      return redbaron.RedBaron(f'"{string}"')[0] diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index a411c5c3..b3e9ecc1 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -60,13 +60,6 @@ def getConfigFilePath() -> str:          if existing_content.strip() == "":              with open(path, "w") as f:                  f.write(default_config) -        elif " continuedev.core" in existing_content: -            with open(path, "w") as f: -                f.write( -                    existing_content.replace( -                        " continuedev.", " continuedev.src.continuedev." -                    ) -                )      return path diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index cbd94fe2..857183bc 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -1,3 +1,4 @@ +import html  import json  import os  from textwrap import dedent @@ -24,6 +25,12 @@ openai.api_key = OPENAI_API_KEY  FREE_USAGE_STEP_NAME = "Please enter OpenAI API key" +def add_ellipsis(text: str, max_length: int = 200) -> str: +    if len(text) > max_length: +        return text[: max_length - 3] + "..." +    return text + +  class SimpleChatStep(Step):      name: str = "Generating Response..."      manage_own_chat_context: bool = True @@ -91,13 +98,26 @@ class SimpleChatStep(Step):              if "content" in chunk:                  self.description += chunk["content"] + +                # HTML unencode +                end_size = len(chunk["content"]) - 6 +                if "&" in self.description[-end_size:]: +                    self.description = self.description[:-end_size] + html.unescape( +                        self.description[-end_size:] +                    ) +                  await sdk.update_ui() -        self.name = remove_quotes_and_escapes( -            await sdk.models.medium.complete( -                f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:', -                max_tokens=20, -            ) +        self.name = "Generating title..." +        await sdk.update_ui() +        self.name = add_ellipsis( +            remove_quotes_and_escapes( +                await sdk.models.medium.complete( +                    f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:', +                    max_tokens=20, +                ) +            ), +            200,          )          self.chat_context.append( diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 08c5efc5..dbf9ba0d 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -182,7 +182,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer):          name = "continue_logs.txt"          logs = "\n\n############################################\n\n".join(              [ -                "This is a log of the exact prompt/completion pairs sent/received from the LLM during this step" +                "This is a log of the prompt/completion pairs sent/received from the LLM during this step"              ]              + self.session.autopilot.continue_sdk.history.timeline[index].logs          ) @@ -244,8 +244,13 @@ class GUIProtocolServer(AbstractGUIProtocolServer):                      if prev_model is not None:                          exists = False                          for other in unused_models: -                            if display_llm_class(prev_model) == display_llm_class( -                                other +                            if ( +                                prev_model.__class__.__name__ +                                == other.__class__.__name__ +                                and ( +                                    not other.name.startswith("gpt") +                                    or prev_model.name == other.name +                                )                              ):                                  exists = True                                  break | 
