diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-05-29 18:31:25 -0400 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-05-29 18:31:25 -0400 |
commit | 22245d2cbf90daa9033d8551207aa986069d8b24 (patch) | |
tree | 6a7cdaa88b365a5838c6d2c178fdbba48a667f33 | |
parent | 9a221fda9b44d0dc7ab2637c9a25d1be226b2b32 (diff) | |
download | sncontinue-22245d2cbf90daa9033d8551207aa986069d8b24.tar.gz sncontinue-22245d2cbf90daa9033d8551207aa986069d8b24.tar.bz2 sncontinue-22245d2cbf90daa9033d8551207aa986069d8b24.zip |
(much!) faster inference with starcoder
-rw-r--r-- | continuedev/src/continuedev/core/agent.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/env.py | 19 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/policy.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 25 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 25 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/steps/main.py | 57 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/notebook.py | 4 | ||||
-rw-r--r-- | extension/react-app/src/components/StepContainer.tsx | 21 |
8 files changed, 130 insertions, 28 deletions
diff --git a/continuedev/src/continuedev/core/agent.py b/continuedev/src/continuedev/core/agent.py index 6d1f542e..329e3d4c 100644 --- a/continuedev/src/continuedev/core/agent.py +++ b/continuedev/src/continuedev/core/agent.py @@ -49,9 +49,10 @@ class Agent(ContinueBaseModel): async def wait_for_user_input(self) -> str: self._active = False self.update_subscribers() - await self._user_input_queue.get(self.history.current_index) + user_input = await self._user_input_queue.get(self.history.current_index) self._active = True self.update_subscribers() + return user_input _manual_edits_buffer: List[FileEditWithFullContents] = [] diff --git a/continuedev/src/continuedev/core/env.py b/continuedev/src/continuedev/core/env.py index d7275b41..6267ed60 100644 --- a/continuedev/src/continuedev/core/env.py +++ b/continuedev/src/continuedev/core/env.py @@ -1,7 +1,22 @@ from dotenv import load_dotenv import os -load_dotenv() +def get_env_var(var_name: str): + load_dotenv() + return os.getenv(var_name) -openai_api_key = os.getenv("OPENAI_API_KEY") + +def save_env_var(var_name: str, var_value: str): + with open('.env', 'r') as f: + lines = f.readlines() + with open('.env', 'w') as f: + values = {} + for line in lines: + key, value = line.split('=') + value = value.replace('"', '') + values[key] = value + + values[var_name] = var_value + for key, value in values.items(): + f.write(f'{key}="{value}"\n') diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index 504d5ff1..07101576 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -6,7 +6,7 @@ from ..models.main import ContinueBaseModel from ..libs.steps.ty import CreatePipelineStep from .main import Step, Validator, History, Policy from .observation import Observation, TracebackObservation, UserInputObservation -from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep +from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep from ..libs.steps.nate import WritePytestsStep, CreateTableStep from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma from ..libs.steps.continue_step import ContinueStepStep @@ -32,7 +32,7 @@ class DemoPolicy(Policy): return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:])) elif "/step" in observation.user_input: return ContinueStepStep(prompt=" ".join(observation.user_input.split(" ")[1:])) - return FasterEditHighlightedCodeStep(user_input=observation.user_input) + return StarCoderEditHighlightedCodeStep(user_input=observation.user_input) state = history.get_current() if state is None or not self.ran_code_last: diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 3559e9d7..750b335d 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -3,10 +3,12 @@ from typing import Coroutine, Union from ..models.filesystem_edit import FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory from ..models.filesystem import RangeInFile from ..libs.llm import LLM +from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer from .main import History, Step from ..libs.steps.core.core import * +from .env import get_env_var, save_env_var class Agent: @@ -18,11 +20,22 @@ class ContinueSDKSteps: self.sdk = sdk +class Models: + def __init__(self, sdk: "ContinueSDK"): + self.sdk = sdk + + async def starcoder(self): + api_key = await self.sdk.get_user_secret( + 'HUGGING_FACE_TOKEN', 'Please enter your Hugging Face token') + return HuggingFaceInferenceAPI(api_key=api_key) + + class ContinueSDK: """The SDK provided as parameters to a step""" llm: LLM ide: AbstractIdeProtocolServer steps: ContinueSDKSteps + models: Models __agent: Agent def __init__(self, agent: Agent, llm: Union[LLM, None] = None): @@ -33,6 +46,7 @@ class ContinueSDK: self.ide = agent.ide self.__agent = agent self.steps = ContinueSDKSteps(self) + self.models = Models(self) @property def history(self) -> History: @@ -80,3 +94,14 @@ class ContinueSDK: async def delete_directory(self, path: str): return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) + + async def get_user_secret(self, env_var: str, prompt: str) -> str: + try: + val = get_env_var(env_var) + if val is not None: + return val + except: + pass + val = (await self.run_step(WaitForUserInputStep(prompt=prompt))).text + save_env_var(env_var, val) + return val diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py new file mode 100644 index 00000000..83852d27 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -0,0 +1,25 @@ +from ..llm import LLM +import requests + +DEFAULT_MAX_TOKENS = 2048 +DEFAULT_MAX_TIME = 120. + + +class HuggingFaceInferenceAPI(LLM): + api_key: str + model: str = "bigcode/starcoder" + + def complete(self, prompt: str, **kwargs): + """Return the completion of the text with the given temperature.""" + API_URL = f"https://api-inference.huggingface.co/models/{self.model}" + headers = { + "Authorization": f"Bearer {self.api_key}"} + + response = requests.post(API_URL, headers=headers, json={ + "inputs": prompt, "parameters": { + "max_new_tokens": DEFAULT_MAX_TOKENS, + "max_time": DEFAULT_MAX_TIME, + "return_full_text": False, + } + }) + return response.json()[0]["generated_text"] diff --git a/continuedev/src/continuedev/libs/steps/main.py b/continuedev/src/continuedev/libs/steps/main.py index 4f4f80e3..c8a85800 100644 --- a/continuedev/src/continuedev/libs/steps/main.py +++ b/continuedev/src/continuedev/libs/steps/main.py @@ -1,4 +1,6 @@ -from typing import Callable, Coroutine, List, Union +from typing import Coroutine, List, Union + +from pydantic import BaseModel from ..util.traceback_parsers import parse_python_traceback from ..llm import LLM @@ -8,9 +10,10 @@ from ...models.filesystem import RangeInFile, RangeInFileWithContents from ...core.observation import Observation, TextObservation, TracebackObservation from ..llm.prompt_utils import MarkdownStyleEncoderDecoder from textwrap import dedent -from ...core.main import History, Policy, Step, ContinueSDK, Observation +from ...core.main import Step +from ...core.sdk import ContinueSDK +from ...core.observation import Observation import subprocess -import json from .core.core import EditCodeStep @@ -36,6 +39,10 @@ class RunCodeStep(Step): return None +class Policy(BaseModel): + pass + + class RunPolicyUntilDoneStep(Step): policy: "Policy" @@ -206,6 +213,50 @@ class FasterEditHighlightedCodeStep(Step): return None +class StarCoderEditHighlightedCodeStep(Step): + user_input: str + hide = True + _prompt: str = "<commit_before>{code}<commit_msg>{user_request}<commit_after>" + + async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + return "Editing highlighted code" + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + range_in_files = await sdk.ide.getHighlightedCode() + if len(range_in_files) == 0: + # Get the full contents of all open files + files = await sdk.ide.getOpenFiles() + contents = {} + for file in files: + contents[file] = await sdk.ide.readFile(file) + + range_in_files = [RangeInFile.from_entire_file( + filepath, content) for filepath, content in contents.items()] + + rif_with_contents = [] + for range_in_file in range_in_files: + file_contents = await sdk.ide.readRangeInFile(range_in_file) + rif_with_contents.append( + RangeInFileWithContents.from_range_in_file(range_in_file, file_contents)) + + rif_dict = {} + for rif in rif_with_contents: + rif_dict[rif.filepath] = rif.contents + + for rif in rif_with_contents: + prompt = self._prompt.format( + code=rif.contents, user_request=self.user_input) + completion = str((await sdk.models.starcoder()).complete(prompt)) + eot_token = "<|endoftext|>" + if completion.endswith(eot_token): + completion = completion[:completion.rindex(eot_token)] + + await sdk.ide.applyFileSystemEdit( + FileEdit(filepath=rif.filepath, range=rif.range, replacement=completion)) + await sdk.ide.saveFile(rif.filepath) + await sdk.ide.setFileOpen(rif.filepath) + + class EditHighlightedCodeStep(Step): user_input: str hide = True diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py index bfd7a09c..c5dcea31 100644 --- a/continuedev/src/continuedev/server/notebook.py +++ b/continuedev/src/continuedev/server/notebook.py @@ -12,7 +12,7 @@ from ..libs.steps.nate import ImplementAbstractMethodStep from ..core.observation import Observation from ..libs.llm.openai import OpenAI from .ide_protocol import AbstractIdeProtocolServer -from ..core.env import openai_api_key +from ..core.env import get_env_var import asyncio import nest_asyncio nest_asyncio.apply() @@ -75,7 +75,7 @@ class SessionManager: def new_session(self, ide: AbstractIdeProtocolServer) -> str: cmd = "python3 /Users/natesesti/Desktop/continue/extension/examples/python/main.py" - agent = DemoAgent(llm=OpenAI(api_key=openai_api_key), + agent = DemoAgent(llm=OpenAI(api_key=get_env_var("OPENAI_API_KEY")), policy=DemoPolicy(cmd=cmd), ide=ide) session_id = str(uuid4()) session = Session(session_id=session_id, agent=agent) diff --git a/extension/react-app/src/components/StepContainer.tsx b/extension/react-app/src/components/StepContainer.tsx index 03649b66..36b3d99a 100644 --- a/extension/react-app/src/components/StepContainer.tsx +++ b/extension/react-app/src/components/StepContainer.tsx @@ -144,6 +144,9 @@ function StepContainer(props: StepContainerProps) { onSubmit={(ev) => { props.onUserInput(ev.currentTarget.value); }} + onClick={(e) => { + e.stopPropagation(); + }} /> )} {props.historyNode.step.name === "Waiting for user confirmation" && ( @@ -165,24 +168,6 @@ function StepContainer(props: StepContainerProps) { /> </> )} - - {open && ( - <> - {/* {props.historyNode.observation && ( - <SubContainer title="Error"> - <CodeBlock>Error Here</CodeBlock> - </SubContainer> - )} */} - {/* {props.iterationContext.suggestedChanges.map((sc) => { - return ( - <SubContainer title="Suggested Change"> - {sc.filepath} - <CodeBlock>{sc.replacement}</CodeBlock> - </SubContainer> - ); - })} */} - </> - )} </StepContainerDiv> </GradientBorder> |