diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-09-01 13:36:43 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-09-01 13:36:43 -0700 |
commit | c1e5039731941eb6b6eea166edd433cd49d4e858 (patch) | |
tree | fbc2b533b9622dbbb11ecec10769b0183e70c9bd | |
parent | 25cc0c438f78b1f5029debc36cc10a36e9fb542a (diff) | |
download | sncontinue-c1e5039731941eb6b6eea166edd433cd49d4e858.tar.gz sncontinue-c1e5039731941eb6b6eea166edd433cd49d4e858.tar.bz2 sncontinue-c1e5039731941eb6b6eea166edd433cd49d4e858.zip |
feat: :sparkles: select custom model to use with edit step
4 files changed, 18 insertions, 10 deletions
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index 6625065e..e6f38cd0 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -7,7 +7,7 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens -from .prompts.chat import code_llama_template_messages +from .prompts.chat import llama2_template_messages class LlamaCpp(LLM): @@ -15,7 +15,7 @@ class LlamaCpp(LLM): server_url: str = "http://localhost:8080" verify_ssl: Optional[bool] = None - template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages + template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"], "grammar": "root ::= "} use_command: Optional[str] = None diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 464c6420..a61103b9 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -148,10 +148,7 @@ class OpenAI(LLM): args = self.default_args.copy() args.update(kwargs) args["stream"] = True - # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it? - args["model"] = ( - self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613" - ) + if not args["model"].endswith("0613") and "functions" in args: del args["functions"] diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 17b325ab..1529fe1b 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -2,12 +2,13 @@ import difflib import traceback from textwrap import dedent -from typing import Any, Coroutine, List, Union +from typing import Any, Coroutine, List, Optional, Union from pydantic import validator from ....core.main import ChatMessage, ContinueCustomException, Step from ....core.observation import Observation, TextObservation, UserInputObservation +from ....libs.llm import LLM from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS from ....libs.util.strings import ( @@ -161,6 +162,7 @@ class ShellCommandsStep(Step): class DefaultModelEditCodeStep(Step): user_input: str + model: Optional[LLM] = None range_in_files: List[RangeInFile] name: str = "Editing Code" hide = False @@ -241,7 +243,10 @@ class DefaultModelEditCodeStep(Step): # We don't know here all of the functions being passed in. # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. - model_to_use = sdk.models.edit + if self.model is not None: + await sdk.start_model(self.model) + + model_to_use = self.model or sdk.models.edit max_tokens = int(model_to_use.context_length / 2) TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 @@ -836,6 +841,7 @@ class EditFileStep(Step): filepath: str prompt: str hide: bool = True + model: Optional[LLM] = None async def describe(self, models: Models) -> Coroutine[str, None, None]: return "Editing file: " + self.filepath @@ -848,6 +854,7 @@ class EditFileStep(Step): RangeInFile.from_entire_file(self.filepath, file_contents) ], user_input=self.prompt, + model=self.model, ) ) diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index ab5775c6..2ceb82c5 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -1,12 +1,13 @@ import os from textwrap import dedent -from typing import Coroutine, List, Union +from typing import Coroutine, List, Optional, Union from pydantic import BaseModel, Field from ...core.main import ContinueCustomException, Step from ...core.observation import Observation from ...core.sdk import ContinueSDK, Models +from ...libs.llm import LLM from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder from ...libs.util.calculate_diff import calculate_diff2 from ...libs.util.logging import logger @@ -240,6 +241,7 @@ class EditHighlightedCodeStep(Step): title="User Input", description="The natural language request describing how to edit the code", ) + model: Optional[LLM] = None hide = True description: str = "Change the contents of the currently highlighted code or open file. You should call this function if the user asks seems to be asking for a code change." @@ -293,7 +295,9 @@ class EditHighlightedCodeStep(Step): await sdk.run_step( DefaultModelEditCodeStep( - user_input=self.user_input, range_in_files=range_in_files + user_input=self.user_input, + range_in_files=range_in_files, + model=self.model, ) ) |