diff options
Diffstat (limited to 'continuedev/src')
23 files changed, 704 insertions, 299 deletions
diff --git a/continuedev/src/continuedev/core/agent.py b/continuedev/src/continuedev/core/agent.py index 6d1f542e..cf5c9781 100644 --- a/continuedev/src/continuedev/core/agent.py +++ b/continuedev/src/continuedev/core/agent.py @@ -10,10 +10,10 @@ from ..models.main import ContinueBaseModel from .main import Policy, History, FullState, Step, HistoryNode from ..libs.steps.core.core import ReversibleStep, ManualEditStep, UserInputStep from .sdk import ContinueSDK +import asyncio class Agent(ContinueBaseModel): - llm: LLM policy: Policy ide: AbstractIdeProtocolServer history: History = History.from_empty() @@ -31,27 +31,25 @@ class Agent(ContinueBaseModel): def get_full_state(self) -> FullState: return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue) - def on_update(self, callback: Callable[["FullState"], None]): + def on_update(self, callback: Coroutine["FullState", None, None]): """Subscribe to changes to state""" self._on_update_callbacks.append(callback) - def update_subscribers(self): + async def update_subscribers(self): full_state = self.get_full_state() for callback in self._on_update_callbacks: - callback(full_state) - - def __get_step_params(self, step: "Step"): - return ContinueSDK(agent=self, llm=self.llm.with_system_message(step.system_message)) + await callback(full_state) def give_user_input(self, input: str, index: int): - self._user_input_queue.post(index, input) + self._user_input_queue.post(str(index), input) async def wait_for_user_input(self) -> str: self._active = False - self.update_subscribers() - await self._user_input_queue.get(self.history.current_index) + await self.update_subscribers() + user_input = await self._user_input_queue.get(str(self.history.current_index)) self._active = True - self.update_subscribers() + await self.update_subscribers() + return user_input _manual_edits_buffer: List[FileEditWithFullContents] = [] @@ -61,9 +59,9 @@ class Agent(ContinueBaseModel): current_step = self.history.get_current().step self.history.step_back() if issubclass(current_step.__class__, ReversibleStep): - await current_step.reverse(self.__get_step_params(current_step)) + await current_step.reverse(ContinueSDK(self)) - self.update_subscribers() + await self.update_subscribers() except Exception as e: print(e) @@ -91,19 +89,24 @@ class Agent(ContinueBaseModel): self.history.add_node(HistoryNode( step=step, observation=None, depth=self._step_depth)) + # Call all subscribed callbacks + await self.update_subscribers() + # Run step self._step_depth += 1 - observation = await step(self.__get_step_params(step)) + observation = await step(ContinueSDK(self)) self._step_depth -= 1 # Add observation to history - self.history.get_current().observation = observation + self.history.get_last_at_depth( + self._step_depth, include_current=True).observation = observation # Update its description - step._set_description(await step.describe(self.llm)) - - # Call all subscribed callbacks - self.update_subscribers() + async def update_description(): + step._set_description(await step.describe(ContinueSDK(self).models)) + # Update subscribers with new description + await self.update_subscribers() + asyncio.create_task(update_description()) return observation @@ -137,7 +140,7 @@ class Agent(ContinueBaseModel): # Doing this so active can make it to the frontend after steps are done. But want better state syncing tools for callback in self._on_update_callbacks: - callback(None) + await callback(None) async def run_from_observation(self, observation: Observation): next_step = self.policy.next(self.history) @@ -157,7 +160,7 @@ class Agent(ContinueBaseModel): async def accept_user_input(self, user_input: str): self._main_user_input_queue.append(user_input) - self.update_subscribers() + await self.update_subscribers() if len(self._main_user_input_queue) > 1: return @@ -166,7 +169,7 @@ class Agent(ContinueBaseModel): # Just run the step that takes user input, and # then up to the policy to decide how to deal with it. self._main_user_input_queue.pop(0) - self.update_subscribers() + await self.update_subscribers() await self.run_from_step(UserInputStep(user_input=user_input)) while len(self._main_user_input_queue) > 0: diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py new file mode 100644 index 00000000..e62f0e4f --- /dev/null +++ b/continuedev/src/continuedev/core/config.py @@ -0,0 +1,29 @@ +import json +import os +from pydantic import BaseModel +from typing import List, Optional, Dict +import yaml + + +class ContinueConfig(BaseModel): + """ + A pydantic class for the continue config file. + """ + steps_on_startup: Optional[Dict[str, Dict]] = {} + server_url: Optional[str] = None + + +def load_config(config_file: str) -> ContinueConfig: + """ + Load the config file and return a ContinueConfig object. + """ + _, ext = os.path.splitext(config_file) + if ext == '.yaml': + with open(config_file, 'r') as f: + config_dict = yaml.safe_load(f) + elif ext == '.json': + with open(config_file, 'r') as f: + config_dict = json.load(f) + else: + raise ValueError(f'Unknown config file extension: {ext}') + return ContinueConfig(**config_dict) diff --git a/continuedev/src/continuedev/core/env.py b/continuedev/src/continuedev/core/env.py index d7275b41..2692c348 100644 --- a/continuedev/src/continuedev/core/env.py +++ b/continuedev/src/continuedev/core/env.py @@ -1,7 +1,30 @@ from dotenv import load_dotenv import os -load_dotenv() +def get_env_var(var_name: str): + load_dotenv() + return os.getenv(var_name) -openai_api_key = os.getenv("OPENAI_API_KEY") + +def make_sure_env_exists(): + if not os.path.exists('.env'): + with open('.env', 'w') as f: + f.write('') + + +def save_env_var(var_name: str, var_value: str): + make_sure_env_exists() + + with open('.env', 'r') as f: + lines = f.readlines() + with open('.env', 'w') as f: + values = {} + for line in lines: + key, value = line.split('=') + value = value.replace('"', '') + values[key] = value + + values[var_name] = var_value + for key, value in values.items(): + f.write(f'{key}="{value}"\n') diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 51fcd299..a2336671 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -27,6 +27,17 @@ class History(ContinueBaseModel): return None return self.timeline[self.current_index] + def get_last_at_depth(self, depth: int, include_current: bool = False) -> Union[HistoryNode, None]: + i = self.current_index if include_current else self.current_index - 1 + while i >= 0: + if self.timeline[i].depth == depth and type(self.timeline[i].step).__name__ != "ManualEditStep": + return self.timeline[i] + i -= 1 + return None + + def get_last_at_same_depth(self) -> Union[HistoryNode, None]: + return self.get_last_at_depth(self.get_current().depth) + def remove_current_and_substeps(self): self.timeline.pop(self.current_index) while self.get_current() is not None and self.get_current().depth > 0: @@ -51,7 +62,7 @@ class History(ContinueBaseModel): self.current_index -= 1 def last_observation(self) -> Union[Observation, None]: - state = self.get_current() + state = self.get_last_at_same_depth() if state is None: return None return state.observation @@ -72,7 +83,7 @@ class ContinueSDK: pass -class SequentialStep: +class Models: pass @@ -94,7 +105,7 @@ class Step(ContinueBaseModel): class Config: copy_on_model_validation = False - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: if self._description is not None: return self._description return "Running step: " + self.name @@ -135,6 +146,16 @@ class Step(ContinueBaseModel): return SequentialStep(steps=steps) +class SequentialStep(Step): + steps: list[Step] + hide: bool = True + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + for step in self.steps: + observation = await sdk.run_step(step) + return observation + + class ValidatorObservation(Observation): passed: bool observation: Observation diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index 5a6652f4..9f68515f 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -1,22 +1,23 @@ from typing import List, Tuple, Type - -from ..models.main import ContinueBaseModel - -from ..libs.steps.ty import CreatePipelineStep +from ..libs.steps.steps_on_startup import StepsOnStartupStep +from ..libs.steps.draft.dlt import CreatePipelineStep from .main import Step, Validator, History, Policy from .observation import Observation, TracebackObservation, UserInputObservation -from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep +from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep, MessageStep, EmptyStep from ..libs.steps.nate import WritePytestsStep, CreateTableStep -from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma +# from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma from ..libs.steps.continue_step import ContinueStepStep class DemoPolicy(Policy): ran_code_last: bool = False - cmd: str def next(self, history: History) -> Step: + # At the very start, run initial Steps spcecified in the config + if history.get_current() is None: + return MessageStep(message="Welcome to Continue!") >> StepsOnStartupStep() + observation = history.last_observation() if observation is not None and isinstance(observation, UserInputObservation): # This could be defined with ObservationTypePolicy. Ergonomics not right though. @@ -26,18 +27,15 @@ class DemoPolicy(Policy): return CreatePipelineStep() elif "/table" in observation.user_input: return CreateTableStep(sql_str=" ".join(observation.user_input.split(" ")[1:])) - elif "/ask" in observation.user_input: - return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:])) - elif "/edit" in observation.user_input: - return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:])) + # elif "/ask" in observation.user_input: + # return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:])) + # elif "/edit" in observation.user_input: + # return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:])) elif "/step" in observation.user_input: return ContinueStepStep(prompt=" ".join(observation.user_input.split(" ")[1:])) - return EditHighlightedCodeStep(user_input=observation.user_input) + return StarCoderEditHighlightedCodeStep(user_input=observation.user_input) state = history.get_current() - if state is None or not self.ran_code_last: - self.ran_code_last = True - return RunCodeStep(cmd=self.cmd) if observation is not None and isinstance(observation, TracebackObservation): self.ran_code_last = False diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 3559e9d7..ce0c53fd 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,12 +1,17 @@ import os from typing import Coroutine, Union -from ..models.filesystem_edit import FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory + +from .config import ContinueConfig, load_config +from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory from ..models.filesystem import RangeInFile from ..libs.llm import LLM +from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI +from ..libs.llm.openai import OpenAI from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer from .main import History, Step from ..libs.steps.core.core import * +from .env import get_env_var, make_sure_env_exists class Agent: @@ -18,21 +23,33 @@ class ContinueSDKSteps: self.sdk = sdk +class Models: + def __init__(self, sdk: "ContinueSDK"): + self.sdk = sdk + + async def starcoder(self): + api_key = await self.sdk.get_user_secret( + 'HUGGING_FACE_TOKEN', 'Please add your Hugging Face token to the .env file') + return HuggingFaceInferenceAPI(api_key=api_key) + + async def gpt35(self): + api_key = await self.sdk.get_user_secret( + 'OPENAI_API_KEY', 'Please add your OpenAI API key to the .env file') + return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo") + + class ContinueSDK: """The SDK provided as parameters to a step""" - llm: LLM ide: AbstractIdeProtocolServer steps: ContinueSDKSteps + models: Models __agent: Agent - def __init__(self, agent: Agent, llm: Union[LLM, None] = None): - if llm is None: - self.llm = agent.llm - else: - self.llm = llm + def __init__(self, agent: Agent): self.ide = agent.ide self.__agent = agent self.steps = ContinueSDKSteps(self) + self.models = Models(self) @property def history(self) -> History: @@ -57,7 +74,7 @@ class ContinueSDK: async def run(self, commands: List[str] | str, cwd: str = None): commands = commands if isinstance(commands, List) else [commands] - return await self.run_step(ShellCommandsStep(commands=commands, cwd=cwd)) + return await self.run_step(ShellCommandsStep(cmds=commands, cwd=cwd)) async def edit_file(self, filename: str, prompt: str): filepath = await self._ensure_absolute_path(filename) @@ -69,6 +86,12 @@ class ContinueSDK: prompt=f'Here is the code before:\n\n{{code}}\n\nHere is the user request:\n\n{prompt}\n\nHere is the code edited to perfectly solve the user request:\n\n' )) + async def append_to_file(self, filename: str, content: str): + filepath = await self._ensure_absolute_path(filename) + previous_content = await self.ide.readFile(filepath) + file_edit = FileEdit.from_append(filepath, previous_content, content) + await self.ide.applyFileSystemEdit(file_edit) + async def add_file(self, filename: str, content: str | None): return await self.run_step(FileSystemEditStep(edit=AddFile(filename=filename, content=content))) @@ -80,3 +103,38 @@ class ContinueSDK: async def delete_directory(self, path: str): return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) + + async def get_user_secret(self, env_var: str, prompt: str) -> str: + make_sure_env_exists() + + val = None + while val is None: + try: + val = get_env_var(env_var) + if val is not None: + return val + except: + pass + server_dir = os.getcwd() + env_path = os.path.join(server_dir, ".env") + await self.ide.setFileOpen(env_path) + await self.append_to_file(env_path, f'\n{env_var}="<ENTER SECRET HERE>"') + await self.run_step(WaitForUserConfirmationStep(prompt=prompt)) + val = get_env_var(env_var) + + return val + + async def get_config(self) -> ContinueConfig: + dir = await self.ide.getWorkspaceDirectory() + yaml_path = os.path.join(dir, 'continue.yaml') + json_path = os.path.join(dir, 'continue.json') + if os.path.exists(yaml_path): + return load_config(yaml_path) + elif os.path.exists(json_path): + return load_config(json_path) + else: + return ContinueConfig() + + def set_loading_message(self, message: str): + # self.__agent.set_loading_message(message) + raise NotImplementedError() diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py new file mode 100644 index 00000000..83852d27 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -0,0 +1,25 @@ +from ..llm import LLM +import requests + +DEFAULT_MAX_TOKENS = 2048 +DEFAULT_MAX_TIME = 120. + + +class HuggingFaceInferenceAPI(LLM): + api_key: str + model: str = "bigcode/starcoder" + + def complete(self, prompt: str, **kwargs): + """Return the completion of the text with the given temperature.""" + API_URL = f"https://api-inference.huggingface.co/models/{self.model}" + headers = { + "Authorization": f"Bearer {self.api_key}"} + + response = requests.post(API_URL, headers=headers, json={ + "inputs": prompt, "parameters": { + "max_new_tokens": DEFAULT_MAX_TOKENS, + "max_time": DEFAULT_MAX_TIME, + "return_full_text": False, + } + }) + return response.json()[0]["generated_text"] diff --git a/continuedev/src/continuedev/libs/steps/chroma.py b/continuedev/src/continuedev/libs/steps/chroma.py index f13a2bab..39424c5c 100644 --- a/continuedev/src/continuedev/libs/steps/chroma.py +++ b/continuedev/src/continuedev/libs/steps/chroma.py @@ -40,7 +40,7 @@ class AnswerQuestionChroma(Step): Here is the answer:""") - answer = sdk.llm.complete(prompt) + answer = (await sdk.models.gpt35()).complete(prompt) print(answer) self._answer = answer diff --git a/continuedev/src/continuedev/libs/steps/core/core.py b/continuedev/src/continuedev/libs/steps/core/core.py index 0338d635..9a5d54f0 100644 --- a/continuedev/src/continuedev/libs/steps/core/core.py +++ b/continuedev/src/continuedev/libs/steps/core/core.py @@ -4,27 +4,18 @@ from textwrap import dedent from typing import Coroutine, List, Union from ...llm.prompt_utils import MarkdownStyleEncoderDecoder -from ...util.traceback_parsers import parse_python_traceback - from ....models.filesystem_edit import EditDiff, FileEditWithFullContents, FileSystemEdit from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents -from ...llm import LLM from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation -from ....core.main import Step +from ....core.main import Step, SequentialStep class ContinueSDK: pass -class SequentialStep(Step): - steps: list[Step] - hide: bool = True - - async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - for step in self.steps: - observation = await sdk.run_step(step) - return observation +class Models: + pass class ReversibleStep(Step): @@ -47,13 +38,14 @@ class FileSystemEditStep(ReversibleStep): # Where and when should file saves happen? -def ShellCommandsStep(Step): +class ShellCommandsStep(Step): cmds: List[str] cwd: str | None = None name: str = "Run Shell Commands" - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: - return "\n".join(self.cmds) + async def describe(self, models: Models) -> Coroutine[str, None, None]: + cmds_str = "\n".join(self.cmds) + return (await models.gpt35()).complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:") async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd @@ -81,13 +73,13 @@ class EditCodeStep(Step): _prompt: Union[str, None] = None _completion: Union[str, None] = None - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: if self._edit_diffs is None: return "Editing files: " + ", ".join(map(lambda rif: rif.filepath, self.range_in_files)) elif len(self._edit_diffs) == 0: return "No edits made" else: - return llm.complete(dedent(f"""{self._prompt}{self._completion} + return (await models.gpt35()).complete(dedent(f"""{self._prompt}{self._completion} Maximally concise summary of changes in bullet points (can use markdown): """)) @@ -102,7 +94,7 @@ class EditCodeStep(Step): code_string = enc_dec.encode() prompt = self.prompt.format(code=code_string) - completion = sdk.llm.complete(prompt) + completion = (await sdk.models.gpt35()).complete(prompt) # Temporarily doing this to generate description. self._prompt = prompt @@ -127,7 +119,7 @@ class EditFileStep(Step): prompt: str hide: bool = True - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: return "Editing file: " + self.filepath async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -145,7 +137,7 @@ class ManualEditStep(ReversibleStep): hide: bool = True - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: return "Manual edit step" # TODO - only handling FileEdit here, but need all other types of FileSystemEdits # Also requires the merge_file_edit function @@ -181,7 +173,7 @@ class UserInputStep(Step): name: str = "User Input" hide: bool = True - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: return self.user_input async def run(self, sdk: ContinueSDK) -> Coroutine[UserInputObservation, None, None]: @@ -194,7 +186,7 @@ class WaitForUserInputStep(Step): _description: Union[str, None] = None - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: return self.prompt async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -207,7 +199,7 @@ class WaitForUserConfirmationStep(Step): prompt: str name: str = "Waiting for user confirmation" - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: return self.prompt async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: diff --git a/continuedev/src/continuedev/libs/steps/draft/dlt.py b/continuedev/src/continuedev/libs/steps/draft/dlt.py index 5ba5692a..73762327 100644 --- a/continuedev/src/continuedev/libs/steps/draft/dlt.py +++ b/continuedev/src/continuedev/libs/steps/draft/dlt.py @@ -1,16 +1,29 @@ from textwrap import dedent + +from ....core.sdk import Models + +from ....core.observation import DictObservation from ....models.filesystem_edit import AddFile from ....core.main import Step from ....core.sdk import ContinueSDK -from ..main import WaitForUserInputStep +from ..core.core import WaitForUserInputStep +from ..main import MessageStep class SetupPipelineStep(Step): + hide: bool = True + name: str = "Setup dlt Pipeline" api_description: str # e.g. "I want to load data from the weatherapi.com API" + async def describe(self, models: Models): + return dedent(f"""\ + This step will create a new dlt pipeline that loads data from an API, as per your request: + {self.api_description} + """) + async def run(self, sdk: ContinueSDK): - source_name = sdk.llm.complete( + source_name = (await sdk.models.gpt35()).complete( f"Write a snake_case name for the data source described by {self.api_description}: ").strip() filename = f'{source_name}.py' @@ -19,7 +32,7 @@ class SetupPipelineStep(Step): 'python3 -m venv env', 'source env/bin/activate', 'pip install dlt', - 'dlt init {source_name} duckdb', + f'dlt init {source_name} duckdb', 'Y', 'pip install -r requirements.txt' ]) @@ -31,15 +44,14 @@ class SetupPipelineStep(Step): ) # wait for user to put API key in secrets.toml - await sdk.ide.setFileOpen(".dlt/secrets.toml") - await sdk.wait_for_user_confirmation("Please add the API key to the `secrets.toml` file and then press `Continue`") - return {"source_name": source_name} + await sdk.ide.setFileOpen(await sdk.ide.getWorkspaceDirectory() + "/.dlt/secrets.toml") + await sdk.wait_for_user_confirmation("If this service requires an API key, please add it to the `secrets.toml` file and then press `Continue`") + return DictObservation(values={"source_name": source_name}) class ValidatePipelineStep(Step): - async def run(self, sdk: ContinueSDK): - source_name = sdk.history.last_observation()["source_name"] + source_name = sdk.history.last_observation().values["source_name"] filename = f'{source_name}.py' # test that the API call works @@ -68,15 +80,27 @@ class ValidatePipelineStep(Step): for row in rows: print(row) ''') - await sdk.apply_filesystem_edit(AddFile(filepath='query.py', content=tables_query_code)) + + query_filename = (await sdk.ide.getWorkspaceDirectory()) + "/query.py" + await sdk.apply_filesystem_edit(AddFile(filepath=query_filename, content=tables_query_code)) await sdk.run('env/bin/python3 query.py') class CreatePipelineStep(Step): + hide: bool = True async def run(self, sdk: ContinueSDK): await sdk.run_step( + MessageStep(message=dedent("""\ + This recipe will walk you through the process of creating a dlt pipeline for your chosen data source. With the help of Continue, you will: + - Create a Python virtual environment with dlt installed + - Run `dlt init` to generate a pipeline template + - Write the code to call the API + - Add any required API keys to the `secrets.toml` file + - Test that the API call works + - Load the data into a local DuckDB instance + - Write a query to view the data""")) >> WaitForUserInputStep(prompt="What API do you want to load data from?") >> - SetupPipelineStep() >> + SetupPipelineStep(api_description="WeatherAPI.com API") >> ValidatePipelineStep() ) diff --git a/continuedev/src/continuedev/libs/steps/main.py b/continuedev/src/continuedev/libs/steps/main.py index f28cb23f..c70d5c2c 100644 --- a/continuedev/src/continuedev/libs/steps/main.py +++ b/continuedev/src/continuedev/libs/steps/main.py @@ -1,4 +1,6 @@ -from typing import Callable, Coroutine, List, Union +from typing import Coroutine, List, Union + +from pydantic import BaseModel from ..util.traceback_parsers import parse_python_traceback from ..llm import LLM @@ -8,16 +10,17 @@ from ...models.filesystem import RangeInFile, RangeInFileWithContents from ...core.observation import Observation, TextObservation, TracebackObservation from ..llm.prompt_utils import MarkdownStyleEncoderDecoder from textwrap import dedent -from ...core.main import History, Policy, Step, ContinueSDK, Observation +from ...core.main import Step +from ...core.sdk import ContinueSDK, Models +from ...core.observation import Observation import subprocess -import json from .core.core import EditCodeStep class RunCodeStep(Step): cmd: str - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: return f"Ran command: `{self.cmd}`" async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -36,6 +39,10 @@ class RunCodeStep(Step): return None +class Policy(BaseModel): + pass + + class RunPolicyUntilDoneStep(Step): policy: "Policy" @@ -52,7 +59,7 @@ class RunCommandStep(Step): name: str = "Run command" _description: str = None - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: if self._description is not None: return self._description return self.cmd @@ -77,44 +84,48 @@ class FasterEditHighlightedCodeStep(Step): hide = True _completion: str = "Edit Code" _edit_diffs: Union[List[EditDiff], None] = None - _prompt: str = dedent("""Below is the code before changes: - -{code} - -This is the user request: - -{user_input} - -Edit the code to perfectly satifsfy the user request. Format the changes you want to make as a comma-separated array of JSON objects of the form: -{{ - "edits": [{{ - "filepath": <FILEPATH>, - "replace_me": <CODE_TO_REPLACE>, - "replace_with": <CODE_TO_REPLACE_WITH> - }}] -}} - -For example, if you want to replace the code `x = 1` with `x = 2` in main.py, you would write: -{{ - "edits": [{{ - "filepath": "main.py", - "replace_me": "x = 1", - "replace_with": "x = 2" - }}] -}} -If you wanted to delete the code `def sum(a, b):\\n return a + b` in main.py, you would write: -{{ - "edits": [{{ - "filepath": "main.py", - "replace_me": "def sum(a, b):\\n return a + b", - "replace_with": "" - }}] -}} - -Respond with only as many edits as needed, and output only the list of json objects, no other text. + _prompt: str = dedent("""\ + You will be given code to edit in order to perfectly satisfy the user request. All the changes you make must be described as replacements, which you should format in the following way: + FILEPATH + <FILE_TO_EDIT> + REPLACE_ME + <CODE_TO_REPLACE> + REPLACE_WITH + <CODE_TO_REPLACE_WITH> + + where <CODE_TO_REPLACE> and <CODE_TO_REPLACE_WITH> can be multiple lines, but should be the mininum needed to make the edit. Be sure to maintain existing whitespace at the start of lines. + + For example, if you want to replace the code `x = 1` with `x = 2` in main.py, you would write: + FILEPATH + main.py + REPLACE_ME + x = 1 + REPLACE_WITH + x = 2 + If you wanted to delete the code + ``` + def sum(a, b): + return a + b + ``` + in main.py, you would write: + FILEPATH + main.py + REPLACE_ME + def sum(a, b): + return a + b + REPLACE_WITH + + You may need to make multiple edits; respond with exactly as many as needed. + + Below is the code before changes: + + {code} + + This is the user request: "{user_input}" + Here is the description of changes to make: """) - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: return "Editing highlighted code" async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -143,21 +154,51 @@ Respond with only as many edits as needed, and output only the list of json obje for rif in rif_with_contents: rif_dict[rif.filepath] = rif.contents - completion = sdk.llm.complete(prompt) + completion = (await sdk.models.gpt35()).complete(prompt) # Temporarily doing this to generate description. self._prompt = prompt self._completion = completion + print(completion) # ALTERNATIVE DECODING STEP HERE + raw_file_edits = [] + lines = completion.split("\n") + current_edit = {} + status = "FILEPATH" + for i in range(0, len(lines)): + line = lines[i] + if line == "FILEPATH": + if "FILEPATH" in current_edit: + raw_file_edits.append(current_edit) + current_edit = {} + status = "FILEPATH" + elif line == "REPLACE_ME": + status = "REPLACE_ME" + elif line == "REPLACE_WITH": + status = "REPLACE_WITH" + elif status == "FILEPATH": + current_edit["filepath"] = line + elif status == "REPLACE_ME": + if "replace_me" in current_edit: + current_edit["replace_me"] += "\n" + line + else: + current_edit["replace_me"] = line + elif status == "REPLACE_WITH": + if "replace_with" in current_edit: + current_edit["replace_with"] += "\n" + line + else: + current_edit["replace_with"] = line + if "filepath" in current_edit: + raw_file_edits.append(current_edit) + file_edits = [] - obj = json.loads(completion.strip()) - for edit in obj["edits"]: + for edit in raw_file_edits: filepath = edit["filepath"] replace_me = edit["replace_me"] replace_with = edit["replace_with"] file_edits.append( - FileEdit(filepath=filepath, range=Range.from_snippet_in_file(content=rif_dict[filepath], snippet=replace_me), replacement=replace_with)) + FileEdit(filepath=filepath, range=Range.from_lines_snippet_in_file(content=rif_dict[filepath], snippet=replace_me), replacement=replace_with)) # ------------------------------ self._edit_diffs = [] @@ -172,6 +213,50 @@ Respond with only as many edits as needed, and output only the list of json obje return None +class StarCoderEditHighlightedCodeStep(Step): + user_input: str + hide = False + _prompt: str = "<commit_before>{code}<commit_msg>{user_request}<commit_after>" + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return "Editing highlighted code" + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + range_in_files = await sdk.ide.getHighlightedCode() + if len(range_in_files) == 0: + # Get the full contents of all open files + files = await sdk.ide.getOpenFiles() + contents = {} + for file in files: + contents[file] = await sdk.ide.readFile(file) + + range_in_files = [RangeInFile.from_entire_file( + filepath, content) for filepath, content in contents.items()] + + rif_with_contents = [] + for range_in_file in range_in_files: + file_contents = await sdk.ide.readRangeInFile(range_in_file) + rif_with_contents.append( + RangeInFileWithContents.from_range_in_file(range_in_file, file_contents)) + + rif_dict = {} + for rif in rif_with_contents: + rif_dict[rif.filepath] = rif.contents + + for rif in rif_with_contents: + prompt = self._prompt.format( + code=rif.contents, user_request=self.user_input) + completion = str((await sdk.models.starcoder()).complete(prompt)) + eot_token = "<|endoftext|>" + if completion.endswith(eot_token): + completion = completion[:completion.rindex(eot_token)] + + await sdk.ide.applyFileSystemEdit( + FileEdit(filepath=rif.filepath, range=rif.range, replacement=completion)) + await sdk.ide.saveFile(rif.filepath) + await sdk.ide.setFileOpen(rif.filepath) + + class EditHighlightedCodeStep(Step): user_input: str hide = True @@ -186,7 +271,7 @@ This is the user request: This is the code after being changed to perfectly satisfy the user request: """) - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: return "Editing highlighted code" async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -208,7 +293,7 @@ This is the code after being changed to perfectly satisfy the user request: class FindCodeStep(Step): prompt: str - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: return "Finding code" async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -222,7 +307,7 @@ class UserInputStep(Step): class SolveTracebackStep(Step): traceback: Traceback - async def describe(self, llm: LLM) -> Coroutine[str, None, None]: + async def describe(self, models: Models) -> Coroutine[str, None, None]: return f"```\n{self.traceback.full_traceback}\n```" async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -246,3 +331,24 @@ class SolveTracebackStep(Step): await sdk.run_step(EditCodeStep( range_in_files=range_in_files, prompt=prompt)) return None + + +class MessageStep(Step): + name: str = "Message" + message: str + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return self.message + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + return TextObservation(text=self.message) + + +class EmptyStep(Step): + hide: bool = True + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return "" + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + pass diff --git a/continuedev/src/continuedev/libs/steps/migration.py b/continuedev/src/continuedev/libs/steps/migration.py index f044a60f..7b70422d 100644 --- a/continuedev/src/continuedev/libs/steps/migration.py +++ b/continuedev/src/continuedev/libs/steps/migration.py @@ -15,7 +15,7 @@ class MigrationStep(Step): recent_edits = await sdk.ide.get_recent_edits(self.edited_file) recent_edits_string = "\n\n".join( map(lambda x: x.to_string(), recent_edits)) - description = await sdk.llm.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") + description = await (await sdk.models.gpt35()).complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") await sdk.run_step(RunCommandStep(cmd=f"cd libs && poetry run alembic revision --autogenerate -m {description}")) migration_file = f"libs/alembic/versions/{?}.py" contents = await sdk.ide.readFile(migration_file) diff --git a/continuedev/src/continuedev/libs/steps/nate.py b/continuedev/src/continuedev/libs/steps/nate.py index a0e728e5..2f84e9d7 100644 --- a/continuedev/src/continuedev/libs/steps/nate.py +++ b/continuedev/src/continuedev/libs/steps/nate.py @@ -45,7 +45,7 @@ Here are additional instructions: Here is a complete set of pytest unit tests: """) - # tests = sdk.llm.complete(prompt) + # tests = (await sdk.models.gpt35()).complete(prompt) tests = ''' import pytest @@ -169,9 +169,9 @@ export class Order { tracking_number: string; }''' time.sleep(2) - # orm_entity = sdk.llm.complete( + # orm_entity = (await sdk.models.gpt35()).complete( # f"{self.sql_str}\n\nWrite a TypeORM entity called {entity_name} for this table, importing as necessary:") - # sdk.llm.complete("What is the name of the entity?") + # (await sdk.models.gpt35()).complete("What is the name of the entity?") await sdk.apply_filesystem_edit(AddFile(filepath=f"/Users/natesesti/Desktop/continue/extension/examples/python/MyProject/src/entity/{entity_name}.ts", content=orm_entity)) await sdk.ide.setFileOpen(f"/Users/natesesti/Desktop/continue/extension/examples/python/MyProject/src/entity/{entity_name}.ts", True) diff --git a/continuedev/src/continuedev/libs/steps/pytest.py b/continuedev/src/continuedev/libs/steps/pytest.py index b4e6dfd2..2e83ae2d 100644 --- a/continuedev/src/continuedev/libs/steps/pytest.py +++ b/continuedev/src/continuedev/libs/steps/pytest.py @@ -33,5 +33,5 @@ class WritePytestsStep(Step): Here is a complete set of pytest unit tests: """) - tests = sdk.llm.complete(prompt) + tests = (await sdk.models.gpt35()).complete(prompt) await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/libs/steps/steps_on_startup.py b/continuedev/src/continuedev/libs/steps/steps_on_startup.py new file mode 100644 index 00000000..fd1eb8f0 --- /dev/null +++ b/continuedev/src/continuedev/libs/steps/steps_on_startup.py @@ -0,0 +1,30 @@ + + +from ...core.main import ContinueSDK, Models, Step +from .main import UserInputStep +from .draft.dlt import CreatePipelineStep + + +step_name_to_step_class = { + "UserInputStep": UserInputStep, + "CreatePipelineStep": CreatePipelineStep +} + + +class StepsOnStartupStep(Step): + hide: bool = True + + async def describe(self, models: Models): + return "Running steps on startup" + + async def run(self, sdk: ContinueSDK): + steps_descriptions = (await sdk.get_config()).steps_on_startup + + for step_name, step_params in steps_descriptions.items(): + try: + step = step_name_to_step_class[step_name](**step_params) + except: + print( + f"Incorrect parameters for step {step_name}. Parameters provided were: {step_params}") + continue + await sdk.run_step(step) diff --git a/continuedev/src/continuedev/libs/steps/ty.py b/continuedev/src/continuedev/libs/steps/ty.py index 5ff03f04..9dde7c86 100644 --- a/continuedev/src/continuedev/libs/steps/ty.py +++ b/continuedev/src/continuedev/libs/steps/ty.py @@ -18,7 +18,7 @@ class SetupPipelineStep(Step): api_description: str # e.g. "I want to load data from the weatherapi.com API" async def run(self, sdk: ContinueSDK): - # source_name = sdk.llm.complete( + # source_name = (await sdk.models.gpt35()).complete( # f"Write a snake_case name for the data source described by {self.api_description}: ").strip() filename = f'/Users/natesesti/Desktop/continue/extension/examples/python/{source_name}.py' diff --git a/continuedev/src/continuedev/models/filesystem_edit.py b/continuedev/src/continuedev/models/filesystem_edit.py index 7526d4c9..8e74b819 100644 --- a/continuedev/src/continuedev/models/filesystem_edit.py +++ b/continuedev/src/continuedev/models/filesystem_edit.py @@ -37,6 +37,10 @@ class FileEdit(AtomicFileSystemEdit): def from_insertion(filepath: str, position: Position, content: str) -> "FileEdit": return FileEdit(filepath=filepath, range=Range.from_shorthand(position.line, position.character, position.line, position.character), replacement=content) + @staticmethod + def from_append(filepath: str, previous_content: str, appended_content: str) -> "FileEdit": + return FileEdit(filepath=filepath, range=Range.from_position(Position.from_end_of_file(previous_content)), replacement=appended_content) + class FileEditWithFullContents(BaseModel): fileEdit: FileEdit diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index 7986b30c..02c44aae 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -32,13 +32,17 @@ class Position(BaseModel): def from_index(string: str, index: int) -> "Position": """Convert index in string to line and character""" line = string.count("\n", 0, index) - if line == 1: + if line == 0: character = index else: character = index - string.rindex("\n", 0, index) - 1 return Position(line=line, character=character) + @staticmethod + def from_end_of_file(contents: str) -> "Position": + return Position.from_index(contents, len(contents)) + class Range(BaseModel): """A range in a file. 0-indexed.""" @@ -93,6 +97,34 @@ class Range(BaseModel): end_index = start_index + len(snippet) return Range.from_indices(content, start_index, end_index) + @staticmethod + def from_lines_snippet_in_file(content: str, snippet: str) -> "Range": + # lines is a substring of the content modulo whitespace on each line + content_lines = content.splitlines() + snippet_lines = snippet.splitlines() + + start_line = -1 + end_line = -1 + looking_for_line = 0 + for i in range(len(content_lines)): + if content_lines[i].strip() == snippet_lines[looking_for_line].strip(): + if looking_for_line == len(snippet_lines) - 1: + start_line = i - len(snippet_lines) + 1 + end_line = i + break + looking_for_line += 1 + else: + looking_for_line = 0 + + if start_line == -1 or end_line == -1: + raise ValueError("Snippet not found in content") + + return Range.from_shorthand(start_line, 0, end_line, len(content_lines[end_line]) - 1) + + @staticmethod + def from_position(position: Position) -> "Range": + return Range(start=position, end=position) + class AbstractModel(ABC, BaseModel): @root_validator(pre=True) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index dd1dc463..50296841 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -1,5 +1,6 @@ # This is a separate server from server/main.py import asyncio +import json import os from typing import Any, Dict, List, Type, TypeVar, Union import uuid @@ -90,31 +91,33 @@ class IdeProtocolServer(AbstractIdeProtocolServer): def __init__(self, session_manager: SessionManager): self.session_manager = session_manager - async def _send_json(self, data: Any): - await self.websocket.send_json(data) + async def _send_json(self, message_type: str, data: Any): + await self.websocket.send_json({ + "messageType": message_type, + "data": data + }) async def _receive_json(self, message_type: str) -> Any: return await self.sub_queue.get(message_type) async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: - await self._send_json(data) + await self._send_json(message_type, data) resp = await self._receive_json(message_type) return resp_model.parse_obj(resp) - async def handle_json(self, data: Any): - t = data["messageType"] - if t == "openNotebook": + async def handle_json(self, message_type: str, data: Any): + if message_type == "openNotebook": await self.openNotebook() - elif t == "setFileOpen": + elif message_type == "setFileOpen": await self.setFileOpen(data["filepath"], data["open"]) - elif t == "fileEdits": + elif message_type == "fileEdits": fileEdits = list( map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"])) self.onFileEdits(fileEdits) - elif t in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]: - self.sub_queue.post(t, data) + elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]: + self.sub_queue.post(message_type, data) else: - raise ValueError("Unknown message type", t) + raise ValueError("Unknown message type", message_type) # ------------------------------- # # Request actions in IDE, doesn't matter which Session @@ -123,24 +126,21 @@ class IdeProtocolServer(AbstractIdeProtocolServer): async def setFileOpen(self, filepath: str, open: bool = True): # Agent needs access to this. - await self.websocket.send_json({ - "messageType": "setFileOpen", + await self._send_json("setFileOpen", { "filepath": filepath, "open": open }) async def openNotebook(self): session_id = self.session_manager.new_session(self) - await self._send_json({ - "messageType": "openNotebook", + await self._send_json("openNotebook", { "sessionId": session_id }) async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: ids = [str(uuid.uuid4()) for _ in suggestions] for i in range(len(suggestions)): - self._send_json({ - "messageType": "showSuggestion", + self._send_json("showSuggestion", { "suggestion": suggestions[i], "suggestionId": ids[i] }) @@ -210,8 +210,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): async def saveFile(self, filepath: str): """Save a file""" - await self._send_json({ - "messageType": "saveFile", + await self._send_json("saveFile", { "filepath": filepath }) @@ -293,10 +292,17 @@ ideProtocolServer = IdeProtocolServer(session_manager) async def websocket_endpoint(websocket: WebSocket): await websocket.accept() print("Accepted websocket connection from, ", websocket.client) - await websocket.send_json({"messageType": "connected"}) + await websocket.send_json({"messageType": "connected", "data": {}}) ideProtocolServer.websocket = websocket while True: - data = await websocket.receive_json() - await ideProtocolServer.handle_json(data) + message = await websocket.receive_text() + message = json.loads(message) + + if "messageType" not in message or "data" not in message: + continue + message_type = message["messageType"] + data = message["data"] + + await ideProtocolServer.handle_json(message_type, data) await websocket.close() diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index 11ad1d8f..1ffe1450 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,3 +1,4 @@ +import os from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from .ide import router as ide_router @@ -32,7 +33,11 @@ args = parser.parse_args() def run_server(): - uvicorn.run(app, host="0.0.0.0", port=args.port, log_config="logging.ini") + if os.path.exists("logging.yaml"): + uvicorn.run(app, host="0.0.0.0", port=args.port, + log_config="logging.yaml") + else: + uvicorn.run(app, host="0.0.0.0", port=args.port) if __name__ == "__main__": diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py index bfd7a09c..9ca510dd 100644 --- a/continuedev/src/continuedev/server/notebook.py +++ b/continuedev/src/continuedev/server/notebook.py @@ -1,18 +1,12 @@ -from fastapi import FastAPI, Depends, Header, WebSocket, APIRouter -from typing import Any, Dict, List, Union -from uuid import uuid4 +import json +from fastapi import Depends, Header, WebSocket, APIRouter +from typing import Any, Type, TypeVar, Union from pydantic import BaseModel from uvicorn.main import Server -from ..models.filesystem_edit import FileEditWithFullContents -from ..core.policy import DemoPolicy -from ..core.main import FullState, History, Step -from ..core.agent import Agent -from ..libs.steps.nate import ImplementAbstractMethodStep -from ..core.observation import Observation -from ..libs.llm.openai import OpenAI -from .ide_protocol import AbstractIdeProtocolServer -from ..core.env import openai_api_key +from .session_manager import SessionManager, session_manager, Session +from .notebook_protocol import AbstractNotebookProtocolServer +from ..libs.util.queue import AsyncSubscriptionQueue import asyncio import nest_asyncio nest_asyncio.apply() @@ -36,160 +30,101 @@ class AppStatus: Server.handle_exit = AppStatus.handle_exit -class Session: - session_id: str - agent: Agent - ws: Union[WebSocket, None] - - def __init__(self, session_id: str, agent: Agent): - self.session_id = session_id - self.agent = agent - self.ws = None - - -class DemoAgent(Agent): - first_seen: bool = False - cumulative_edit_string = "" - - def handle_manual_edits(self, edits: List[FileEditWithFullContents]): - for edit in edits: - self.cumulative_edit_string += edit.fileEdit.replacement - self._manual_edits_buffer.append(edit) - # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge. - # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) - # FOR DEMO PURPOSES - if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement: - self.cumulative_edit_string = "" - asyncio.create_task(self.run_from_step( - ImplementAbstractMethodStep())) - - -class SessionManager: - sessions: Dict[str, Session] = {} - _event_loop: Union[asyncio.BaseEventLoop, None] = None - - def get_session(self, session_id: str) -> Session: - if session_id not in self.sessions: - raise KeyError("Session ID not recognized") - return self.sessions[session_id] - - def new_session(self, ide: AbstractIdeProtocolServer) -> str: - cmd = "python3 /Users/natesesti/Desktop/continue/extension/examples/python/main.py" - agent = DemoAgent(llm=OpenAI(api_key=openai_api_key), - policy=DemoPolicy(cmd=cmd), ide=ide) - session_id = str(uuid4()) - session = Session(session_id=session_id, agent=agent) - self.sessions[session_id] = session +def session(x_continue_session_id: str = Header("anonymous")) -> Session: + return session_manager.get_session(x_continue_session_id) - def on_update(state: FullState): - session_manager.send_ws_data(session_id, { - "messageType": "state", - "state": agent.get_full_state().dict() - }) - agent.on_update(on_update) - asyncio.create_task(agent.run_policy()) - return session_id +def websocket_session(session_id: str) -> Session: + return session_manager.get_session(session_id) - def remove_session(self, session_id: str): - del self.sessions[session_id] - def register_websocket(self, session_id: str, ws: WebSocket): - self.sessions[session_id].ws = ws - print("Registered websocket for session", session_id) +T = TypeVar("T", bound=BaseModel) - def send_ws_data(self, session_id: str, data: Any): - if self.sessions[session_id].ws is None: - print(f"Session {session_id} has no websocket") - return +# You should probably abstract away the websocket stuff into a separate class - async def a(): - await self.sessions[session_id].ws.send_json(data) - # Run coroutine in background - if self._event_loop is None or self._event_loop.is_closed(): - self._event_loop = asyncio.new_event_loop() - self._event_loop.run_until_complete(a()) - self._event_loop.close() - else: - self._event_loop.run_until_complete(a()) - self._event_loop.close() +class NotebookProtocolServer(AbstractNotebookProtocolServer): + websocket: WebSocket + session: Session + sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() + def __init__(self, session: Session): + self.session = session -session_manager = SessionManager() + async def _send_json(self, message_type: str, data: Any): + await self.websocket.send_json({ + "messageType": message_type, + "data": data + }) + async def _receive_json(self, message_type: str) -> Any: + return await self.sub_queue.get(message_type) -def session(x_continue_session_id: str = Header("anonymous")) -> Session: - return session_manager.get_session(x_continue_session_id) + async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: + await self._send_json(message_type, data) + resp = await self._receive_json(message_type) + return resp_model.parse_obj(resp) + def handle_json(self, message_type: str, data: Any): + try: + if message_type == "main_input": + self.on_main_input(data["input"]) + elif message_type == "step_user_input": + self.on_step_user_input(data["input"], data["index"]) + elif message_type == "refinement_input": + self.on_refinement_input(data["input"], data["index"]) + elif message_type == "reverse_to_index": + self.on_reverse_to_index(data["index"]) + except Exception as e: + print(e) -def websocket_session(session_id: str) -> Session: - return session_manager.get_session(session_id) + async def send_state_update(self): + state = self.session.agent.get_full_state().dict() + await self._send_json("state_update", { + "state": state + }) + def on_main_input(self, input: str): + # Do something with user input + asyncio.create_task(self.session.agent.accept_user_input(input)) -class StartSessionBody(BaseModel): - config_file_path: Union[str, None] + def on_reverse_to_index(self, index: int): + # Reverse the history to the given index + asyncio.create_task(self.session.agent.reverse_to_index(index)) + def on_step_user_input(self, input: str, index: int): + asyncio.create_task( + self.session.agent.give_user_input(input, index)) -class StartSessionResp(BaseModel): - session_id: str + def on_refinement_input(self, input: str, index: int): + asyncio.create_task( + self.session.agent.accept_refinement_input(input, index)) @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): await websocket.accept() + print("Session started") session_manager.register_websocket(session.session_id, websocket) - data = await websocket.receive_text() + protocol = NotebookProtocolServer(session) + protocol.websocket = websocket + # Update any history that may have happened before connection - await websocket.send_json({ - "messageType": "state", - "state": session_manager.get_session(session.session_id).agent.get_full_state().dict() - }) - print("Session started", data) + await protocol.send_state_update() + while AppStatus.should_exit is False: - data = await websocket.receive_json() - print("Received data", data) + message = await websocket.receive_text() + print("Received message", message) + if type(message) is str: + message = json.loads(message) - if "messageType" not in data: + if "messageType" not in message or "data" not in message: continue - messageType = data["messageType"] + message_type = message["messageType"] + data = message["data"] - try: - if messageType == "main_input": - # Do something with user input - asyncio.create_task( - session.agent.accept_user_input(data["value"])) - elif messageType == "step_user_input": - asyncio.create_task( - session.agent.give_user_input(data["value"], data["index"])) - elif messageType == "refinement_input": - asyncio.create_task( - session.agent.accept_refinement_input(data["value"], data["index"])) - elif messageType == "reverse": - # Reverse the history to the given index - asyncio.create_task( - session.agent.reverse_to_index(data["index"])) - except Exception as e: - print(e) + protocol.handle_json(message_type, data) print("Closing websocket") await websocket.close() - - -@router.post("/run") -def request_run(step: Step, session=Depends(session)): - """Tell an agent to take a specific action.""" - asyncio.create_task(session.agent.run_from_step(step)) - return "Success" - - -@router.get("/history") -def get_history(session=Depends(session)) -> History: - return session.agent.history - - -@router.post("/observation") -def post_observation(observation: Observation, session=Depends(session)): - asyncio.create_task(session.agent.run_from_observation(observation)) - return "Success" diff --git a/continuedev/src/continuedev/server/notebook_protocol.py b/continuedev/src/continuedev/server/notebook_protocol.py new file mode 100644 index 00000000..c2be82e0 --- /dev/null +++ b/continuedev/src/continuedev/server/notebook_protocol.py @@ -0,0 +1,28 @@ +from typing import Any +from abc import ABC, abstractmethod + + +class AbstractNotebookProtocolServer(ABC): + @abstractmethod + async def handle_json(self, data: Any): + """Handle a json message""" + + @abstractmethod + def on_main_input(self, input: str): + """Called when the user inputs something""" + + @abstractmethod + def on_reverse_to_index(self, index: int): + """Called when the user requests reverse to a previous index""" + + @abstractmethod + def on_refinement_input(self, input: str, index: int): + """Called when the user inputs a refinement""" + + @abstractmethod + def on_step_user_input(self, input: str, index: int): + """Called when the user inputs a step""" + + @abstractmethod + async def send_state_update(self, state: dict): + """Send a state update to the client""" diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py new file mode 100644 index 00000000..c5715034 --- /dev/null +++ b/continuedev/src/continuedev/server/session_manager.py @@ -0,0 +1,86 @@ +from fastapi import WebSocket +from typing import Any, Dict, List, Union +from uuid import uuid4 + +from ..models.filesystem_edit import FileEditWithFullContents +from ..core.policy import DemoPolicy +from ..core.main import FullState +from ..core.agent import Agent +from ..libs.steps.nate import ImplementAbstractMethodStep +from .ide_protocol import AbstractIdeProtocolServer +import asyncio +import nest_asyncio +nest_asyncio.apply() + + +class Session: + session_id: str + agent: Agent + ws: Union[WebSocket, None] + + def __init__(self, session_id: str, agent: Agent): + self.session_id = session_id + self.agent = agent + self.ws = None + + +class DemoAgent(Agent): + first_seen: bool = False + cumulative_edit_string = "" + + def handle_manual_edits(self, edits: List[FileEditWithFullContents]): + for edit in edits: + self.cumulative_edit_string += edit.fileEdit.replacement + self._manual_edits_buffer.append(edit) + # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge. + # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) + # FOR DEMO PURPOSES + if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement: + self.cumulative_edit_string = "" + asyncio.create_task(self.run_from_step( + ImplementAbstractMethodStep())) + + +class SessionManager: + sessions: Dict[str, Session] = {} + _event_loop: Union[asyncio.BaseEventLoop, None] = None + + def get_session(self, session_id: str) -> Session: + if session_id not in self.sessions: + raise KeyError("Session ID not recognized") + return self.sessions[session_id] + + def new_session(self, ide: AbstractIdeProtocolServer) -> str: + agent = DemoAgent(policy=DemoPolicy(), ide=ide) + session_id = str(uuid4()) + session = Session(session_id=session_id, agent=agent) + self.sessions[session_id] = session + + async def on_update(state: FullState): + await session_manager.send_ws_data(session_id, "state_update", { + "state": agent.get_full_state().dict() + }) + + agent.on_update(on_update) + asyncio.create_task(agent.run_policy()) + return session_id + + def remove_session(self, session_id: str): + del self.sessions[session_id] + + def register_websocket(self, session_id: str, ws: WebSocket): + self.sessions[session_id].ws = ws + print("Registered websocket for session", session_id) + + async def send_ws_data(self, session_id: str, message_type: str, data: Any): + if self.sessions[session_id].ws is None: + print(f"Session {session_id} has no websocket") + return + + await self.sessions[session_id].ws.send_json({ + "messageType": message_type, + "data": data + }) + + +session_manager = SessionManager() |