summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/core/sdk.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py19
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py10
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_tgi.py127
-rw-r--r--continuedev/src/continuedev/libs/llm/llamacpp.py12
-rw-r--r--continuedev/src/continuedev/libs/llm/ollama.py148
-rw-r--r--continuedev/src/continuedev/libs/llm/prompts/chat.py5
-rw-r--r--continuedev/src/continuedev/libs/util/edit_config.py4
-rw-r--r--continuedev/src/continuedev/libs/util/paths.py7
-rw-r--r--continuedev/src/continuedev/plugins/steps/chat.py30
-rw-r--r--continuedev/src/continuedev/server/gui.py11
-rw-r--r--extension/react-app/src/components/Layout.tsx2
12 files changed, 233 insertions, 144 deletions
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 778f81b3..37992b67 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -228,6 +228,8 @@ class ContinueSDK(AbstractContinueSDK):
spec.loader.exec_module(config)
self._last_valid_config = config.config
+ logger.debug("Loaded Continue config file from %s", path)
+
return config.config
def get_code_context(
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 9a7d0ac9..16bc2fce 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -4,12 +4,7 @@ from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import (
- DEFAULT_ARGS,
- compile_chat_messages,
- count_tokens,
- format_chat_messages,
-)
+from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
class AnthropicLLM(LLM):
@@ -118,9 +113,10 @@ class AnthropicLLM(LLM):
)
completion = ""
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ prompt = self.__messages_to_prompt(messages)
+ self.write_log(f"Prompt: \n\n{prompt}")
async for chunk in await self._async_client.completions.create(
- prompt=self.__messages_to_prompt(messages), **args
+ prompt=prompt, **args
):
yield {"role": "assistant", "content": chunk.completion}
completion += chunk.completion
@@ -143,11 +139,10 @@ class AnthropicLLM(LLM):
system_message=self.system_message,
)
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ prompt = self.__messages_to_prompt(messages)
+ self.write_log(f"Prompt: \n\n{prompt}")
resp = (
- await self._async_client.completions.create(
- prompt=self.__messages_to_prompt(messages), **args
- )
+ await self._async_client.completions.create(prompt=prompt, **args)
).completion
self.write_log(f"Completion: \n\n{resp}")
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index be82c445..db3aaed7 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -147,16 +147,6 @@ class GGML(LLM):
) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
- # messages = compile_chat_messages(
- # args["model"],
- # with_history,
- # self.context_length,
- # args["max_tokens"],
- # prompt,
- # functions=None,
- # system_message=self.system_message,
- # )
-
self.write_log(f"Prompt: \n\n{prompt}")
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py
new file mode 100644
index 00000000..f04e700d
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py
@@ -0,0 +1,127 @@
+import json
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
+
+import aiohttp
+
+from ...core.main import ChatMessage
+from ..llm import LLM
+from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
+from .prompts.chat import code_llama_template_messages
+
+
+class HuggingFaceTGI(LLM):
+ model: str = "huggingface-tgi"
+ max_context_length: int = 2048
+ server_url: str = "http://localhost:8080"
+ verify_ssl: Optional[bool] = None
+
+ template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages
+
+ requires_write_log = True
+
+ write_log: Optional[Callable[[str], None]] = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ async def start(self, write_log: Callable[[str], None], **kwargs):
+ self.write_log = write_log
+
+ async def stop(self):
+ pass
+
+ @property
+ def name(self):
+ return self.model
+
+ @property
+ def context_length(self):
+ return self.max_context_length
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
+
+ def _transform_args(self, args):
+ args = {
+ **args,
+ "max_new_tokens": args.get("max_tokens", 1024),
+ }
+ args.pop("max_tokens", None)
+ return args
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.name, text)
+
+ async def stream_complete(
+ self, prompt, with_history: List[ChatMessage] = None, **kwargs
+ ) -> Generator[Union[Any, List, Dict], None, None]:
+ args = self.default_args.copy()
+ args.update(kwargs)
+ args["stream"] = True
+
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name,
+ with_history,
+ self.context_length,
+ args["max_tokens"],
+ prompt,
+ functions=args.get("functions", None),
+ system_message=self.system_message,
+ )
+
+ prompt = self.template_messages(messages)
+ self.write_log(f"Prompt: \n\n{prompt}")
+ completion = ""
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ ) as client_session:
+ async with client_session.post(
+ f"{self.server_url}",
+ json={"inputs": prompt, **self._transform_args(args)},
+ ) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ chunk = line.decode("utf-8")
+ json_chunk = json.loads(chunk)
+ text = json_chunk["details"]["best_of_sequences"][0][
+ "generated_text"
+ ]
+ yield text
+ completion += text
+
+ self.write_log(f"Completion: \n\n{completion}")
+
+ async def stream_chat(
+ self, messages: List[ChatMessage] = None, **kwargs
+ ) -> Generator[Union[Any, List, Dict], None, None]:
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name,
+ messages,
+ self.context_length,
+ args["max_tokens"],
+ None,
+ functions=args.get("functions", None),
+ system_message=self.system_message,
+ )
+
+ async for chunk in self.stream_complete(
+ None, self.template_messages(messages), **args
+ ):
+ yield {
+ "role": "assistant",
+ "content": chunk,
+ }
+
+ async def complete(
+ self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
+ ) -> Coroutine[Any, Any, str]:
+ args = {**self.default_args, **kwargs}
+
+ completion = ""
+ async for chunk in self.stream_complete(prompt, with_history, **args):
+ completion += chunk
+
+ return completion
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py
index 9e424fde..6625065e 100644
--- a/continuedev/src/continuedev/libs/llm/llamacpp.py
+++ b/continuedev/src/continuedev/libs/llm/llamacpp.py
@@ -6,12 +6,7 @@ import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import (
- DEFAULT_ARGS,
- compile_chat_messages,
- count_tokens,
- format_chat_messages,
-)
+from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
from .prompts.chat import code_llama_template_messages
@@ -108,7 +103,8 @@ class LlamaCpp(LLM):
system_message=self.system_message,
)
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ prompt = self.convert_to_chat(messages)
+ self.write_log(f"Prompt: \n\n{prompt}")
completion = ""
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
@@ -116,7 +112,7 @@ class LlamaCpp(LLM):
async with client_session.post(
f"{self.server_url}/completion",
json={
- "prompt": self.convert_to_chat(messages),
+ "prompt": prompt,
**self._transform_args(args),
},
headers={"Content-Type": "application/json"},
diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py
index c754e54d..03300435 100644
--- a/continuedev/src/continuedev/libs/llm/ollama.py
+++ b/continuedev/src/continuedev/libs/llm/ollama.py
@@ -8,6 +8,7 @@ import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
+from .prompts.chat import llama2_template_messages
class Ollama(LLM):
@@ -57,43 +58,6 @@ class Ollama(LLM):
def count_tokens(self, text: str):
return count_tokens(self.name, text)
- def convert_to_chat(self, msgs: ChatMessage) -> str:
- if len(msgs) == 0:
- return ""
-
- prompt = ""
- has_system = msgs[0]["role"] == "system"
- if has_system and msgs[0]["content"] == "":
- has_system = False
- msgs.pop(0)
-
- # TODO: Instead make stream_complete and stream_chat the same method.
- if len(msgs) == 1 and "[INST]" in msgs[0]["content"]:
- return msgs[0]["content"]
-
- if has_system:
- system_message = dedent(
- f"""\
- <<SYS>>
- {self.system_message}
- <</SYS>>
-
- """
- )
- if len(msgs) > 1:
- prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]"
- else:
- prompt += f"[INST] {system_message} [/INST]"
- return
-
- for i in range(2 if has_system else 0, len(msgs)):
- if msgs[i]["role"] == "user":
- prompt += f"[INST] {msgs[i]['content']} [/INST]"
- else:
- prompt += msgs[i]["content"]
-
- return prompt
-
async def stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
@@ -107,37 +71,36 @@ class Ollama(LLM):
functions=None,
system_message=self.system_message,
)
- prompt = self.convert_to_chat(messages)
+ prompt = llama2_template_messages(messages)
async with self._client_session.post(
f"{self.server_url}/api/generate",
json={
- "prompt": prompt,
+ "template": prompt,
"model": self.model,
+ "system": self.system_message,
+ "options": {"temperature": args["temperature"]},
},
) as resp:
url_decode_buffer = ""
async for line in resp.content.iter_any():
if line:
- try:
- json_chunk = line.decode("utf-8")
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- j = json.loads(chunk)
- if "response" in j:
- url_decode_buffer += j["response"]
-
- if (
- "&" in url_decode_buffer
- and url_decode_buffer.index("&")
- > len(url_decode_buffer) - 5
- ):
- continue
- yield urllib.parse.unquote(url_decode_buffer)
- url_decode_buffer = ""
- except:
- raise Exception(str(line[0]))
+ json_chunk = line.decode("utf-8")
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ j = json.loads(chunk)
+ if "response" in j:
+ url_decode_buffer += j["response"]
+
+ if (
+ "&" in url_decode_buffer
+ and url_decode_buffer.index("&")
+ > len(url_decode_buffer) - 5
+ ):
+ continue
+ yield urllib.parse.unquote(url_decode_buffer)
+ url_decode_buffer = ""
async def stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
@@ -152,67 +115,56 @@ class Ollama(LLM):
functions=None,
system_message=self.system_message,
)
- prompt = self.convert_to_chat(messages)
+ prompt = llama2_template_messages(messages)
- self.write_log(f"Prompt: {prompt}")
+ self.write_log(f"Prompt:\n{prompt}")
+ completion = ""
async with self._client_session.post(
f"{self.server_url}/api/generate",
json={
- "prompt": prompt,
+ "template": prompt,
"model": self.model,
+ "system": self.system_message,
+ "options": {"temperature": args["temperature"]},
},
) as resp:
- # This is streaming application/json instaed of text/event-stream
- url_decode_buffer = ""
async for line in resp.content.iter_chunks():
if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- j = json.loads(chunk)
- if "response" in j:
- url_decode_buffer += j["response"]
- if (
- "&" in url_decode_buffer
- and url_decode_buffer.index("&")
- > len(url_decode_buffer) - 5
- ):
- continue
- yield {
- "role": "assistant",
- "content": urllib.parse.unquote(
- url_decode_buffer
- ),
- }
- url_decode_buffer = ""
- except:
- raise Exception(str(line[0]))
+ json_chunk = line[0].decode("utf-8")
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ j = json.loads(chunk)
+ if "response" in j:
+ yield {
+ "role": "assistant",
+ "content": j["response"],
+ }
+ completion += j["response"]
+ self.write_log(f"Completion:\n{completion}")
async def complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
completion = ""
-
+ args = {**self.default_args, **kwargs}
async with self._client_session.post(
f"{self.server_url}/api/generate",
json={
- "prompt": prompt,
+ "template": prompt,
"model": self.model,
+ "system": self.system_message,
+ "options": {"temperature": args["temperature"]},
},
) as resp:
async for line in resp.content.iter_any():
if line:
- try:
- json_chunk = line.decode("utf-8")
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- j = json.loads(chunk)
- if "response" in j:
- completion += urllib.parse.unquote(j["response"])
- except:
- raise Exception(str(line[0]))
+ json_chunk = line.decode("utf-8")
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ j = json.loads(chunk)
+ if "response" in j:
+ completion += urllib.parse.unquote(j["response"])
return completion
diff --git a/continuedev/src/continuedev/libs/llm/prompts/chat.py b/continuedev/src/continuedev/libs/llm/prompts/chat.py
index 110dfaae..c7c208c0 100644
--- a/continuedev/src/continuedev/libs/llm/prompts/chat.py
+++ b/continuedev/src/continuedev/libs/llm/prompts/chat.py
@@ -7,6 +7,11 @@ def llama2_template_messages(msgs: ChatMessage) -> str:
if len(msgs) == 0:
return ""
+ if msgs[0]["role"] == "assistant":
+ # These models aren't trained to handle assistant message coming first,
+ # and typically these are just introduction messages from Continue
+ msgs.pop(0)
+
prompt = ""
has_system = msgs[0]["role"] == "system"
diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py
index 3dd4646c..eed43054 100644
--- a/continuedev/src/continuedev/libs/util/edit_config.py
+++ b/continuedev/src/continuedev/libs/util/edit_config.py
@@ -5,12 +5,14 @@ import redbaron
from .paths import getConfigFilePath
+
def get_config_source():
config_file_path = getConfigFilePath()
with open(config_file_path, "r") as file:
source_code = file.read()
return source_code
+
def load_red():
source_code = get_config_source()
@@ -104,6 +106,8 @@ def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron:
def create_string_node(string: str) -> redbaron.RedBaron:
+ if "\n" in string:
+ return redbaron.RedBaron(f'"""{string}"""')[0]
return redbaron.RedBaron(f'"{string}"')[0]
diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py
index a411c5c3..b3e9ecc1 100644
--- a/continuedev/src/continuedev/libs/util/paths.py
+++ b/continuedev/src/continuedev/libs/util/paths.py
@@ -60,13 +60,6 @@ def getConfigFilePath() -> str:
if existing_content.strip() == "":
with open(path, "w") as f:
f.write(default_config)
- elif " continuedev.core" in existing_content:
- with open(path, "w") as f:
- f.write(
- existing_content.replace(
- " continuedev.", " continuedev.src.continuedev."
- )
- )
return path
diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py
index cbd94fe2..857183bc 100644
--- a/continuedev/src/continuedev/plugins/steps/chat.py
+++ b/continuedev/src/continuedev/plugins/steps/chat.py
@@ -1,3 +1,4 @@
+import html
import json
import os
from textwrap import dedent
@@ -24,6 +25,12 @@ openai.api_key = OPENAI_API_KEY
FREE_USAGE_STEP_NAME = "Please enter OpenAI API key"
+def add_ellipsis(text: str, max_length: int = 200) -> str:
+ if len(text) > max_length:
+ return text[: max_length - 3] + "..."
+ return text
+
+
class SimpleChatStep(Step):
name: str = "Generating Response..."
manage_own_chat_context: bool = True
@@ -91,13 +98,26 @@ class SimpleChatStep(Step):
if "content" in chunk:
self.description += chunk["content"]
+
+ # HTML unencode
+ end_size = len(chunk["content"]) - 6
+ if "&" in self.description[-end_size:]:
+ self.description = self.description[:-end_size] + html.unescape(
+ self.description[-end_size:]
+ )
+
await sdk.update_ui()
- self.name = remove_quotes_and_escapes(
- await sdk.models.medium.complete(
- f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:',
- max_tokens=20,
- )
+ self.name = "Generating title..."
+ await sdk.update_ui()
+ self.name = add_ellipsis(
+ remove_quotes_and_escapes(
+ await sdk.models.medium.complete(
+ f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:',
+ max_tokens=20,
+ )
+ ),
+ 200,
)
self.chat_context.append(
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 08c5efc5..dbf9ba0d 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -182,7 +182,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
name = "continue_logs.txt"
logs = "\n\n############################################\n\n".join(
[
- "This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"
+ "This is a log of the prompt/completion pairs sent/received from the LLM during this step"
]
+ self.session.autopilot.continue_sdk.history.timeline[index].logs
)
@@ -244,8 +244,13 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
if prev_model is not None:
exists = False
for other in unused_models:
- if display_llm_class(prev_model) == display_llm_class(
- other
+ if (
+ prev_model.__class__.__name__
+ == other.__class__.__name__
+ and (
+ not other.name.startswith("gpt")
+ or prev_model.name == other.name
+ )
):
exists = True
break
diff --git a/extension/react-app/src/components/Layout.tsx b/extension/react-app/src/components/Layout.tsx
index 500dc921..c328a206 100644
--- a/extension/react-app/src/components/Layout.tsx
+++ b/extension/react-app/src/components/Layout.tsx
@@ -101,7 +101,7 @@ const Layout = () => {
if (event.metaKey && event.altKey && event.code === "KeyN") {
client?.loadSession(undefined);
}
- if (event.metaKey && event.code === "KeyC") {
+ if ((event.metaKey || event.ctrlKey) && event.code === "KeyC") {
const selection = window.getSelection()?.toString();
if (selection) {
// Copy to clipboard