diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-05-29 18:31:25 -0400 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-05-29 18:31:25 -0400 |
commit | 22245d2cbf90daa9033d8551207aa986069d8b24 (patch) | |
tree | 6a7cdaa88b365a5838c6d2c178fdbba48a667f33 /continuedev/src/continuedev/libs/steps/main.py | |
parent | 9a221fda9b44d0dc7ab2637c9a25d1be226b2b32 (diff) | |
download | sncontinue-22245d2cbf90daa9033d8551207aa986069d8b24.tar.gz sncontinue-22245d2cbf90daa9033d8551207aa986069d8b24.tar.bz2 sncontinue-22245d2cbf90daa9033d8551207aa986069d8b24.zip |
(much!) faster inference with starcoder
Diffstat (limited to 'continuedev/src/continuedev/libs/steps/main.py')
-rw-r--r-- | continuedev/src/continuedev/libs/steps/main.py | 57 |
1 files changed, 54 insertions, 3 deletions
diff --git a/continuedev/src/continuedev/libs/steps/main.py b/continuedev/src/continuedev/libs/steps/main.py index 4f4f80e3..c8a85800 100644 --- a/continuedev/src/continuedev/libs/steps/main.py +++ b/continuedev/src/continuedev/libs/steps/main.py @@ -1,4 +1,6 @@ -from typing import Callable, Coroutine, List, Union +from typing import Coroutine, List, Union + +from pydantic import BaseModel from ..util.traceback_parsers import parse_python_traceback from ..llm import LLM @@ -8,9 +10,10 @@ from ...models.filesystem import RangeInFile, RangeInFileWithContents from ...core.observation import Observation, TextObservation, TracebackObservation from ..llm.prompt_utils import MarkdownStyleEncoderDecoder from textwrap import dedent -from ...core.main import History, Policy, Step, ContinueSDK, Observation +from ...core.main import Step +from ...core.sdk import ContinueSDK +from ...core.observation import Observation import subprocess -import json from .core.core import EditCodeStep @@ -36,6 +39,10 @@ class RunCodeStep(Step): return None +class Policy(BaseModel): + pass + + class RunPolicyUntilDoneStep(Step): policy: "Policy" @@ -206,6 +213,50 @@ class FasterEditHighlightedCodeStep(Step): return None +class StarCoderEditHighlightedCodeStep(Step): + user_input: str + hide = True + _prompt: str = "<commit_before>{code}<commit_msg>{user_request}<commit_after>" + + async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + return "Editing highlighted code" + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + range_in_files = await sdk.ide.getHighlightedCode() + if len(range_in_files) == 0: + # Get the full contents of all open files + files = await sdk.ide.getOpenFiles() + contents = {} + for file in files: + contents[file] = await sdk.ide.readFile(file) + + range_in_files = [RangeInFile.from_entire_file( + filepath, content) for filepath, content in contents.items()] + + rif_with_contents = [] + for range_in_file in range_in_files: + file_contents = await sdk.ide.readRangeInFile(range_in_file) + rif_with_contents.append( + RangeInFileWithContents.from_range_in_file(range_in_file, file_contents)) + + rif_dict = {} + for rif in rif_with_contents: + rif_dict[rif.filepath] = rif.contents + + for rif in rif_with_contents: + prompt = self._prompt.format( + code=rif.contents, user_request=self.user_input) + completion = str((await sdk.models.starcoder()).complete(prompt)) + eot_token = "<|endoftext|>" + if completion.endswith(eot_token): + completion = completion[:completion.rindex(eot_token)] + + await sdk.ide.applyFileSystemEdit( + FileEdit(filepath=rif.filepath, range=rif.range, replacement=completion)) + await sdk.ide.saveFile(rif.filepath) + await sdk.ide.setFileOpen(rif.filepath) + + class EditHighlightedCodeStep(Step): user_input: str hide = True |