From c1e5039731941eb6b6eea166edd433cd49d4e858 Mon Sep 17 00:00:00 2001
From: Nate Sesti <sestinj@gmail.com>
Date: Fri, 1 Sep 2023 13:36:43 -0700
Subject: feat: :sparkles: select custom model to use with edit step

---
 continuedev/src/continuedev/libs/llm/llamacpp.py       |  4 ++--
 continuedev/src/continuedev/libs/llm/openai.py         |  5 +----
 continuedev/src/continuedev/plugins/steps/core/core.py | 11 +++++++++--
 continuedev/src/continuedev/plugins/steps/main.py      |  8 ++++++--
 4 files changed, 18 insertions(+), 10 deletions(-)

diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py
index 6625065e..e6f38cd0 100644
--- a/continuedev/src/continuedev/libs/llm/llamacpp.py
+++ b/continuedev/src/continuedev/libs/llm/llamacpp.py
@@ -7,7 +7,7 @@ import aiohttp
 from ...core.main import ChatMessage
 from ..llm import LLM
 from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
-from .prompts.chat import code_llama_template_messages
+from .prompts.chat import llama2_template_messages
 
 
 class LlamaCpp(LLM):
@@ -15,7 +15,7 @@ class LlamaCpp(LLM):
     server_url: str = "http://localhost:8080"
     verify_ssl: Optional[bool] = None
 
-    template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages
+    template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages
     llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"], "grammar": "root ::= "}
 
     use_command: Optional[str] = None
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 464c6420..a61103b9 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -148,10 +148,7 @@ class OpenAI(LLM):
         args = self.default_args.copy()
         args.update(kwargs)
         args["stream"] = True
-        # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it?
-        args["model"] = (
-            self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613"
-        )
+
         if not args["model"].endswith("0613") and "functions" in args:
             del args["functions"]
 
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index 17b325ab..1529fe1b 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -2,12 +2,13 @@
 import difflib
 import traceback
 from textwrap import dedent
-from typing import Any, Coroutine, List, Union
+from typing import Any, Coroutine, List, Optional, Union
 
 from pydantic import validator
 
 from ....core.main import ChatMessage, ContinueCustomException, Step
 from ....core.observation import Observation, TextObservation, UserInputObservation
+from ....libs.llm import LLM
 from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
 from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS
 from ....libs.util.strings import (
@@ -161,6 +162,7 @@ class ShellCommandsStep(Step):
 
 class DefaultModelEditCodeStep(Step):
     user_input: str
+    model: Optional[LLM] = None
     range_in_files: List[RangeInFile]
     name: str = "Editing Code"
     hide = False
@@ -241,7 +243,10 @@ class DefaultModelEditCodeStep(Step):
         # We don't know here all of the functions being passed in.
         # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
         # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
-        model_to_use = sdk.models.edit
+        if self.model is not None:
+            await sdk.start_model(self.model)
+
+        model_to_use = self.model or sdk.models.edit
         max_tokens = int(model_to_use.context_length / 2)
 
         TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
@@ -836,6 +841,7 @@ class EditFileStep(Step):
     filepath: str
     prompt: str
     hide: bool = True
+    model: Optional[LLM] = None
 
     async def describe(self, models: Models) -> Coroutine[str, None, None]:
         return "Editing file: " + self.filepath
@@ -848,6 +854,7 @@ class EditFileStep(Step):
                     RangeInFile.from_entire_file(self.filepath, file_contents)
                 ],
                 user_input=self.prompt,
+                model=self.model,
             )
         )
 
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index ab5775c6..2ceb82c5 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -1,12 +1,13 @@
 import os
 from textwrap import dedent
-from typing import Coroutine, List, Union
+from typing import Coroutine, List, Optional, Union
 
 from pydantic import BaseModel, Field
 
 from ...core.main import ContinueCustomException, Step
 from ...core.observation import Observation
 from ...core.sdk import ContinueSDK, Models
+from ...libs.llm import LLM
 from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder
 from ...libs.util.calculate_diff import calculate_diff2
 from ...libs.util.logging import logger
@@ -240,6 +241,7 @@ class EditHighlightedCodeStep(Step):
         title="User Input",
         description="The natural language request describing how to edit the code",
     )
+    model: Optional[LLM] = None
     hide = True
     description: str = "Change the contents of the currently highlighted code or open file. You should call this function if the user asks seems to be asking for a code change."
 
@@ -293,7 +295,9 @@ class EditHighlightedCodeStep(Step):
 
         await sdk.run_step(
             DefaultModelEditCodeStep(
-                user_input=self.user_input, range_in_files=range_in_files
+                user_input=self.user_input,
+                range_in_files=range_in_files,
+                model=self.model,
             )
         )
 
-- 
cgit v1.2.3-70-g09d2