summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-05-29 18:31:25 -0400
committerNate Sesti <sestinj@gmail.com>2023-05-29 18:31:25 -0400
commit22245d2cbf90daa9033d8551207aa986069d8b24 (patch)
tree6a7cdaa88b365a5838c6d2c178fdbba48a667f33
parent9a221fda9b44d0dc7ab2637c9a25d1be226b2b32 (diff)
downloadsncontinue-22245d2cbf90daa9033d8551207aa986069d8b24.tar.gz
sncontinue-22245d2cbf90daa9033d8551207aa986069d8b24.tar.bz2
sncontinue-22245d2cbf90daa9033d8551207aa986069d8b24.zip
(much!) faster inference with starcoder
-rw-r--r--continuedev/src/continuedev/core/agent.py3
-rw-r--r--continuedev/src/continuedev/core/env.py19
-rw-r--r--continuedev/src/continuedev/core/policy.py4
-rw-r--r--continuedev/src/continuedev/core/sdk.py25
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py25
-rw-r--r--continuedev/src/continuedev/libs/steps/main.py57
-rw-r--r--continuedev/src/continuedev/server/notebook.py4
-rw-r--r--extension/react-app/src/components/StepContainer.tsx21
8 files changed, 130 insertions, 28 deletions
diff --git a/continuedev/src/continuedev/core/agent.py b/continuedev/src/continuedev/core/agent.py
index 6d1f542e..329e3d4c 100644
--- a/continuedev/src/continuedev/core/agent.py
+++ b/continuedev/src/continuedev/core/agent.py
@@ -49,9 +49,10 @@ class Agent(ContinueBaseModel):
async def wait_for_user_input(self) -> str:
self._active = False
self.update_subscribers()
- await self._user_input_queue.get(self.history.current_index)
+ user_input = await self._user_input_queue.get(self.history.current_index)
self._active = True
self.update_subscribers()
+ return user_input
_manual_edits_buffer: List[FileEditWithFullContents] = []
diff --git a/continuedev/src/continuedev/core/env.py b/continuedev/src/continuedev/core/env.py
index d7275b41..6267ed60 100644
--- a/continuedev/src/continuedev/core/env.py
+++ b/continuedev/src/continuedev/core/env.py
@@ -1,7 +1,22 @@
from dotenv import load_dotenv
import os
-load_dotenv()
+def get_env_var(var_name: str):
+ load_dotenv()
+ return os.getenv(var_name)
-openai_api_key = os.getenv("OPENAI_API_KEY")
+
+def save_env_var(var_name: str, var_value: str):
+ with open('.env', 'r') as f:
+ lines = f.readlines()
+ with open('.env', 'w') as f:
+ values = {}
+ for line in lines:
+ key, value = line.split('=')
+ value = value.replace('"', '')
+ values[key] = value
+
+ values[var_name] = var_value
+ for key, value in values.items():
+ f.write(f'{key}="{value}"\n')
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index 504d5ff1..07101576 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -6,7 +6,7 @@ from ..models.main import ContinueBaseModel
from ..libs.steps.ty import CreatePipelineStep
from .main import Step, Validator, History, Policy
from .observation import Observation, TracebackObservation, UserInputObservation
-from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep
+from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep
from ..libs.steps.nate import WritePytestsStep, CreateTableStep
from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma
from ..libs.steps.continue_step import ContinueStepStep
@@ -32,7 +32,7 @@ class DemoPolicy(Policy):
return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:]))
elif "/step" in observation.user_input:
return ContinueStepStep(prompt=" ".join(observation.user_input.split(" ")[1:]))
- return FasterEditHighlightedCodeStep(user_input=observation.user_input)
+ return StarCoderEditHighlightedCodeStep(user_input=observation.user_input)
state = history.get_current()
if state is None or not self.ran_code_last:
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 3559e9d7..750b335d 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -3,10 +3,12 @@ from typing import Coroutine, Union
from ..models.filesystem_edit import FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory
from ..models.filesystem import RangeInFile
from ..libs.llm import LLM
+from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
from .main import History, Step
from ..libs.steps.core.core import *
+from .env import get_env_var, save_env_var
class Agent:
@@ -18,11 +20,22 @@ class ContinueSDKSteps:
self.sdk = sdk
+class Models:
+ def __init__(self, sdk: "ContinueSDK"):
+ self.sdk = sdk
+
+ async def starcoder(self):
+ api_key = await self.sdk.get_user_secret(
+ 'HUGGING_FACE_TOKEN', 'Please enter your Hugging Face token')
+ return HuggingFaceInferenceAPI(api_key=api_key)
+
+
class ContinueSDK:
"""The SDK provided as parameters to a step"""
llm: LLM
ide: AbstractIdeProtocolServer
steps: ContinueSDKSteps
+ models: Models
__agent: Agent
def __init__(self, agent: Agent, llm: Union[LLM, None] = None):
@@ -33,6 +46,7 @@ class ContinueSDK:
self.ide = agent.ide
self.__agent = agent
self.steps = ContinueSDKSteps(self)
+ self.models = Models(self)
@property
def history(self) -> History:
@@ -80,3 +94,14 @@ class ContinueSDK:
async def delete_directory(self, path: str):
return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path)))
+
+ async def get_user_secret(self, env_var: str, prompt: str) -> str:
+ try:
+ val = get_env_var(env_var)
+ if val is not None:
+ return val
+ except:
+ pass
+ val = (await self.run_step(WaitForUserInputStep(prompt=prompt))).text
+ save_env_var(env_var, val)
+ return val
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
new file mode 100644
index 00000000..83852d27
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -0,0 +1,25 @@
+from ..llm import LLM
+import requests
+
+DEFAULT_MAX_TOKENS = 2048
+DEFAULT_MAX_TIME = 120.
+
+
+class HuggingFaceInferenceAPI(LLM):
+ api_key: str
+ model: str = "bigcode/starcoder"
+
+ def complete(self, prompt: str, **kwargs):
+ """Return the completion of the text with the given temperature."""
+ API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
+ headers = {
+ "Authorization": f"Bearer {self.api_key}"}
+
+ response = requests.post(API_URL, headers=headers, json={
+ "inputs": prompt, "parameters": {
+ "max_new_tokens": DEFAULT_MAX_TOKENS,
+ "max_time": DEFAULT_MAX_TIME,
+ "return_full_text": False,
+ }
+ })
+ return response.json()[0]["generated_text"]
diff --git a/continuedev/src/continuedev/libs/steps/main.py b/continuedev/src/continuedev/libs/steps/main.py
index 4f4f80e3..c8a85800 100644
--- a/continuedev/src/continuedev/libs/steps/main.py
+++ b/continuedev/src/continuedev/libs/steps/main.py
@@ -1,4 +1,6 @@
-from typing import Callable, Coroutine, List, Union
+from typing import Coroutine, List, Union
+
+from pydantic import BaseModel
from ..util.traceback_parsers import parse_python_traceback
from ..llm import LLM
@@ -8,9 +10,10 @@ from ...models.filesystem import RangeInFile, RangeInFileWithContents
from ...core.observation import Observation, TextObservation, TracebackObservation
from ..llm.prompt_utils import MarkdownStyleEncoderDecoder
from textwrap import dedent
-from ...core.main import History, Policy, Step, ContinueSDK, Observation
+from ...core.main import Step
+from ...core.sdk import ContinueSDK
+from ...core.observation import Observation
import subprocess
-import json
from .core.core import EditCodeStep
@@ -36,6 +39,10 @@ class RunCodeStep(Step):
return None
+class Policy(BaseModel):
+ pass
+
+
class RunPolicyUntilDoneStep(Step):
policy: "Policy"
@@ -206,6 +213,50 @@ class FasterEditHighlightedCodeStep(Step):
return None
+class StarCoderEditHighlightedCodeStep(Step):
+ user_input: str
+ hide = True
+ _prompt: str = "<commit_before>{code}<commit_msg>{user_request}<commit_after>"
+
+ async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ return "Editing highlighted code"
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ range_in_files = await sdk.ide.getHighlightedCode()
+ if len(range_in_files) == 0:
+ # Get the full contents of all open files
+ files = await sdk.ide.getOpenFiles()
+ contents = {}
+ for file in files:
+ contents[file] = await sdk.ide.readFile(file)
+
+ range_in_files = [RangeInFile.from_entire_file(
+ filepath, content) for filepath, content in contents.items()]
+
+ rif_with_contents = []
+ for range_in_file in range_in_files:
+ file_contents = await sdk.ide.readRangeInFile(range_in_file)
+ rif_with_contents.append(
+ RangeInFileWithContents.from_range_in_file(range_in_file, file_contents))
+
+ rif_dict = {}
+ for rif in rif_with_contents:
+ rif_dict[rif.filepath] = rif.contents
+
+ for rif in rif_with_contents:
+ prompt = self._prompt.format(
+ code=rif.contents, user_request=self.user_input)
+ completion = str((await sdk.models.starcoder()).complete(prompt))
+ eot_token = "<|endoftext|>"
+ if completion.endswith(eot_token):
+ completion = completion[:completion.rindex(eot_token)]
+
+ await sdk.ide.applyFileSystemEdit(
+ FileEdit(filepath=rif.filepath, range=rif.range, replacement=completion))
+ await sdk.ide.saveFile(rif.filepath)
+ await sdk.ide.setFileOpen(rif.filepath)
+
+
class EditHighlightedCodeStep(Step):
user_input: str
hide = True
diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py
index bfd7a09c..c5dcea31 100644
--- a/continuedev/src/continuedev/server/notebook.py
+++ b/continuedev/src/continuedev/server/notebook.py
@@ -12,7 +12,7 @@ from ..libs.steps.nate import ImplementAbstractMethodStep
from ..core.observation import Observation
from ..libs.llm.openai import OpenAI
from .ide_protocol import AbstractIdeProtocolServer
-from ..core.env import openai_api_key
+from ..core.env import get_env_var
import asyncio
import nest_asyncio
nest_asyncio.apply()
@@ -75,7 +75,7 @@ class SessionManager:
def new_session(self, ide: AbstractIdeProtocolServer) -> str:
cmd = "python3 /Users/natesesti/Desktop/continue/extension/examples/python/main.py"
- agent = DemoAgent(llm=OpenAI(api_key=openai_api_key),
+ agent = DemoAgent(llm=OpenAI(api_key=get_env_var("OPENAI_API_KEY")),
policy=DemoPolicy(cmd=cmd), ide=ide)
session_id = str(uuid4())
session = Session(session_id=session_id, agent=agent)
diff --git a/extension/react-app/src/components/StepContainer.tsx b/extension/react-app/src/components/StepContainer.tsx
index 03649b66..36b3d99a 100644
--- a/extension/react-app/src/components/StepContainer.tsx
+++ b/extension/react-app/src/components/StepContainer.tsx
@@ -144,6 +144,9 @@ function StepContainer(props: StepContainerProps) {
onSubmit={(ev) => {
props.onUserInput(ev.currentTarget.value);
}}
+ onClick={(e) => {
+ e.stopPropagation();
+ }}
/>
)}
{props.historyNode.step.name === "Waiting for user confirmation" && (
@@ -165,24 +168,6 @@ function StepContainer(props: StepContainerProps) {
/>
</>
)}
-
- {open && (
- <>
- {/* {props.historyNode.observation && (
- <SubContainer title="Error">
- <CodeBlock>Error Here</CodeBlock>
- </SubContainer>
- )} */}
- {/* {props.iterationContext.suggestedChanges.map((sc) => {
- return (
- <SubContainer title="Suggested Change">
- {sc.filepath}
- <CodeBlock>{sc.replacement}</CodeBlock>
- </SubContainer>
- );
- })} */}
- </>
- )}
</StepContainerDiv>
</GradientBorder>