diff options
-rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ollama.py | 27 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/templating.py | 33 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/core/core.py | 35 |
4 files changed, 79 insertions, 18 deletions
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 294e2c8b..90ef7934 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -12,6 +12,8 @@ class LLM(ContinueBaseModel, ABC): system_message: Optional[str] = None + prompt_templates: dict = {} + class Config: arbitrary_types_allowed = True extra = "allow" diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index 5647b702..df2b2238 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -16,12 +16,28 @@ class Ollama(LLM): max_context_length: int = 2048 _client_session: aiohttp.ClientSession = None + requires_write_log = True + + prompt_templates = { + "edit": dedent( + """\ + [INST] Consider the following code: + ``` + {{code_to_edit}} + ``` + Edit the code to perfectly satisfy the following user request: + {{user_input}} + Output nothing except for the code. No code block, no English explanation, no start/end tags. + [/INST]""" + ), + } class Config: arbitrary_types_allowed = True - async def start(self, **kwargs): + async def start(self, write_log, **kwargs): self._client_session = aiohttp.ClientSession() + self.write_log = write_log async def stop(self): await self._client_session.close() @@ -47,6 +63,14 @@ class Ollama(LLM): prompt = "" has_system = msgs[0]["role"] == "system" + if has_system and msgs[0]["content"] == "": + has_system = False + msgs.pop(0) + + # TODO: Instead make stream_complete and stream_chat the same method. + if len(msgs) == 1 and "[INST]" in msgs[0]["content"]: + return msgs[0]["content"] + if has_system: system_message = dedent( f"""\ @@ -120,6 +144,7 @@ class Ollama(LLM): ) prompt = self.convert_to_chat(messages) + self.write_log(f"Prompt: {prompt}") async with self._client_session.post( f"{self.server_url}/api/generate", json={ diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py index edcf2884..8d6a32fc 100644 --- a/continuedev/src/continuedev/libs/util/templating.py +++ b/continuedev/src/continuedev/libs/util/templating.py @@ -1,7 +1,10 @@ import os +from typing import Callable, Dict, List, Union import chevron +from ...core.main import ChatMessage + def get_vars_in_template(template): """ @@ -41,3 +44,33 @@ def render_templated_string(template: str) -> str: args[escaped_var] = "" return chevron.render(template, args) + + +""" +A PromptTemplate can either be a template string (mustache syntax, e.g. {{user_input}}) or +a function which takes the history and a dictionary of additional key-value pairs and returns +either a string or a list of ChatMessages. +If a string is returned, it will be assumed that the chat history should be ignored +""" +PromptTemplate = Union[ + str, Callable[[ChatMessage, Dict[str, str]], Union[str, List[ChatMessage]]] +] + + +def render_prompt_template( + template: PromptTemplate, history: List[ChatMessage], other_data: Dict[str, str] +) -> str: + """ + Render a prompt template. + """ + if isinstance(template, str): + data = { + "history": history, + **other_data, + } + if len(history) > 0 and history[0].role == "system": + data["system_message"] = history.pop(0).content + + return chevron.render(template, data) + else: + return template(history, other_data) diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index fe4b8a61..212746f4 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -8,15 +8,14 @@ from pydantic import validator from ....core.main import ChatMessage, ContinueCustomException, Step from ....core.observation import Observation, TextObservation, UserInputObservation -from ....libs.llm.anthropic import AnthropicLLM from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI -from ....libs.llm.openai import OpenAI from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS from ....libs.util.strings import ( dedent_and_get_common_whitespace, remove_quotes_and_escapes, ) from ....libs.util.telemetry import posthog_logger +from ....libs.util.templating import render_prompt_template from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents from ....models.filesystem_edit import ( EditDiff, @@ -639,21 +638,23 @@ Please output the code to be inserted at the cursor in order to fulfill the user repeating_file_suffix = False line_below_highlighted_range = file_suffix.lstrip().split("\n")[0] - if not ( - isinstance(model_to_use, OpenAI) - or isinstance(model_to_use, MaybeProxyOpenAI) - or isinstance(model_to_use, AnthropicLLM) - ): - messages = [ - ChatMessage( - role="user", - content=f'```\n{rif.contents}\n```\n\nUser request: "{self.user_input}"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n', - summary=self.user_input, - ) - ] - # elif isinstance(model_to_use, ReplicateLLM): - # messages = [ChatMessage( - # role="user", content=f"// Previous implementation\n\n{rif.contents}\n\n// Updated implementation (after following directions: {self.user_input})\n\n", summary=self.user_input)] + # Use custom templates defined by the model + if template := model_to_use.prompt_templates.get("edit"): + rendered = render_prompt_template( + template, + messages[:-1], + {"code_to_edit": rif.contents, "user_input": self.user_input}, + ) + if isinstance(rendered, str): + messages = [ + ChatMessage( + role="user", + content=rendered, + summary=self.user_input, + ) + ] + else: + messages = rendered generator = model_to_use.stream_chat( messages, temperature=sdk.config.temperature, max_tokens=max_tokens |