summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/plugins
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-09-01 13:36:43 -0700
committerNate Sesti <sestinj@gmail.com>2023-09-01 13:36:43 -0700
commitc1e5039731941eb6b6eea166edd433cd49d4e858 (patch)
treefbc2b533b9622dbbb11ecec10769b0183e70c9bd /continuedev/src/continuedev/plugins
parent25cc0c438f78b1f5029debc36cc10a36e9fb542a (diff)
downloadsncontinue-c1e5039731941eb6b6eea166edd433cd49d4e858.tar.gz
sncontinue-c1e5039731941eb6b6eea166edd433cd49d4e858.tar.bz2
sncontinue-c1e5039731941eb6b6eea166edd433cd49d4e858.zip
feat: :sparkles: select custom model to use with edit step
Diffstat (limited to 'continuedev/src/continuedev/plugins')
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py11
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py8
2 files changed, 15 insertions, 4 deletions
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index 17b325ab..1529fe1b 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -2,12 +2,13 @@
import difflib
import traceback
from textwrap import dedent
-from typing import Any, Coroutine, List, Union
+from typing import Any, Coroutine, List, Optional, Union
from pydantic import validator
from ....core.main import ChatMessage, ContinueCustomException, Step
from ....core.observation import Observation, TextObservation, UserInputObservation
+from ....libs.llm import LLM
from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS
from ....libs.util.strings import (
@@ -161,6 +162,7 @@ class ShellCommandsStep(Step):
class DefaultModelEditCodeStep(Step):
user_input: str
+ model: Optional[LLM] = None
range_in_files: List[RangeInFile]
name: str = "Editing Code"
hide = False
@@ -241,7 +243,10 @@ class DefaultModelEditCodeStep(Step):
# We don't know here all of the functions being passed in.
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
- model_to_use = sdk.models.edit
+ if self.model is not None:
+ await sdk.start_model(self.model)
+
+ model_to_use = self.model or sdk.models.edit
max_tokens = int(model_to_use.context_length / 2)
TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
@@ -836,6 +841,7 @@ class EditFileStep(Step):
filepath: str
prompt: str
hide: bool = True
+ model: Optional[LLM] = None
async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Editing file: " + self.filepath
@@ -848,6 +854,7 @@ class EditFileStep(Step):
RangeInFile.from_entire_file(self.filepath, file_contents)
],
user_input=self.prompt,
+ model=self.model,
)
)
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index ab5775c6..2ceb82c5 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -1,12 +1,13 @@
import os
from textwrap import dedent
-from typing import Coroutine, List, Union
+from typing import Coroutine, List, Optional, Union
from pydantic import BaseModel, Field
from ...core.main import ContinueCustomException, Step
from ...core.observation import Observation
from ...core.sdk import ContinueSDK, Models
+from ...libs.llm import LLM
from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder
from ...libs.util.calculate_diff import calculate_diff2
from ...libs.util.logging import logger
@@ -240,6 +241,7 @@ class EditHighlightedCodeStep(Step):
title="User Input",
description="The natural language request describing how to edit the code",
)
+ model: Optional[LLM] = None
hide = True
description: str = "Change the contents of the currently highlighted code or open file. You should call this function if the user asks seems to be asking for a code change."
@@ -293,7 +295,9 @@ class EditHighlightedCodeStep(Step):
await sdk.run_step(
DefaultModelEditCodeStep(
- user_input=self.user_input, range_in_files=range_in_files
+ user_input=self.user_input,
+ range_in_files=range_in_files,
+ model=self.model,
)
)