summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/libs/steps/main.py
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-05-29 18:31:25 -0400
committerNate Sesti <sestinj@gmail.com>2023-05-29 18:31:25 -0400
commit8d59100b3194cc8d122708523226968899efb5e1 (patch)
tree88fe742114c87d6df0424f46dfc86077d716a074 /continuedev/src/continuedev/libs/steps/main.py
parent8c00cddb9345daaf2052d3b2650fa136f39813be (diff)
downloadsncontinue-8d59100b3194cc8d122708523226968899efb5e1.tar.gz
sncontinue-8d59100b3194cc8d122708523226968899efb5e1.tar.bz2
sncontinue-8d59100b3194cc8d122708523226968899efb5e1.zip
(much!) faster inference with starcoder
Diffstat (limited to 'continuedev/src/continuedev/libs/steps/main.py')
-rw-r--r--continuedev/src/continuedev/libs/steps/main.py57
1 files changed, 54 insertions, 3 deletions
diff --git a/continuedev/src/continuedev/libs/steps/main.py b/continuedev/src/continuedev/libs/steps/main.py
index 4f4f80e3..c8a85800 100644
--- a/continuedev/src/continuedev/libs/steps/main.py
+++ b/continuedev/src/continuedev/libs/steps/main.py
@@ -1,4 +1,6 @@
-from typing import Callable, Coroutine, List, Union
+from typing import Coroutine, List, Union
+
+from pydantic import BaseModel
from ..util.traceback_parsers import parse_python_traceback
from ..llm import LLM
@@ -8,9 +10,10 @@ from ...models.filesystem import RangeInFile, RangeInFileWithContents
from ...core.observation import Observation, TextObservation, TracebackObservation
from ..llm.prompt_utils import MarkdownStyleEncoderDecoder
from textwrap import dedent
-from ...core.main import History, Policy, Step, ContinueSDK, Observation
+from ...core.main import Step
+from ...core.sdk import ContinueSDK
+from ...core.observation import Observation
import subprocess
-import json
from .core.core import EditCodeStep
@@ -36,6 +39,10 @@ class RunCodeStep(Step):
return None
+class Policy(BaseModel):
+ pass
+
+
class RunPolicyUntilDoneStep(Step):
policy: "Policy"
@@ -206,6 +213,50 @@ class FasterEditHighlightedCodeStep(Step):
return None
+class StarCoderEditHighlightedCodeStep(Step):
+ user_input: str
+ hide = True
+ _prompt: str = "<commit_before>{code}<commit_msg>{user_request}<commit_after>"
+
+ async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ return "Editing highlighted code"
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ range_in_files = await sdk.ide.getHighlightedCode()
+ if len(range_in_files) == 0:
+ # Get the full contents of all open files
+ files = await sdk.ide.getOpenFiles()
+ contents = {}
+ for file in files:
+ contents[file] = await sdk.ide.readFile(file)
+
+ range_in_files = [RangeInFile.from_entire_file(
+ filepath, content) for filepath, content in contents.items()]
+
+ rif_with_contents = []
+ for range_in_file in range_in_files:
+ file_contents = await sdk.ide.readRangeInFile(range_in_file)
+ rif_with_contents.append(
+ RangeInFileWithContents.from_range_in_file(range_in_file, file_contents))
+
+ rif_dict = {}
+ for rif in rif_with_contents:
+ rif_dict[rif.filepath] = rif.contents
+
+ for rif in rif_with_contents:
+ prompt = self._prompt.format(
+ code=rif.contents, user_request=self.user_input)
+ completion = str((await sdk.models.starcoder()).complete(prompt))
+ eot_token = "<|endoftext|>"
+ if completion.endswith(eot_token):
+ completion = completion[:completion.rindex(eot_token)]
+
+ await sdk.ide.applyFileSystemEdit(
+ FileEdit(filepath=rif.filepath, range=rif.range, replacement=completion))
+ await sdk.ide.saveFile(rif.filepath)
+ await sdk.ide.setFileOpen(rif.filepath)
+
+
class EditHighlightedCodeStep(Step):
user_input: str
hide = True