diff options
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 19 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_tgi.py | 127 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/llamacpp.py | 12 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ollama.py | 148 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/prompts/chat.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/edit_config.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/paths.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/chat.py | 30 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 11 | ||||
-rw-r--r-- | extension/react-app/src/components/Layout.tsx | 2 |
12 files changed, 233 insertions, 144 deletions
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 778f81b3..37992b67 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -228,6 +228,8 @@ class ContinueSDK(AbstractContinueSDK): spec.loader.exec_module(config) self._last_valid_config = config.config + logger.debug("Loaded Continue config file from %s", path) + return config.config def get_code_context( diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 9a7d0ac9..16bc2fce 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -4,12 +4,7 @@ from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import ( - DEFAULT_ARGS, - compile_chat_messages, - count_tokens, - format_chat_messages, -) +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens class AnthropicLLM(LLM): @@ -118,9 +113,10 @@ class AnthropicLLM(LLM): ) completion = "" - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") + prompt = self.__messages_to_prompt(messages) + self.write_log(f"Prompt: \n\n{prompt}") async for chunk in await self._async_client.completions.create( - prompt=self.__messages_to_prompt(messages), **args + prompt=prompt, **args ): yield {"role": "assistant", "content": chunk.completion} completion += chunk.completion @@ -143,11 +139,10 @@ class AnthropicLLM(LLM): system_message=self.system_message, ) - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") + prompt = self.__messages_to_prompt(messages) + self.write_log(f"Prompt: \n\n{prompt}") resp = ( - await self._async_client.completions.create( - prompt=self.__messages_to_prompt(messages), **args - ) + await self._async_client.completions.create(prompt=prompt, **args) ).completion self.write_log(f"Completion: \n\n{resp}") diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index be82c445..db3aaed7 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -147,16 +147,6 @@ class GGML(LLM): ) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} - # messages = compile_chat_messages( - # args["model"], - # with_history, - # self.context_length, - # args["max_tokens"], - # prompt, - # functions=None, - # system_message=self.system_message, - # ) - self.write_log(f"Prompt: \n\n{prompt}") async with aiohttp.ClientSession( connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py new file mode 100644 index 00000000..f04e700d --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -0,0 +1,127 @@ +import json +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union + +import aiohttp + +from ...core.main import ChatMessage +from ..llm import LLM +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from .prompts.chat import code_llama_template_messages + + +class HuggingFaceTGI(LLM): + model: str = "huggingface-tgi" + max_context_length: int = 2048 + server_url: str = "http://localhost:8080" + verify_ssl: Optional[bool] = None + + template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages + + requires_write_log = True + + write_log: Optional[Callable[[str], None]] = None + + class Config: + arbitrary_types_allowed = True + + async def start(self, write_log: Callable[[str], None], **kwargs): + self.write_log = write_log + + async def stop(self): + pass + + @property + def name(self): + return self.model + + @property + def context_length(self): + return self.max_context_length + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + + def _transform_args(self, args): + args = { + **args, + "max_new_tokens": args.get("max_tokens", 1024), + } + args.pop("max_tokens", None) + return args + + def count_tokens(self, text: str): + return count_tokens(self.name, text) + + async def stream_complete( + self, prompt, with_history: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, + with_history, + self.context_length, + args["max_tokens"], + prompt, + functions=args.get("functions", None), + system_message=self.system_message, + ) + + prompt = self.template_messages(messages) + self.write_log(f"Prompt: \n\n{prompt}") + completion = "" + async with aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) + ) as client_session: + async with client_session.post( + f"{self.server_url}", + json={"inputs": prompt, **self._transform_args(args)}, + ) as resp: + async for line in resp.content.iter_any(): + if line: + chunk = line.decode("utf-8") + json_chunk = json.loads(chunk) + text = json_chunk["details"]["best_of_sequences"][0][ + "generated_text" + ] + yield text + completion += text + + self.write_log(f"Completion: \n\n{completion}") + + async def stream_chat( + self, messages: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, + messages, + self.context_length, + args["max_tokens"], + None, + functions=args.get("functions", None), + system_message=self.system_message, + ) + + async for chunk in self.stream_complete( + None, self.template_messages(messages), **args + ): + yield { + "role": "assistant", + "content": chunk, + } + + async def complete( + self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + ) -> Coroutine[Any, Any, str]: + args = {**self.default_args, **kwargs} + + completion = "" + async for chunk in self.stream_complete(prompt, with_history, **args): + completion += chunk + + return completion diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index 9e424fde..6625065e 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -6,12 +6,7 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import ( - DEFAULT_ARGS, - compile_chat_messages, - count_tokens, - format_chat_messages, -) +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens from .prompts.chat import code_llama_template_messages @@ -108,7 +103,8 @@ class LlamaCpp(LLM): system_message=self.system_message, ) - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") + prompt = self.convert_to_chat(messages) + self.write_log(f"Prompt: \n\n{prompt}") completion = "" async with aiohttp.ClientSession( connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) @@ -116,7 +112,7 @@ class LlamaCpp(LLM): async with client_session.post( f"{self.server_url}/completion", json={ - "prompt": self.convert_to_chat(messages), + "prompt": prompt, **self._transform_args(args), }, headers={"Content-Type": "application/json"}, diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index c754e54d..03300435 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -8,6 +8,7 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from .prompts.chat import llama2_template_messages class Ollama(LLM): @@ -57,43 +58,6 @@ class Ollama(LLM): def count_tokens(self, text: str): return count_tokens(self.name, text) - def convert_to_chat(self, msgs: ChatMessage) -> str: - if len(msgs) == 0: - return "" - - prompt = "" - has_system = msgs[0]["role"] == "system" - if has_system and msgs[0]["content"] == "": - has_system = False - msgs.pop(0) - - # TODO: Instead make stream_complete and stream_chat the same method. - if len(msgs) == 1 and "[INST]" in msgs[0]["content"]: - return msgs[0]["content"] - - if has_system: - system_message = dedent( - f"""\ - <<SYS>> - {self.system_message} - <</SYS>> - - """ - ) - if len(msgs) > 1: - prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" - else: - prompt += f"[INST] {system_message} [/INST]" - return - - for i in range(2 if has_system else 0, len(msgs)): - if msgs[i]["role"] == "user": - prompt += f"[INST] {msgs[i]['content']} [/INST]" - else: - prompt += msgs[i]["content"] - - return prompt - async def stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: @@ -107,37 +71,36 @@ class Ollama(LLM): functions=None, system_message=self.system_message, ) - prompt = self.convert_to_chat(messages) + prompt = llama2_template_messages(messages) async with self._client_session.post( f"{self.server_url}/api/generate", json={ - "prompt": prompt, + "template": prompt, "model": self.model, + "system": self.system_message, + "options": {"temperature": args["temperature"]}, }, ) as resp: url_decode_buffer = "" async for line in resp.content.iter_any(): if line: - try: - json_chunk = line.decode("utf-8") - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - j = json.loads(chunk) - if "response" in j: - url_decode_buffer += j["response"] - - if ( - "&" in url_decode_buffer - and url_decode_buffer.index("&") - > len(url_decode_buffer) - 5 - ): - continue - yield urllib.parse.unquote(url_decode_buffer) - url_decode_buffer = "" - except: - raise Exception(str(line[0])) + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + url_decode_buffer += j["response"] + + if ( + "&" in url_decode_buffer + and url_decode_buffer.index("&") + > len(url_decode_buffer) - 5 + ): + continue + yield urllib.parse.unquote(url_decode_buffer) + url_decode_buffer = "" async def stream_chat( self, messages: List[ChatMessage] = None, **kwargs @@ -152,67 +115,56 @@ class Ollama(LLM): functions=None, system_message=self.system_message, ) - prompt = self.convert_to_chat(messages) + prompt = llama2_template_messages(messages) - self.write_log(f"Prompt: {prompt}") + self.write_log(f"Prompt:\n{prompt}") + completion = "" async with self._client_session.post( f"{self.server_url}/api/generate", json={ - "prompt": prompt, + "template": prompt, "model": self.model, + "system": self.system_message, + "options": {"temperature": args["temperature"]}, }, ) as resp: - # This is streaming application/json instaed of text/event-stream - url_decode_buffer = "" async for line in resp.content.iter_chunks(): if line[1]: - try: - json_chunk = line[0].decode("utf-8") - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - j = json.loads(chunk) - if "response" in j: - url_decode_buffer += j["response"] - if ( - "&" in url_decode_buffer - and url_decode_buffer.index("&") - > len(url_decode_buffer) - 5 - ): - continue - yield { - "role": "assistant", - "content": urllib.parse.unquote( - url_decode_buffer - ), - } - url_decode_buffer = "" - except: - raise Exception(str(line[0])) + json_chunk = line[0].decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + yield { + "role": "assistant", + "content": j["response"], + } + completion += j["response"] + self.write_log(f"Completion:\n{completion}") async def complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: completion = "" - + args = {**self.default_args, **kwargs} async with self._client_session.post( f"{self.server_url}/api/generate", json={ - "prompt": prompt, + "template": prompt, "model": self.model, + "system": self.system_message, + "options": {"temperature": args["temperature"]}, }, ) as resp: async for line in resp.content.iter_any(): if line: - try: - json_chunk = line.decode("utf-8") - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - j = json.loads(chunk) - if "response" in j: - completion += urllib.parse.unquote(j["response"]) - except: - raise Exception(str(line[0])) + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + completion += urllib.parse.unquote(j["response"]) return completion diff --git a/continuedev/src/continuedev/libs/llm/prompts/chat.py b/continuedev/src/continuedev/libs/llm/prompts/chat.py index 110dfaae..c7c208c0 100644 --- a/continuedev/src/continuedev/libs/llm/prompts/chat.py +++ b/continuedev/src/continuedev/libs/llm/prompts/chat.py @@ -7,6 +7,11 @@ def llama2_template_messages(msgs: ChatMessage) -> str: if len(msgs) == 0: return "" + if msgs[0]["role"] == "assistant": + # These models aren't trained to handle assistant message coming first, + # and typically these are just introduction messages from Continue + msgs.pop(0) + prompt = "" has_system = msgs[0]["role"] == "system" diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index 3dd4646c..eed43054 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -5,12 +5,14 @@ import redbaron from .paths import getConfigFilePath + def get_config_source(): config_file_path = getConfigFilePath() with open(config_file_path, "r") as file: source_code = file.read() return source_code + def load_red(): source_code = get_config_source() @@ -104,6 +106,8 @@ def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron: def create_string_node(string: str) -> redbaron.RedBaron: + if "\n" in string: + return redbaron.RedBaron(f'"""{string}"""')[0] return redbaron.RedBaron(f'"{string}"')[0] diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index a411c5c3..b3e9ecc1 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -60,13 +60,6 @@ def getConfigFilePath() -> str: if existing_content.strip() == "": with open(path, "w") as f: f.write(default_config) - elif " continuedev.core" in existing_content: - with open(path, "w") as f: - f.write( - existing_content.replace( - " continuedev.", " continuedev.src.continuedev." - ) - ) return path diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index cbd94fe2..857183bc 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -1,3 +1,4 @@ +import html import json import os from textwrap import dedent @@ -24,6 +25,12 @@ openai.api_key = OPENAI_API_KEY FREE_USAGE_STEP_NAME = "Please enter OpenAI API key" +def add_ellipsis(text: str, max_length: int = 200) -> str: + if len(text) > max_length: + return text[: max_length - 3] + "..." + return text + + class SimpleChatStep(Step): name: str = "Generating Response..." manage_own_chat_context: bool = True @@ -91,13 +98,26 @@ class SimpleChatStep(Step): if "content" in chunk: self.description += chunk["content"] + + # HTML unencode + end_size = len(chunk["content"]) - 6 + if "&" in self.description[-end_size:]: + self.description = self.description[:-end_size] + html.unescape( + self.description[-end_size:] + ) + await sdk.update_ui() - self.name = remove_quotes_and_escapes( - await sdk.models.medium.complete( - f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:', - max_tokens=20, - ) + self.name = "Generating title..." + await sdk.update_ui() + self.name = add_ellipsis( + remove_quotes_and_escapes( + await sdk.models.medium.complete( + f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:', + max_tokens=20, + ) + ), + 200, ) self.chat_context.append( diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 08c5efc5..dbf9ba0d 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -182,7 +182,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer): name = "continue_logs.txt" logs = "\n\n############################################\n\n".join( [ - "This is a log of the exact prompt/completion pairs sent/received from the LLM during this step" + "This is a log of the prompt/completion pairs sent/received from the LLM during this step" ] + self.session.autopilot.continue_sdk.history.timeline[index].logs ) @@ -244,8 +244,13 @@ class GUIProtocolServer(AbstractGUIProtocolServer): if prev_model is not None: exists = False for other in unused_models: - if display_llm_class(prev_model) == display_llm_class( - other + if ( + prev_model.__class__.__name__ + == other.__class__.__name__ + and ( + not other.name.startswith("gpt") + or prev_model.name == other.name + ) ): exists = True break diff --git a/extension/react-app/src/components/Layout.tsx b/extension/react-app/src/components/Layout.tsx index 500dc921..c328a206 100644 --- a/extension/react-app/src/components/Layout.tsx +++ b/extension/react-app/src/components/Layout.tsx @@ -101,7 +101,7 @@ const Layout = () => { if (event.metaKey && event.altKey && event.code === "KeyN") { client?.loadSession(undefined); } - if (event.metaKey && event.code === "KeyC") { + if ((event.metaKey || event.ctrlKey) && event.code === "KeyC") { const selection = window.getSelection()?.toString(); if (selection) { // Copy to clipboard |