summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/ollama.py27
-rw-r--r--continuedev/src/continuedev/libs/util/templating.py33
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py35
4 files changed, 79 insertions, 18 deletions
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 294e2c8b..90ef7934 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -12,6 +12,8 @@ class LLM(ContinueBaseModel, ABC):
system_message: Optional[str] = None
+ prompt_templates: dict = {}
+
class Config:
arbitrary_types_allowed = True
extra = "allow"
diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py
index 5647b702..df2b2238 100644
--- a/continuedev/src/continuedev/libs/llm/ollama.py
+++ b/continuedev/src/continuedev/libs/llm/ollama.py
@@ -16,12 +16,28 @@ class Ollama(LLM):
max_context_length: int = 2048
_client_session: aiohttp.ClientSession = None
+ requires_write_log = True
+
+ prompt_templates = {
+ "edit": dedent(
+ """\
+ [INST] Consider the following code:
+ ```
+ {{code_to_edit}}
+ ```
+ Edit the code to perfectly satisfy the following user request:
+ {{user_input}}
+ Output nothing except for the code. No code block, no English explanation, no start/end tags.
+ [/INST]"""
+ ),
+ }
class Config:
arbitrary_types_allowed = True
- async def start(self, **kwargs):
+ async def start(self, write_log, **kwargs):
self._client_session = aiohttp.ClientSession()
+ self.write_log = write_log
async def stop(self):
await self._client_session.close()
@@ -47,6 +63,14 @@ class Ollama(LLM):
prompt = ""
has_system = msgs[0]["role"] == "system"
+ if has_system and msgs[0]["content"] == "":
+ has_system = False
+ msgs.pop(0)
+
+ # TODO: Instead make stream_complete and stream_chat the same method.
+ if len(msgs) == 1 and "[INST]" in msgs[0]["content"]:
+ return msgs[0]["content"]
+
if has_system:
system_message = dedent(
f"""\
@@ -120,6 +144,7 @@ class Ollama(LLM):
)
prompt = self.convert_to_chat(messages)
+ self.write_log(f"Prompt: {prompt}")
async with self._client_session.post(
f"{self.server_url}/api/generate",
json={
diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py
index edcf2884..8d6a32fc 100644
--- a/continuedev/src/continuedev/libs/util/templating.py
+++ b/continuedev/src/continuedev/libs/util/templating.py
@@ -1,7 +1,10 @@
import os
+from typing import Callable, Dict, List, Union
import chevron
+from ...core.main import ChatMessage
+
def get_vars_in_template(template):
"""
@@ -41,3 +44,33 @@ def render_templated_string(template: str) -> str:
args[escaped_var] = ""
return chevron.render(template, args)
+
+
+"""
+A PromptTemplate can either be a template string (mustache syntax, e.g. {{user_input}}) or
+a function which takes the history and a dictionary of additional key-value pairs and returns
+either a string or a list of ChatMessages.
+If a string is returned, it will be assumed that the chat history should be ignored
+"""
+PromptTemplate = Union[
+ str, Callable[[ChatMessage, Dict[str, str]], Union[str, List[ChatMessage]]]
+]
+
+
+def render_prompt_template(
+ template: PromptTemplate, history: List[ChatMessage], other_data: Dict[str, str]
+) -> str:
+ """
+ Render a prompt template.
+ """
+ if isinstance(template, str):
+ data = {
+ "history": history,
+ **other_data,
+ }
+ if len(history) > 0 and history[0].role == "system":
+ data["system_message"] = history.pop(0).content
+
+ return chevron.render(template, data)
+ else:
+ return template(history, other_data)
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index fe4b8a61..212746f4 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -8,15 +8,14 @@ from pydantic import validator
from ....core.main import ChatMessage, ContinueCustomException, Step
from ....core.observation import Observation, TextObservation, UserInputObservation
-from ....libs.llm.anthropic import AnthropicLLM
from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
-from ....libs.llm.openai import OpenAI
from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS
from ....libs.util.strings import (
dedent_and_get_common_whitespace,
remove_quotes_and_escapes,
)
from ....libs.util.telemetry import posthog_logger
+from ....libs.util.templating import render_prompt_template
from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
from ....models.filesystem_edit import (
EditDiff,
@@ -639,21 +638,23 @@ Please output the code to be inserted at the cursor in order to fulfill the user
repeating_file_suffix = False
line_below_highlighted_range = file_suffix.lstrip().split("\n")[0]
- if not (
- isinstance(model_to_use, OpenAI)
- or isinstance(model_to_use, MaybeProxyOpenAI)
- or isinstance(model_to_use, AnthropicLLM)
- ):
- messages = [
- ChatMessage(
- role="user",
- content=f'```\n{rif.contents}\n```\n\nUser request: "{self.user_input}"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n',
- summary=self.user_input,
- )
- ]
- # elif isinstance(model_to_use, ReplicateLLM):
- # messages = [ChatMessage(
- # role="user", content=f"// Previous implementation\n\n{rif.contents}\n\n// Updated implementation (after following directions: {self.user_input})\n\n", summary=self.user_input)]
+ # Use custom templates defined by the model
+ if template := model_to_use.prompt_templates.get("edit"):
+ rendered = render_prompt_template(
+ template,
+ messages[:-1],
+ {"code_to_edit": rif.contents, "user_input": self.user_input},
+ )
+ if isinstance(rendered, str):
+ messages = [
+ ChatMessage(
+ role="user",
+ content=rendered,
+ summary=self.user_input,
+ )
+ ]
+ else:
+ messages = rendered
generator = model_to_use.stream_chat(
messages, temperature=sdk.config.temperature, max_tokens=max_tokens