diff options
Diffstat (limited to 'continuedev')
18 files changed, 301 insertions, 238 deletions
| diff --git a/continuedev/src/continuedev/core/agent.py b/continuedev/src/continuedev/core/agent.py index 329e3d4c..7f7466a2 100644 --- a/continuedev/src/continuedev/core/agent.py +++ b/continuedev/src/continuedev/core/agent.py @@ -13,7 +13,6 @@ from .sdk import ContinueSDK  class Agent(ContinueBaseModel): -    llm: LLM      policy: Policy      ide: AbstractIdeProtocolServer      history: History = History.from_empty() @@ -31,27 +30,24 @@ class Agent(ContinueBaseModel):      def get_full_state(self) -> FullState:          return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue) -    def on_update(self, callback: Callable[["FullState"], None]): +    def on_update(self, callback: Coroutine["FullState", None, None]):          """Subscribe to changes to state"""          self._on_update_callbacks.append(callback) -    def update_subscribers(self): +    async def update_subscribers(self):          full_state = self.get_full_state()          for callback in self._on_update_callbacks: -            callback(full_state) - -    def __get_step_params(self, step: "Step"): -        return ContinueSDK(agent=self, llm=self.llm.with_system_message(step.system_message)) +            await callback(full_state)      def give_user_input(self, input: str, index: int): -        self._user_input_queue.post(index, input) +        self._user_input_queue.post(str(index), input)      async def wait_for_user_input(self) -> str:          self._active = False -        self.update_subscribers() -        user_input = await self._user_input_queue.get(self.history.current_index) +        await self.update_subscribers() +        user_input = await self._user_input_queue.get(str(self.history.current_index))          self._active = True -        self.update_subscribers() +        await self.update_subscribers()          return user_input      _manual_edits_buffer: List[FileEditWithFullContents] = [] @@ -62,9 +58,9 @@ class Agent(ContinueBaseModel):                  current_step = self.history.get_current().step                  self.history.step_back()                  if issubclass(current_step.__class__, ReversibleStep): -                    await current_step.reverse(self.__get_step_params(current_step)) +                    await current_step.reverse(ContinueSDK(self)) -                self.update_subscribers() +                await self.update_subscribers()          except Exception as e:              print(e) @@ -94,17 +90,17 @@ class Agent(ContinueBaseModel):          # Run step          self._step_depth += 1 -        observation = await step(self.__get_step_params(step)) +        observation = await step(ContinueSDK(self))          self._step_depth -= 1          # Add observation to history          self.history.get_current().observation = observation          # Update its description -        step._set_description(await step.describe(self.llm)) +        step._set_description(await step.describe(ContinueSDK(self)))          # Call all subscribed callbacks -        self.update_subscribers() +        await self.update_subscribers()          return observation @@ -138,7 +134,7 @@ class Agent(ContinueBaseModel):          # Doing this so active can make it to the frontend after steps are done. But want better state syncing tools          for callback in self._on_update_callbacks: -            callback(None) +            await callback(None)      async def run_from_observation(self, observation: Observation):          next_step = self.policy.next(self.history) @@ -158,7 +154,7 @@ class Agent(ContinueBaseModel):      async def accept_user_input(self, user_input: str):          self._main_user_input_queue.append(user_input) -        self.update_subscribers() +        await self.update_subscribers()          if len(self._main_user_input_queue) > 1:              return @@ -167,7 +163,7 @@ class Agent(ContinueBaseModel):          # Just run the step that takes user input, and          # then up to the policy to decide how to deal with it.          self._main_user_input_queue.pop(0) -        self.update_subscribers() +        await self.update_subscribers()          await self.run_from_step(UserInputStep(user_input=user_input))          while len(self._main_user_input_queue) > 0: diff --git a/continuedev/src/continuedev/core/env.py b/continuedev/src/continuedev/core/env.py index 6267ed60..edd3297c 100644 --- a/continuedev/src/continuedev/core/env.py +++ b/continuedev/src/continuedev/core/env.py @@ -8,6 +8,10 @@ def get_env_var(var_name: str):  def save_env_var(var_name: str, var_value: str): +    if not os.path.exists('.env'): +        with open('.env', 'w') as f: +            f.write(f'{var_name}="{var_value}"\n') +        return      with open('.env', 'r') as f:          lines = f.readlines()      with open('.env', 'w') as f: diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 51fcd299..6be5139b 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -72,7 +72,7 @@ class ContinueSDK:      pass -class SequentialStep: +class Models:      pass @@ -94,7 +94,7 @@ class Step(ContinueBaseModel):      class Config:          copy_on_model_validation = False -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          if self._description is not None:              return self._description          return "Running step: " + self.name @@ -135,6 +135,16 @@ class Step(ContinueBaseModel):          return SequentialStep(steps=steps) +class SequentialStep(Step): +    steps: list[Step] +    hide: bool = True + +    async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: +        for step in self.steps: +            observation = await sdk.run_step(step) +        return observation + +  class ValidatorObservation(Observation):      passed: bool      observation: Observation diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index 07101576..c0ba0f4f 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -1,20 +1,16 @@  from typing import List, Tuple, Type - -from ..models.main import ContinueBaseModel -  from ..libs.steps.ty import CreatePipelineStep  from .main import Step, Validator, History, Policy  from .observation import Observation, TracebackObservation, UserInputObservation  from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep  from ..libs.steps.nate import WritePytestsStep, CreateTableStep -from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma +# from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma  from ..libs.steps.continue_step import ContinueStepStep  class DemoPolicy(Policy):      ran_code_last: bool = False -    cmd: str      def next(self, history: History) -> Step:          observation = history.last_observation() @@ -26,18 +22,15 @@ class DemoPolicy(Policy):                  return CreatePipelineStep()              elif "/table" in observation.user_input:                  return CreateTableStep(sql_str=" ".join(observation.user_input.split(" ")[1:])) -            elif "/ask" in observation.user_input: -                return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:])) -            elif "/edit" in observation.user_input: -                return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:])) +            # elif "/ask" in observation.user_input: +            #     return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:])) +            # elif "/edit" in observation.user_input: +            #     return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:]))              elif "/step" in observation.user_input:                  return ContinueStepStep(prompt=" ".join(observation.user_input.split(" ")[1:]))              return StarCoderEditHighlightedCodeStep(user_input=observation.user_input)          state = history.get_current() -        if state is None or not self.ran_code_last: -            self.ran_code_last = True -            return RunCodeStep(cmd=self.cmd)          if observation is not None and isinstance(observation, TracebackObservation):              self.ran_code_last = False diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 750b335d..6ae0be04 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -4,6 +4,7 @@ from ..models.filesystem_edit import FileSystemEdit, AddFile, DeleteFile, AddDir  from ..models.filesystem import RangeInFile  from ..libs.llm import LLM  from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI +from ..libs.llm.openai import OpenAI  from .observation import Observation  from ..server.ide_protocol import AbstractIdeProtocolServer  from .main import History, Step @@ -29,20 +30,20 @@ class Models:              'HUGGING_FACE_TOKEN', 'Please enter your Hugging Face token')          return HuggingFaceInferenceAPI(api_key=api_key) +    async def gpt35(self): +        api_key = await self.sdk.get_user_secret( +            'OPENAI_API_KEY', 'Please enter your OpenAI API key') +        return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo") +  class ContinueSDK:      """The SDK provided as parameters to a step""" -    llm: LLM      ide: AbstractIdeProtocolServer      steps: ContinueSDKSteps      models: Models      __agent: Agent -    def __init__(self, agent: Agent, llm: Union[LLM, None] = None): -        if llm is None: -            self.llm = agent.llm -        else: -            self.llm = llm +    def __init__(self, agent: Agent):          self.ide = agent.ide          self.__agent = agent          self.steps = ContinueSDKSteps(self) diff --git a/continuedev/src/continuedev/libs/steps/chroma.py b/continuedev/src/continuedev/libs/steps/chroma.py index f13a2bab..39424c5c 100644 --- a/continuedev/src/continuedev/libs/steps/chroma.py +++ b/continuedev/src/continuedev/libs/steps/chroma.py @@ -40,7 +40,7 @@ class AnswerQuestionChroma(Step):              Here is the answer:""") -        answer = sdk.llm.complete(prompt) +        answer = (await sdk.models.gpt35()).complete(prompt)          print(answer)          self._answer = answer diff --git a/continuedev/src/continuedev/libs/steps/core/core.py b/continuedev/src/continuedev/libs/steps/core/core.py index 0338d635..14b3cb80 100644 --- a/continuedev/src/continuedev/libs/steps/core/core.py +++ b/continuedev/src/continuedev/libs/steps/core/core.py @@ -4,27 +4,18 @@ from textwrap import dedent  from typing import Coroutine, List, Union  from ...llm.prompt_utils import MarkdownStyleEncoderDecoder -from ...util.traceback_parsers import parse_python_traceback -  from ....models.filesystem_edit import EditDiff, FileEditWithFullContents, FileSystemEdit  from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents -from ...llm import LLM  from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation -from ....core.main import Step +from ....core.main import Step, SequentialStep  class ContinueSDK:      pass -class SequentialStep(Step): -    steps: list[Step] -    hide: bool = True - -    async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: -        for step in self.steps: -            observation = await sdk.run_step(step) -        return observation +class Models: +    pass  class ReversibleStep(Step): @@ -52,7 +43,7 @@ def ShellCommandsStep(Step):      cwd: str | None = None      name: str = "Run Shell Commands" -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return "\n".join(self.cmds)      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -81,13 +72,13 @@ class EditCodeStep(Step):      _prompt: Union[str, None] = None      _completion: Union[str, None] = None -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          if self._edit_diffs is None:              return "Editing files: " + ", ".join(map(lambda rif: rif.filepath, self.range_in_files))          elif len(self._edit_diffs) == 0:              return "No edits made"          else: -            return llm.complete(dedent(f"""{self._prompt}{self._completion} +            return (await models.gpt35()).complete(dedent(f"""{self._prompt}{self._completion}                  Maximally concise summary of changes in bullet points (can use markdown):              """)) @@ -102,7 +93,7 @@ class EditCodeStep(Step):          code_string = enc_dec.encode()          prompt = self.prompt.format(code=code_string) -        completion = sdk.llm.complete(prompt) +        completion = (await sdk.models.gpt35()).complete(prompt)          # Temporarily doing this to generate description.          self._prompt = prompt @@ -127,7 +118,7 @@ class EditFileStep(Step):      prompt: str      hide: bool = True -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return "Editing file: " + self.filepath      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -145,7 +136,7 @@ class ManualEditStep(ReversibleStep):      hide: bool = True -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return "Manual edit step"          # TODO - only handling FileEdit here, but need all other types of FileSystemEdits          # Also requires the merge_file_edit function @@ -181,7 +172,7 @@ class UserInputStep(Step):      name: str = "User Input"      hide: bool = True -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return self.user_input      async def run(self, sdk: ContinueSDK) -> Coroutine[UserInputObservation, None, None]: @@ -194,7 +185,7 @@ class WaitForUserInputStep(Step):      _description: Union[str, None] = None -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return self.prompt      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -207,7 +198,7 @@ class WaitForUserConfirmationStep(Step):      prompt: str      name: str = "Waiting for user confirmation" -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return self.prompt      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: diff --git a/continuedev/src/continuedev/libs/steps/draft/dlt.py b/continuedev/src/continuedev/libs/steps/draft/dlt.py index 5ba5692a..460aa0cc 100644 --- a/continuedev/src/continuedev/libs/steps/draft/dlt.py +++ b/continuedev/src/continuedev/libs/steps/draft/dlt.py @@ -10,7 +10,7 @@ class SetupPipelineStep(Step):      api_description: str  # e.g. "I want to load data from the weatherapi.com API"      async def run(self, sdk: ContinueSDK): -        source_name = sdk.llm.complete( +        source_name = (await sdk.models.gpt35()).complete(              f"Write a snake_case name for the data source described by {self.api_description}: ").strip()          filename = f'{source_name}.py' diff --git a/continuedev/src/continuedev/libs/steps/main.py b/continuedev/src/continuedev/libs/steps/main.py index c8a85800..70c0d4b8 100644 --- a/continuedev/src/continuedev/libs/steps/main.py +++ b/continuedev/src/continuedev/libs/steps/main.py @@ -11,7 +11,7 @@ from ...core.observation import Observation, TextObservation, TracebackObservati  from ..llm.prompt_utils import MarkdownStyleEncoderDecoder  from textwrap import dedent  from ...core.main import Step -from ...core.sdk import ContinueSDK +from ...core.sdk import ContinueSDK, Models  from ...core.observation import Observation  import subprocess  from .core.core import EditCodeStep @@ -20,7 +20,7 @@ from .core.core import EditCodeStep  class RunCodeStep(Step):      cmd: str -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return f"Ran command: `{self.cmd}`"      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -59,7 +59,7 @@ class RunCommandStep(Step):      name: str = "Run command"      _description: str = None -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          if self._description is not None:              return self._description          return self.cmd @@ -125,7 +125,7 @@ class FasterEditHighlightedCodeStep(Step):          Here is the description of changes to make:  """) -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return "Editing highlighted code"      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -154,7 +154,7 @@ class FasterEditHighlightedCodeStep(Step):          for rif in rif_with_contents:              rif_dict[rif.filepath] = rif.contents -        completion = sdk.llm.complete(prompt) +        completion = (await sdk.models.gpt35()).complete(prompt)          # Temporarily doing this to generate description.          self._prompt = prompt @@ -215,10 +215,10 @@ class FasterEditHighlightedCodeStep(Step):  class StarCoderEditHighlightedCodeStep(Step):      user_input: str -    hide = True +    hide = False      _prompt: str = "<commit_before>{code}<commit_msg>{user_request}<commit_after>" -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return "Editing highlighted code"      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -271,7 +271,7 @@ This is the user request:  This is the code after being changed to perfectly satisfy the user request:      """) -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return "Editing highlighted code"      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -293,7 +293,7 @@ This is the code after being changed to perfectly satisfy the user request:  class FindCodeStep(Step):      prompt: str -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return "Finding code"      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -307,7 +307,7 @@ class UserInputStep(Step):  class SolveTracebackStep(Step):      traceback: Traceback -    async def describe(self, llm: LLM) -> Coroutine[str, None, None]: +    async def describe(self, models: Models) -> Coroutine[str, None, None]:          return f"```\n{self.traceback.full_traceback}\n```"      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: diff --git a/continuedev/src/continuedev/libs/steps/migration.py b/continuedev/src/continuedev/libs/steps/migration.py index f044a60f..7b70422d 100644 --- a/continuedev/src/continuedev/libs/steps/migration.py +++ b/continuedev/src/continuedev/libs/steps/migration.py @@ -15,7 +15,7 @@ class MigrationStep(Step):          recent_edits = await sdk.ide.get_recent_edits(self.edited_file)          recent_edits_string = "\n\n".join(              map(lambda x: x.to_string(), recent_edits)) -        description = await sdk.llm.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") +        description = await (await sdk.models.gpt35()).complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")          await sdk.run_step(RunCommandStep(cmd=f"cd libs && poetry run alembic revision --autogenerate -m {description}"))          migration_file = f"libs/alembic/versions/{?}.py"          contents = await sdk.ide.readFile(migration_file) diff --git a/continuedev/src/continuedev/libs/steps/nate.py b/continuedev/src/continuedev/libs/steps/nate.py index a0e728e5..2f84e9d7 100644 --- a/continuedev/src/continuedev/libs/steps/nate.py +++ b/continuedev/src/continuedev/libs/steps/nate.py @@ -45,7 +45,7 @@ Here are additional instructions:  Here is a complete set of pytest unit tests:          """) -        # tests = sdk.llm.complete(prompt) +        # tests = (await sdk.models.gpt35()).complete(prompt)          tests = '''  import pytest @@ -169,9 +169,9 @@ export class Order {    tracking_number: string;  }'''          time.sleep(2) -        # orm_entity = sdk.llm.complete( +        # orm_entity = (await sdk.models.gpt35()).complete(          #     f"{self.sql_str}\n\nWrite a TypeORM entity called {entity_name} for this table, importing as necessary:") -        # sdk.llm.complete("What is the name of the entity?") +        # (await sdk.models.gpt35()).complete("What is the name of the entity?")          await sdk.apply_filesystem_edit(AddFile(filepath=f"/Users/natesesti/Desktop/continue/extension/examples/python/MyProject/src/entity/{entity_name}.ts", content=orm_entity))          await sdk.ide.setFileOpen(f"/Users/natesesti/Desktop/continue/extension/examples/python/MyProject/src/entity/{entity_name}.ts", True) diff --git a/continuedev/src/continuedev/libs/steps/pytest.py b/continuedev/src/continuedev/libs/steps/pytest.py index b4e6dfd2..2e83ae2d 100644 --- a/continuedev/src/continuedev/libs/steps/pytest.py +++ b/continuedev/src/continuedev/libs/steps/pytest.py @@ -33,5 +33,5 @@ class WritePytestsStep(Step):              Here is a complete set of pytest unit tests:          """) -        tests = sdk.llm.complete(prompt) +        tests = (await sdk.models.gpt35()).complete(prompt)          await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/libs/steps/ty.py b/continuedev/src/continuedev/libs/steps/ty.py index 5ff03f04..9dde7c86 100644 --- a/continuedev/src/continuedev/libs/steps/ty.py +++ b/continuedev/src/continuedev/libs/steps/ty.py @@ -18,7 +18,7 @@ class SetupPipelineStep(Step):      api_description: str  # e.g. "I want to load data from the weatherapi.com API"      async def run(self, sdk: ContinueSDK): -        # source_name = sdk.llm.complete( +        # source_name = (await sdk.models.gpt35()).complete(          #     f"Write a snake_case name for the data source described by {self.api_description}: ").strip()          filename = f'/Users/natesesti/Desktop/continue/extension/examples/python/{source_name}.py' diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index dd1dc463..50296841 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -1,5 +1,6 @@  # This is a separate server from server/main.py  import asyncio +import json  import os  from typing import Any, Dict, List, Type, TypeVar, Union  import uuid @@ -90,31 +91,33 @@ class IdeProtocolServer(AbstractIdeProtocolServer):      def __init__(self, session_manager: SessionManager):          self.session_manager = session_manager -    async def _send_json(self, data: Any): -        await self.websocket.send_json(data) +    async def _send_json(self, message_type: str, data: Any): +        await self.websocket.send_json({ +            "messageType": message_type, +            "data": data +        })      async def _receive_json(self, message_type: str) -> Any:          return await self.sub_queue.get(message_type)      async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: -        await self._send_json(data) +        await self._send_json(message_type, data)          resp = await self._receive_json(message_type)          return resp_model.parse_obj(resp) -    async def handle_json(self, data: Any): -        t = data["messageType"] -        if t == "openNotebook": +    async def handle_json(self, message_type: str, data: Any): +        if message_type == "openNotebook":              await self.openNotebook() -        elif t == "setFileOpen": +        elif message_type == "setFileOpen":              await self.setFileOpen(data["filepath"], data["open"]) -        elif t == "fileEdits": +        elif message_type == "fileEdits":              fileEdits = list(                  map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"]))              self.onFileEdits(fileEdits) -        elif t in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]: -            self.sub_queue.post(t, data) +        elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]: +            self.sub_queue.post(message_type, data)          else: -            raise ValueError("Unknown message type", t) +            raise ValueError("Unknown message type", message_type)      # ------------------------------- #      # Request actions in IDE, doesn't matter which Session @@ -123,24 +126,21 @@ class IdeProtocolServer(AbstractIdeProtocolServer):      async def setFileOpen(self, filepath: str, open: bool = True):          # Agent needs access to this. -        await self.websocket.send_json({ -            "messageType": "setFileOpen", +        await self._send_json("setFileOpen", {              "filepath": filepath,              "open": open          })      async def openNotebook(self):          session_id = self.session_manager.new_session(self) -        await self._send_json({ -            "messageType": "openNotebook", +        await self._send_json("openNotebook", {              "sessionId": session_id          })      async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool:          ids = [str(uuid.uuid4()) for _ in suggestions]          for i in range(len(suggestions)): -            self._send_json({ -                "messageType": "showSuggestion", +            self._send_json("showSuggestion", {                  "suggestion": suggestions[i],                  "suggestionId": ids[i]              }) @@ -210,8 +210,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):      async def saveFile(self, filepath: str):          """Save a file""" -        await self._send_json({ -            "messageType": "saveFile", +        await self._send_json("saveFile", {              "filepath": filepath          }) @@ -293,10 +292,17 @@ ideProtocolServer = IdeProtocolServer(session_manager)  async def websocket_endpoint(websocket: WebSocket):      await websocket.accept()      print("Accepted websocket connection from, ", websocket.client) -    await websocket.send_json({"messageType": "connected"}) +    await websocket.send_json({"messageType": "connected", "data": {}})      ideProtocolServer.websocket = websocket      while True: -        data = await websocket.receive_json() -        await ideProtocolServer.handle_json(data) +        message = await websocket.receive_text() +        message = json.loads(message) + +        if "messageType" not in message or "data" not in message: +            continue +        message_type = message["messageType"] +        data = message["data"] + +        await ideProtocolServer.handle_json(message_type, data)      await websocket.close() diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index 11ad1d8f..e87d5fa9 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -32,7 +32,7 @@ args = parser.parse_args()  def run_server(): -    uvicorn.run(app, host="0.0.0.0", port=args.port, log_config="logging.ini") +    uvicorn.run(app, host="0.0.0.0", port=args.port)  if __name__ == "__main__": diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py index c5dcea31..edb61a45 100644 --- a/continuedev/src/continuedev/server/notebook.py +++ b/continuedev/src/continuedev/server/notebook.py @@ -1,18 +1,12 @@ -from fastapi import FastAPI, Depends, Header, WebSocket, APIRouter -from typing import Any, Dict, List, Union -from uuid import uuid4 +import json +from fastapi import Depends, Header, WebSocket, APIRouter +from typing import Any, Type, TypeVar, Union  from pydantic import BaseModel  from uvicorn.main import Server -from ..models.filesystem_edit import FileEditWithFullContents -from ..core.policy import DemoPolicy -from ..core.main import FullState, History, Step -from ..core.agent import Agent -from ..libs.steps.nate import ImplementAbstractMethodStep -from ..core.observation import Observation -from ..libs.llm.openai import OpenAI -from .ide_protocol import AbstractIdeProtocolServer -from ..core.env import get_env_var +from .session_manager import SessionManager, session_manager, Session +from .notebook_protocol import AbstractNotebookProtocolServer +from ..libs.util.queue import AsyncSubscriptionQueue  import asyncio  import nest_asyncio  nest_asyncio.apply() @@ -36,160 +30,99 @@ class AppStatus:  Server.handle_exit = AppStatus.handle_exit -class Session: -    session_id: str -    agent: Agent -    ws: Union[WebSocket, None] - -    def __init__(self, session_id: str, agent: Agent): -        self.session_id = session_id -        self.agent = agent -        self.ws = None - - -class DemoAgent(Agent): -    first_seen: bool = False -    cumulative_edit_string = "" - -    def handle_manual_edits(self, edits: List[FileEditWithFullContents]): -        for edit in edits: -            self.cumulative_edit_string += edit.fileEdit.replacement -            self._manual_edits_buffer.append(edit) -            # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge. -            # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) -            # FOR DEMO PURPOSES -            if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement: -                self.cumulative_edit_string = "" -                asyncio.create_task(self.run_from_step( -                    ImplementAbstractMethodStep())) - - -class SessionManager: -    sessions: Dict[str, Session] = {} -    _event_loop: Union[asyncio.BaseEventLoop, None] = None - -    def get_session(self, session_id: str) -> Session: -        if session_id not in self.sessions: -            raise KeyError("Session ID not recognized") -        return self.sessions[session_id] - -    def new_session(self, ide: AbstractIdeProtocolServer) -> str: -        cmd = "python3 /Users/natesesti/Desktop/continue/extension/examples/python/main.py" -        agent = DemoAgent(llm=OpenAI(api_key=get_env_var("OPENAI_API_KEY")), -                          policy=DemoPolicy(cmd=cmd), ide=ide) -        session_id = str(uuid4()) -        session = Session(session_id=session_id, agent=agent) -        self.sessions[session_id] = session +def session(x_continue_session_id: str = Header("anonymous")) -> Session: +    return session_manager.get_session(x_continue_session_id) -        def on_update(state: FullState): -            session_manager.send_ws_data(session_id, { -                "messageType": "state", -                "state": agent.get_full_state().dict() -            }) -        agent.on_update(on_update) -        asyncio.create_task(agent.run_policy()) -        return session_id +def websocket_session(session_id: str) -> Session: +    return session_manager.get_session(session_id) -    def remove_session(self, session_id: str): -        del self.sessions[session_id] -    def register_websocket(self, session_id: str, ws: WebSocket): -        self.sessions[session_id].ws = ws -        print("Registered websocket for session", session_id) +T = TypeVar("T", bound=BaseModel) -    def send_ws_data(self, session_id: str, data: Any): -        if self.sessions[session_id].ws is None: -            print(f"Session {session_id} has no websocket") -            return +# You should probably abstract away the websocket stuff into a separate class -        async def a(): -            await self.sessions[session_id].ws.send_json(data) -        # Run coroutine in background -        if self._event_loop is None or self._event_loop.is_closed(): -            self._event_loop = asyncio.new_event_loop() -            self._event_loop.run_until_complete(a()) -            self._event_loop.close() -        else: -            self._event_loop.run_until_complete(a()) -            self._event_loop.close() +class NotebookProtocolServer(AbstractNotebookProtocolServer): +    websocket: WebSocket +    session: Session +    sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() +    def __init__(self, session: Session): +        self.session = session -session_manager = SessionManager() +    async def _send_json(self, data: Any): +        await self.websocket.send_json(data) +    async def _receive_json(self, message_type: str) -> Any: +        return await self.sub_queue.get(message_type) -def session(x_continue_session_id: str = Header("anonymous")) -> Session: -    return session_manager.get_session(x_continue_session_id) +    async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: +        await self._send_json(data) +        resp = await self._receive_json(message_type) +        return resp_model.parse_obj(resp) +    def handle_json(self, message_type: str, data: Any): +        try: +            if message_type == "main_input": +                self.on_main_input(data["input"]) +            elif message_type == "step_user_input": +                self.on_step_user_input(data["input"], data["index"]) +            elif message_type == "refinement_input": +                self.on_refinement_input(data["input"], data["index"]) +            elif message_type == "reverse_to_index": +                self.on_reverse_to_index(data["index"]) +        except Exception as e: +            print(e) -def websocket_session(session_id: str) -> Session: -    return session_manager.get_session(session_id) +    async def send_state_update(self): +        state = self.session.agent.get_full_state().dict() +        await self._send_json({ +            "messageType": "state_update", +            "state": state +        }) +    def on_main_input(self, input: str): +        # Do something with user input +        asyncio.create_task(self.session.agent.accept_user_input(input)) -class StartSessionBody(BaseModel): -    config_file_path: Union[str, None] +    def on_reverse_to_index(self, index: int): +        # Reverse the history to the given index +        asyncio.create_task(self.session.agent.reverse_to_index(index)) +    def on_step_user_input(self, input: str, index: int): +        asyncio.create_task( +            self.session.agent.give_user_input(input, index)) -class StartSessionResp(BaseModel): -    session_id: str +    def on_refinement_input(self, input: str, index: int): +        asyncio.create_task( +            self.session.agent.accept_refinement_input(input, index))  @router.websocket("/ws")  async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):      await websocket.accept() +    print("Session started")      session_manager.register_websocket(session.session_id, websocket) -    data = await websocket.receive_text() +    protocol = NotebookProtocolServer(session) +    protocol.websocket = websocket +      # Update any history that may have happened before connection -    await websocket.send_json({ -        "messageType": "state", -        "state": session_manager.get_session(session.session_id).agent.get_full_state().dict() -    }) -    print("Session started", data) +    await protocol.send_state_update() +      while AppStatus.should_exit is False: -        data = await websocket.receive_json() -        print("Received data", data) +        message = await websocket.receive_json() +        print("Received message", message) +        if type(message) is str: +            message = json.loads(message) -        if "messageType" not in data: +        if "messageType" not in message or "data" not in message:              continue -        messageType = data["messageType"] +        message_type = message["messageType"] +        data = message["data"] -        try: -            if messageType == "main_input": -                # Do something with user input -                asyncio.create_task( -                    session.agent.accept_user_input(data["value"])) -            elif messageType == "step_user_input": -                asyncio.create_task( -                    session.agent.give_user_input(data["value"], data["index"])) -            elif messageType == "refinement_input": -                asyncio.create_task( -                    session.agent.accept_refinement_input(data["value"], data["index"])) -            elif messageType == "reverse": -                # Reverse the history to the given index -                asyncio.create_task( -                    session.agent.reverse_to_index(data["index"])) -        except Exception as e: -            print(e) +        protocol.handle_json(message_type, data)      print("Closing websocket")      await websocket.close() - - -@router.post("/run") -def request_run(step: Step, session=Depends(session)): -    """Tell an agent to take a specific action.""" -    asyncio.create_task(session.agent.run_from_step(step)) -    return "Success" - - -@router.get("/history") -def get_history(session=Depends(session)) -> History: -    return session.agent.history - - -@router.post("/observation") -def post_observation(observation: Observation, session=Depends(session)): -    asyncio.create_task(session.agent.run_from_observation(observation)) -    return "Success" diff --git a/continuedev/src/continuedev/server/notebook_protocol.py b/continuedev/src/continuedev/server/notebook_protocol.py new file mode 100644 index 00000000..c2be82e0 --- /dev/null +++ b/continuedev/src/continuedev/server/notebook_protocol.py @@ -0,0 +1,28 @@ +from typing import Any +from abc import ABC, abstractmethod + + +class AbstractNotebookProtocolServer(ABC): +    @abstractmethod +    async def handle_json(self, data: Any): +        """Handle a json message""" + +    @abstractmethod +    def on_main_input(self, input: str): +        """Called when the user inputs something""" + +    @abstractmethod +    def on_reverse_to_index(self, index: int): +        """Called when the user requests reverse to a previous index""" + +    @abstractmethod +    def on_refinement_input(self, input: str, index: int): +        """Called when the user inputs a refinement""" + +    @abstractmethod +    def on_step_user_input(self, input: str, index: int): +        """Called when the user inputs a step""" + +    @abstractmethod +    async def send_state_update(self, state: dict): +        """Send a state update to the client""" diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py new file mode 100644 index 00000000..b48c21b7 --- /dev/null +++ b/continuedev/src/continuedev/server/session_manager.py @@ -0,0 +1,101 @@ +from fastapi import WebSocket +from typing import Any, Dict, List, Union +from uuid import uuid4 + +from ..models.filesystem_edit import FileEditWithFullContents +from ..core.policy import DemoPolicy +from ..core.main import FullState +from ..core.agent import Agent +from ..libs.steps.nate import ImplementAbstractMethodStep +from .ide_protocol import AbstractIdeProtocolServer +import asyncio +import nest_asyncio +nest_asyncio.apply() + + +class Session: +    session_id: str +    agent: Agent +    ws: Union[WebSocket, None] + +    def __init__(self, session_id: str, agent: Agent): +        self.session_id = session_id +        self.agent = agent +        self.ws = None + + +class DemoAgent(Agent): +    first_seen: bool = False +    cumulative_edit_string = "" + +    def handle_manual_edits(self, edits: List[FileEditWithFullContents]): +        for edit in edits: +            self.cumulative_edit_string += edit.fileEdit.replacement +            self._manual_edits_buffer.append(edit) +            # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge. +            # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) +            # FOR DEMO PURPOSES +            if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement: +                self.cumulative_edit_string = "" +                asyncio.create_task(self.run_from_step( +                    ImplementAbstractMethodStep())) + + +class SessionManager: +    sessions: Dict[str, Session] = {} +    _event_loop: Union[asyncio.BaseEventLoop, None] = None + +    def get_session(self, session_id: str) -> Session: +        if session_id not in self.sessions: +            raise KeyError("Session ID not recognized") +        return self.sessions[session_id] + +    def new_session(self, ide: AbstractIdeProtocolServer) -> str: +        agent = DemoAgent(policy=DemoPolicy(), ide=ide) +        session_id = str(uuid4()) +        session = Session(session_id=session_id, agent=agent) +        self.sessions[session_id] = session + +        async def on_update(state: FullState): +            await session_manager.send_ws_data(session_id, "state_update", { +                "state": agent.get_full_state().dict() +            }) + +        agent.on_update(on_update) +        asyncio.create_task(agent.run_policy()) +        return session_id + +    def remove_session(self, session_id: str): +        del self.sessions[session_id] + +    def register_websocket(self, session_id: str, ws: WebSocket): +        self.sessions[session_id].ws = ws +        print("Registered websocket for session", session_id) + +    async def send_ws_data(self, session_id: str, message_type: str, data: Any): +        if self.sessions[session_id].ws is None: +            print(f"Session {session_id} has no websocket") +            return + +        async def a(): +            await self.sessions[session_id].ws.send_json({ +                "messageType": message_type, +                "data": data +            }) + +        # Run coroutine in background +        await self.sessions[session_id].ws.send_json({ +            "messageType": message_type, +            "data": data +        }) +        return +        if self._event_loop is None or self._event_loop.is_closed(): +            self._event_loop = asyncio.new_event_loop() +            self._event_loop.run_until_complete(a()) +            self._event_loop.close() +        else: +            self._event_loop.run_until_complete(a()) +            self._event_loop.close() + + +session_manager = SessionManager() | 
