summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-30 19:55:18 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-30 19:55:18 -0700
commit28f5d7bedab05a8b061e4e7ee9055a5403786bbc (patch)
tree8e32e9a0edcddf3dd3bf5dbf76e14fb09b15ca8e /continuedev/src
parenta0e2e2d3d606d8bf465eac541a84aa57316ee271 (diff)
downloadsncontinue-28f5d7bedab05a8b061e4e7ee9055a5403786bbc.tar.gz
sncontinue-28f5d7bedab05a8b061e4e7ee9055a5403786bbc.tar.bz2
sncontinue-28f5d7bedab05a8b061e4e7ee9055a5403786bbc.zip
fix: :art: many small improvements
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/sdk.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py19
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py10
-rw-r--r--continuedev/src/continuedev/libs/llm/llamacpp.py12
-rw-r--r--continuedev/src/continuedev/libs/llm/ollama.py149
-rw-r--r--continuedev/src/continuedev/libs/llm/prompts/chat.py5
-rw-r--r--continuedev/src/continuedev/libs/util/edit_config.py4
-rw-r--r--continuedev/src/continuedev/libs/util/paths.py7
-rw-r--r--continuedev/src/continuedev/server/gui.py11
9 files changed, 84 insertions, 135 deletions
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 778f81b3..37992b67 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -228,6 +228,8 @@ class ContinueSDK(AbstractContinueSDK):
spec.loader.exec_module(config)
self._last_valid_config = config.config
+ logger.debug("Loaded Continue config file from %s", path)
+
return config.config
def get_code_context(
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 9a7d0ac9..16bc2fce 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -4,12 +4,7 @@ from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import (
- DEFAULT_ARGS,
- compile_chat_messages,
- count_tokens,
- format_chat_messages,
-)
+from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
class AnthropicLLM(LLM):
@@ -118,9 +113,10 @@ class AnthropicLLM(LLM):
)
completion = ""
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ prompt = self.__messages_to_prompt(messages)
+ self.write_log(f"Prompt: \n\n{prompt}")
async for chunk in await self._async_client.completions.create(
- prompt=self.__messages_to_prompt(messages), **args
+ prompt=prompt, **args
):
yield {"role": "assistant", "content": chunk.completion}
completion += chunk.completion
@@ -143,11 +139,10 @@ class AnthropicLLM(LLM):
system_message=self.system_message,
)
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ prompt = self.__messages_to_prompt(messages)
+ self.write_log(f"Prompt: \n\n{prompt}")
resp = (
- await self._async_client.completions.create(
- prompt=self.__messages_to_prompt(messages), **args
- )
+ await self._async_client.completions.create(prompt=prompt, **args)
).completion
self.write_log(f"Completion: \n\n{resp}")
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index be82c445..db3aaed7 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -147,16 +147,6 @@ class GGML(LLM):
) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
- # messages = compile_chat_messages(
- # args["model"],
- # with_history,
- # self.context_length,
- # args["max_tokens"],
- # prompt,
- # functions=None,
- # system_message=self.system_message,
- # )
-
self.write_log(f"Prompt: \n\n{prompt}")
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py
index 9e424fde..6625065e 100644
--- a/continuedev/src/continuedev/libs/llm/llamacpp.py
+++ b/continuedev/src/continuedev/libs/llm/llamacpp.py
@@ -6,12 +6,7 @@ import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import (
- DEFAULT_ARGS,
- compile_chat_messages,
- count_tokens,
- format_chat_messages,
-)
+from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
from .prompts.chat import code_llama_template_messages
@@ -108,7 +103,8 @@ class LlamaCpp(LLM):
system_message=self.system_message,
)
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ prompt = self.convert_to_chat(messages)
+ self.write_log(f"Prompt: \n\n{prompt}")
completion = ""
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
@@ -116,7 +112,7 @@ class LlamaCpp(LLM):
async with client_session.post(
f"{self.server_url}/completion",
json={
- "prompt": self.convert_to_chat(messages),
+ "prompt": prompt,
**self._transform_args(args),
},
headers={"Content-Type": "application/json"},
diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py
index c754e54d..240d922b 100644
--- a/continuedev/src/continuedev/libs/llm/ollama.py
+++ b/continuedev/src/continuedev/libs/llm/ollama.py
@@ -8,6 +8,7 @@ import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
+from .prompts.chat import llama2_template_messages
class Ollama(LLM):
@@ -57,43 +58,6 @@ class Ollama(LLM):
def count_tokens(self, text: str):
return count_tokens(self.name, text)
- def convert_to_chat(self, msgs: ChatMessage) -> str:
- if len(msgs) == 0:
- return ""
-
- prompt = ""
- has_system = msgs[0]["role"] == "system"
- if has_system and msgs[0]["content"] == "":
- has_system = False
- msgs.pop(0)
-
- # TODO: Instead make stream_complete and stream_chat the same method.
- if len(msgs) == 1 and "[INST]" in msgs[0]["content"]:
- return msgs[0]["content"]
-
- if has_system:
- system_message = dedent(
- f"""\
- <<SYS>>
- {self.system_message}
- <</SYS>>
-
- """
- )
- if len(msgs) > 1:
- prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]"
- else:
- prompt += f"[INST] {system_message} [/INST]"
- return
-
- for i in range(2 if has_system else 0, len(msgs)):
- if msgs[i]["role"] == "user":
- prompt += f"[INST] {msgs[i]['content']} [/INST]"
- else:
- prompt += msgs[i]["content"]
-
- return prompt
-
async def stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
@@ -107,37 +71,36 @@ class Ollama(LLM):
functions=None,
system_message=self.system_message,
)
- prompt = self.convert_to_chat(messages)
+ prompt = llama2_template_messages(messages)
async with self._client_session.post(
f"{self.server_url}/api/generate",
json={
- "prompt": prompt,
+ "template": prompt,
"model": self.model,
+ "system": self.system_message,
+ "options": {"temperature": args["temperature"]},
},
) as resp:
url_decode_buffer = ""
async for line in resp.content.iter_any():
if line:
- try:
- json_chunk = line.decode("utf-8")
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- j = json.loads(chunk)
- if "response" in j:
- url_decode_buffer += j["response"]
-
- if (
- "&" in url_decode_buffer
- and url_decode_buffer.index("&")
- > len(url_decode_buffer) - 5
- ):
- continue
- yield urllib.parse.unquote(url_decode_buffer)
- url_decode_buffer = ""
- except:
- raise Exception(str(line[0]))
+ json_chunk = line.decode("utf-8")
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ j = json.loads(chunk)
+ if "response" in j:
+ url_decode_buffer += j["response"]
+
+ if (
+ "&" in url_decode_buffer
+ and url_decode_buffer.index("&")
+ > len(url_decode_buffer) - 5
+ ):
+ continue
+ yield urllib.parse.unquote(url_decode_buffer)
+ url_decode_buffer = ""
async def stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
@@ -152,67 +115,63 @@ class Ollama(LLM):
functions=None,
system_message=self.system_message,
)
- prompt = self.convert_to_chat(messages)
+ prompt = llama2_template_messages(messages)
self.write_log(f"Prompt: {prompt}")
async with self._client_session.post(
f"{self.server_url}/api/generate",
json={
- "prompt": prompt,
+ "template": prompt,
"model": self.model,
+ "system": self.system_message,
+ "options": {"temperature": args["temperature"]},
},
) as resp:
# This is streaming application/json instaed of text/event-stream
url_decode_buffer = ""
async for line in resp.content.iter_chunks():
if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- j = json.loads(chunk)
- if "response" in j:
- url_decode_buffer += j["response"]
- if (
- "&" in url_decode_buffer
- and url_decode_buffer.index("&")
- > len(url_decode_buffer) - 5
- ):
- continue
- yield {
- "role": "assistant",
- "content": urllib.parse.unquote(
- url_decode_buffer
- ),
- }
- url_decode_buffer = ""
- except:
- raise Exception(str(line[0]))
+ json_chunk = line[0].decode("utf-8")
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ j = json.loads(chunk)
+ if "response" in j:
+ url_decode_buffer += j["response"]
+ if (
+ "&" in url_decode_buffer
+ and url_decode_buffer.index("&")
+ > len(url_decode_buffer) - 5
+ ):
+ continue
+ yield {
+ "role": "assistant",
+ "content": urllib.parse.unquote(url_decode_buffer),
+ }
+ url_decode_buffer = ""
async def complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
completion = ""
-
+ args = {**self.default_args, **kwargs}
async with self._client_session.post(
f"{self.server_url}/api/generate",
json={
- "prompt": prompt,
+ "template": prompt,
"model": self.model,
+ "system": self.system_message,
+ "options": {"temperature": args["temperature"]},
},
) as resp:
async for line in resp.content.iter_any():
if line:
- try:
- json_chunk = line.decode("utf-8")
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- j = json.loads(chunk)
- if "response" in j:
- completion += urllib.parse.unquote(j["response"])
- except:
- raise Exception(str(line[0]))
+ json_chunk = line.decode("utf-8")
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ j = json.loads(chunk)
+ if "response" in j:
+ completion += urllib.parse.unquote(j["response"])
return completion
diff --git a/continuedev/src/continuedev/libs/llm/prompts/chat.py b/continuedev/src/continuedev/libs/llm/prompts/chat.py
index 110dfaae..c7c208c0 100644
--- a/continuedev/src/continuedev/libs/llm/prompts/chat.py
+++ b/continuedev/src/continuedev/libs/llm/prompts/chat.py
@@ -7,6 +7,11 @@ def llama2_template_messages(msgs: ChatMessage) -> str:
if len(msgs) == 0:
return ""
+ if msgs[0]["role"] == "assistant":
+ # These models aren't trained to handle assistant message coming first,
+ # and typically these are just introduction messages from Continue
+ msgs.pop(0)
+
prompt = ""
has_system = msgs[0]["role"] == "system"
diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py
index 3dd4646c..eed43054 100644
--- a/continuedev/src/continuedev/libs/util/edit_config.py
+++ b/continuedev/src/continuedev/libs/util/edit_config.py
@@ -5,12 +5,14 @@ import redbaron
from .paths import getConfigFilePath
+
def get_config_source():
config_file_path = getConfigFilePath()
with open(config_file_path, "r") as file:
source_code = file.read()
return source_code
+
def load_red():
source_code = get_config_source()
@@ -104,6 +106,8 @@ def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron:
def create_string_node(string: str) -> redbaron.RedBaron:
+ if "\n" in string:
+ return redbaron.RedBaron(f'"""{string}"""')[0]
return redbaron.RedBaron(f'"{string}"')[0]
diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py
index a411c5c3..b3e9ecc1 100644
--- a/continuedev/src/continuedev/libs/util/paths.py
+++ b/continuedev/src/continuedev/libs/util/paths.py
@@ -60,13 +60,6 @@ def getConfigFilePath() -> str:
if existing_content.strip() == "":
with open(path, "w") as f:
f.write(default_config)
- elif " continuedev.core" in existing_content:
- with open(path, "w") as f:
- f.write(
- existing_content.replace(
- " continuedev.", " continuedev.src.continuedev."
- )
- )
return path
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 08c5efc5..dbf9ba0d 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -182,7 +182,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
name = "continue_logs.txt"
logs = "\n\n############################################\n\n".join(
[
- "This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"
+ "This is a log of the prompt/completion pairs sent/received from the LLM during this step"
]
+ self.session.autopilot.continue_sdk.history.timeline[index].logs
)
@@ -244,8 +244,13 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
if prev_model is not None:
exists = False
for other in unused_models:
- if display_llm_class(prev_model) == display_llm_class(
- other
+ if (
+ prev_model.__class__.__name__
+ == other.__class__.__name__
+ and (
+ not other.name.startswith("gpt")
+ or prev_model.name == other.name
+ )
):
exists = True
break