diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ollama.py | 27 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/templating.py | 33 | ||||
| -rw-r--r-- | continuedev/src/continuedev/plugins/steps/core/core.py | 35 | 
4 files changed, 79 insertions, 18 deletions
| diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 294e2c8b..90ef7934 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -12,6 +12,8 @@ class LLM(ContinueBaseModel, ABC):      system_message: Optional[str] = None +    prompt_templates: dict = {} +      class Config:          arbitrary_types_allowed = True          extra = "allow" diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index 5647b702..df2b2238 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -16,12 +16,28 @@ class Ollama(LLM):      max_context_length: int = 2048      _client_session: aiohttp.ClientSession = None +    requires_write_log = True + +    prompt_templates = { +        "edit": dedent( +            """\ +            [INST] Consider the following code: +            ``` +            {{code_to_edit}} +            ``` +            Edit the code to perfectly satisfy the following user request: +            {{user_input}} +            Output nothing except for the code. No code block, no English explanation, no start/end tags. +            [/INST]""" +        ), +    }      class Config:          arbitrary_types_allowed = True -    async def start(self, **kwargs): +    async def start(self, write_log, **kwargs):          self._client_session = aiohttp.ClientSession() +        self.write_log = write_log      async def stop(self):          await self._client_session.close() @@ -47,6 +63,14 @@ class Ollama(LLM):          prompt = ""          has_system = msgs[0]["role"] == "system" +        if has_system and msgs[0]["content"] == "": +            has_system = False +            msgs.pop(0) + +        # TODO: Instead make stream_complete and stream_chat the same method. +        if len(msgs) == 1 and "[INST]" in msgs[0]["content"]: +            return msgs[0]["content"] +          if has_system:              system_message = dedent(                  f"""\ @@ -120,6 +144,7 @@ class Ollama(LLM):          )          prompt = self.convert_to_chat(messages) +        self.write_log(f"Prompt: {prompt}")          async with self._client_session.post(              f"{self.server_url}/api/generate",              json={ diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py index edcf2884..8d6a32fc 100644 --- a/continuedev/src/continuedev/libs/util/templating.py +++ b/continuedev/src/continuedev/libs/util/templating.py @@ -1,7 +1,10 @@  import os +from typing import Callable, Dict, List, Union  import chevron +from ...core.main import ChatMessage +  def get_vars_in_template(template):      """ @@ -41,3 +44,33 @@ def render_templated_string(template: str) -> str:                  args[escaped_var] = ""      return chevron.render(template, args) + + +""" +A PromptTemplate can either be a template string (mustache syntax, e.g. {{user_input}}) or +a function which takes the history and a dictionary of additional key-value pairs and returns +either a string or a list of ChatMessages. +If a string is returned, it will be assumed that the chat history should be ignored +""" +PromptTemplate = Union[ +    str, Callable[[ChatMessage, Dict[str, str]], Union[str, List[ChatMessage]]] +] + + +def render_prompt_template( +    template: PromptTemplate, history: List[ChatMessage], other_data: Dict[str, str] +) -> str: +    """ +    Render a prompt template. +    """ +    if isinstance(template, str): +        data = { +            "history": history, +            **other_data, +        } +        if len(history) > 0 and history[0].role == "system": +            data["system_message"] = history.pop(0).content + +        return chevron.render(template, data) +    else: +        return template(history, other_data) diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index fe4b8a61..212746f4 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -8,15 +8,14 @@ from pydantic import validator  from ....core.main import ChatMessage, ContinueCustomException, Step  from ....core.observation import Observation, TextObservation, UserInputObservation -from ....libs.llm.anthropic import AnthropicLLM  from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI -from ....libs.llm.openai import OpenAI  from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS  from ....libs.util.strings import (      dedent_and_get_common_whitespace,      remove_quotes_and_escapes,  )  from ....libs.util.telemetry import posthog_logger +from ....libs.util.templating import render_prompt_template  from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents  from ....models.filesystem_edit import (      EditDiff, @@ -639,21 +638,23 @@ Please output the code to be inserted at the cursor in order to fulfill the user          repeating_file_suffix = False          line_below_highlighted_range = file_suffix.lstrip().split("\n")[0] -        if not ( -            isinstance(model_to_use, OpenAI) -            or isinstance(model_to_use, MaybeProxyOpenAI) -            or isinstance(model_to_use, AnthropicLLM) -        ): -            messages = [ -                ChatMessage( -                    role="user", -                    content=f'```\n{rif.contents}\n```\n\nUser request: "{self.user_input}"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n', -                    summary=self.user_input, -                ) -            ] -        # elif isinstance(model_to_use, ReplicateLLM): -        #     messages = [ChatMessage( -        #         role="user", content=f"// Previous implementation\n\n{rif.contents}\n\n// Updated implementation (after following directions: {self.user_input})\n\n", summary=self.user_input)] +        # Use custom templates defined by the model +        if template := model_to_use.prompt_templates.get("edit"): +            rendered = render_prompt_template( +                template, +                messages[:-1], +                {"code_to_edit": rif.contents, "user_input": self.user_input}, +            ) +            if isinstance(rendered, str): +                messages = [ +                    ChatMessage( +                        role="user", +                        content=rendered, +                        summary=self.user_input, +                    ) +                ] +            else: +                messages = rendered          generator = model_to_use.stream_chat(              messages, temperature=sdk.config.temperature, max_tokens=max_tokens | 
