summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/agent.py34
-rw-r--r--continuedev/src/continuedev/core/env.py4
-rw-r--r--continuedev/src/continuedev/core/main.py14
-rw-r--r--continuedev/src/continuedev/core/policy.py17
-rw-r--r--continuedev/src/continuedev/core/sdk.py13
-rw-r--r--continuedev/src/continuedev/libs/steps/chroma.py2
-rw-r--r--continuedev/src/continuedev/libs/steps/core/core.py33
-rw-r--r--continuedev/src/continuedev/libs/steps/draft/dlt.py2
-rw-r--r--continuedev/src/continuedev/libs/steps/main.py20
-rw-r--r--continuedev/src/continuedev/libs/steps/migration.py2
-rw-r--r--continuedev/src/continuedev/libs/steps/nate.py6
-rw-r--r--continuedev/src/continuedev/libs/steps/pytest.py2
-rw-r--r--continuedev/src/continuedev/libs/steps/ty.py2
-rw-r--r--continuedev/src/continuedev/server/ide.py50
-rw-r--r--continuedev/src/continuedev/server/main.py2
-rw-r--r--continuedev/src/continuedev/server/notebook.py207
-rw-r--r--continuedev/src/continuedev/server/notebook_protocol.py28
-rw-r--r--continuedev/src/continuedev/server/session_manager.py101
18 files changed, 301 insertions, 238 deletions
diff --git a/continuedev/src/continuedev/core/agent.py b/continuedev/src/continuedev/core/agent.py
index 329e3d4c..7f7466a2 100644
--- a/continuedev/src/continuedev/core/agent.py
+++ b/continuedev/src/continuedev/core/agent.py
@@ -13,7 +13,6 @@ from .sdk import ContinueSDK
class Agent(ContinueBaseModel):
- llm: LLM
policy: Policy
ide: AbstractIdeProtocolServer
history: History = History.from_empty()
@@ -31,27 +30,24 @@ class Agent(ContinueBaseModel):
def get_full_state(self) -> FullState:
return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue)
- def on_update(self, callback: Callable[["FullState"], None]):
+ def on_update(self, callback: Coroutine["FullState", None, None]):
"""Subscribe to changes to state"""
self._on_update_callbacks.append(callback)
- def update_subscribers(self):
+ async def update_subscribers(self):
full_state = self.get_full_state()
for callback in self._on_update_callbacks:
- callback(full_state)
-
- def __get_step_params(self, step: "Step"):
- return ContinueSDK(agent=self, llm=self.llm.with_system_message(step.system_message))
+ await callback(full_state)
def give_user_input(self, input: str, index: int):
- self._user_input_queue.post(index, input)
+ self._user_input_queue.post(str(index), input)
async def wait_for_user_input(self) -> str:
self._active = False
- self.update_subscribers()
- user_input = await self._user_input_queue.get(self.history.current_index)
+ await self.update_subscribers()
+ user_input = await self._user_input_queue.get(str(self.history.current_index))
self._active = True
- self.update_subscribers()
+ await self.update_subscribers()
return user_input
_manual_edits_buffer: List[FileEditWithFullContents] = []
@@ -62,9 +58,9 @@ class Agent(ContinueBaseModel):
current_step = self.history.get_current().step
self.history.step_back()
if issubclass(current_step.__class__, ReversibleStep):
- await current_step.reverse(self.__get_step_params(current_step))
+ await current_step.reverse(ContinueSDK(self))
- self.update_subscribers()
+ await self.update_subscribers()
except Exception as e:
print(e)
@@ -94,17 +90,17 @@ class Agent(ContinueBaseModel):
# Run step
self._step_depth += 1
- observation = await step(self.__get_step_params(step))
+ observation = await step(ContinueSDK(self))
self._step_depth -= 1
# Add observation to history
self.history.get_current().observation = observation
# Update its description
- step._set_description(await step.describe(self.llm))
+ step._set_description(await step.describe(ContinueSDK(self)))
# Call all subscribed callbacks
- self.update_subscribers()
+ await self.update_subscribers()
return observation
@@ -138,7 +134,7 @@ class Agent(ContinueBaseModel):
# Doing this so active can make it to the frontend after steps are done. But want better state syncing tools
for callback in self._on_update_callbacks:
- callback(None)
+ await callback(None)
async def run_from_observation(self, observation: Observation):
next_step = self.policy.next(self.history)
@@ -158,7 +154,7 @@ class Agent(ContinueBaseModel):
async def accept_user_input(self, user_input: str):
self._main_user_input_queue.append(user_input)
- self.update_subscribers()
+ await self.update_subscribers()
if len(self._main_user_input_queue) > 1:
return
@@ -167,7 +163,7 @@ class Agent(ContinueBaseModel):
# Just run the step that takes user input, and
# then up to the policy to decide how to deal with it.
self._main_user_input_queue.pop(0)
- self.update_subscribers()
+ await self.update_subscribers()
await self.run_from_step(UserInputStep(user_input=user_input))
while len(self._main_user_input_queue) > 0:
diff --git a/continuedev/src/continuedev/core/env.py b/continuedev/src/continuedev/core/env.py
index 6267ed60..edd3297c 100644
--- a/continuedev/src/continuedev/core/env.py
+++ b/continuedev/src/continuedev/core/env.py
@@ -8,6 +8,10 @@ def get_env_var(var_name: str):
def save_env_var(var_name: str, var_value: str):
+ if not os.path.exists('.env'):
+ with open('.env', 'w') as f:
+ f.write(f'{var_name}="{var_value}"\n')
+ return
with open('.env', 'r') as f:
lines = f.readlines()
with open('.env', 'w') as f:
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index 51fcd299..6be5139b 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -72,7 +72,7 @@ class ContinueSDK:
pass
-class SequentialStep:
+class Models:
pass
@@ -94,7 +94,7 @@ class Step(ContinueBaseModel):
class Config:
copy_on_model_validation = False
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
if self._description is not None:
return self._description
return "Running step: " + self.name
@@ -135,6 +135,16 @@ class Step(ContinueBaseModel):
return SequentialStep(steps=steps)
+class SequentialStep(Step):
+ steps: list[Step]
+ hide: bool = True
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ for step in self.steps:
+ observation = await sdk.run_step(step)
+ return observation
+
+
class ValidatorObservation(Observation):
passed: bool
observation: Observation
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index 07101576..c0ba0f4f 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -1,20 +1,16 @@
from typing import List, Tuple, Type
-
-from ..models.main import ContinueBaseModel
-
from ..libs.steps.ty import CreatePipelineStep
from .main import Step, Validator, History, Policy
from .observation import Observation, TracebackObservation, UserInputObservation
from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep
from ..libs.steps.nate import WritePytestsStep, CreateTableStep
-from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma
+# from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma
from ..libs.steps.continue_step import ContinueStepStep
class DemoPolicy(Policy):
ran_code_last: bool = False
- cmd: str
def next(self, history: History) -> Step:
observation = history.last_observation()
@@ -26,18 +22,15 @@ class DemoPolicy(Policy):
return CreatePipelineStep()
elif "/table" in observation.user_input:
return CreateTableStep(sql_str=" ".join(observation.user_input.split(" ")[1:]))
- elif "/ask" in observation.user_input:
- return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:]))
- elif "/edit" in observation.user_input:
- return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:]))
+ # elif "/ask" in observation.user_input:
+ # return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:]))
+ # elif "/edit" in observation.user_input:
+ # return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:]))
elif "/step" in observation.user_input:
return ContinueStepStep(prompt=" ".join(observation.user_input.split(" ")[1:]))
return StarCoderEditHighlightedCodeStep(user_input=observation.user_input)
state = history.get_current()
- if state is None or not self.ran_code_last:
- self.ran_code_last = True
- return RunCodeStep(cmd=self.cmd)
if observation is not None and isinstance(observation, TracebackObservation):
self.ran_code_last = False
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 750b335d..6ae0be04 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -4,6 +4,7 @@ from ..models.filesystem_edit import FileSystemEdit, AddFile, DeleteFile, AddDir
from ..models.filesystem import RangeInFile
from ..libs.llm import LLM
from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
+from ..libs.llm.openai import OpenAI
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
from .main import History, Step
@@ -29,20 +30,20 @@ class Models:
'HUGGING_FACE_TOKEN', 'Please enter your Hugging Face token')
return HuggingFaceInferenceAPI(api_key=api_key)
+ async def gpt35(self):
+ api_key = await self.sdk.get_user_secret(
+ 'OPENAI_API_KEY', 'Please enter your OpenAI API key')
+ return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo")
+
class ContinueSDK:
"""The SDK provided as parameters to a step"""
- llm: LLM
ide: AbstractIdeProtocolServer
steps: ContinueSDKSteps
models: Models
__agent: Agent
- def __init__(self, agent: Agent, llm: Union[LLM, None] = None):
- if llm is None:
- self.llm = agent.llm
- else:
- self.llm = llm
+ def __init__(self, agent: Agent):
self.ide = agent.ide
self.__agent = agent
self.steps = ContinueSDKSteps(self)
diff --git a/continuedev/src/continuedev/libs/steps/chroma.py b/continuedev/src/continuedev/libs/steps/chroma.py
index f13a2bab..39424c5c 100644
--- a/continuedev/src/continuedev/libs/steps/chroma.py
+++ b/continuedev/src/continuedev/libs/steps/chroma.py
@@ -40,7 +40,7 @@ class AnswerQuestionChroma(Step):
Here is the answer:""")
- answer = sdk.llm.complete(prompt)
+ answer = (await sdk.models.gpt35()).complete(prompt)
print(answer)
self._answer = answer
diff --git a/continuedev/src/continuedev/libs/steps/core/core.py b/continuedev/src/continuedev/libs/steps/core/core.py
index 0338d635..14b3cb80 100644
--- a/continuedev/src/continuedev/libs/steps/core/core.py
+++ b/continuedev/src/continuedev/libs/steps/core/core.py
@@ -4,27 +4,18 @@ from textwrap import dedent
from typing import Coroutine, List, Union
from ...llm.prompt_utils import MarkdownStyleEncoderDecoder
-from ...util.traceback_parsers import parse_python_traceback
-
from ....models.filesystem_edit import EditDiff, FileEditWithFullContents, FileSystemEdit
from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
-from ...llm import LLM
from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
-from ....core.main import Step
+from ....core.main import Step, SequentialStep
class ContinueSDK:
pass
-class SequentialStep(Step):
- steps: list[Step]
- hide: bool = True
-
- async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
- for step in self.steps:
- observation = await sdk.run_step(step)
- return observation
+class Models:
+ pass
class ReversibleStep(Step):
@@ -52,7 +43,7 @@ def ShellCommandsStep(Step):
cwd: str | None = None
name: str = "Run Shell Commands"
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "\n".join(self.cmds)
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -81,13 +72,13 @@ class EditCodeStep(Step):
_prompt: Union[str, None] = None
_completion: Union[str, None] = None
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
if self._edit_diffs is None:
return "Editing files: " + ", ".join(map(lambda rif: rif.filepath, self.range_in_files))
elif len(self._edit_diffs) == 0:
return "No edits made"
else:
- return llm.complete(dedent(f"""{self._prompt}{self._completion}
+ return (await models.gpt35()).complete(dedent(f"""{self._prompt}{self._completion}
Maximally concise summary of changes in bullet points (can use markdown):
"""))
@@ -102,7 +93,7 @@ class EditCodeStep(Step):
code_string = enc_dec.encode()
prompt = self.prompt.format(code=code_string)
- completion = sdk.llm.complete(prompt)
+ completion = (await sdk.models.gpt35()).complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
@@ -127,7 +118,7 @@ class EditFileStep(Step):
prompt: str
hide: bool = True
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Editing file: " + self.filepath
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -145,7 +136,7 @@ class ManualEditStep(ReversibleStep):
hide: bool = True
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Manual edit step"
# TODO - only handling FileEdit here, but need all other types of FileSystemEdits
# Also requires the merge_file_edit function
@@ -181,7 +172,7 @@ class UserInputStep(Step):
name: str = "User Input"
hide: bool = True
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return self.user_input
async def run(self, sdk: ContinueSDK) -> Coroutine[UserInputObservation, None, None]:
@@ -194,7 +185,7 @@ class WaitForUserInputStep(Step):
_description: Union[str, None] = None
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return self.prompt
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -207,7 +198,7 @@ class WaitForUserConfirmationStep(Step):
prompt: str
name: str = "Waiting for user confirmation"
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return self.prompt
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
diff --git a/continuedev/src/continuedev/libs/steps/draft/dlt.py b/continuedev/src/continuedev/libs/steps/draft/dlt.py
index 5ba5692a..460aa0cc 100644
--- a/continuedev/src/continuedev/libs/steps/draft/dlt.py
+++ b/continuedev/src/continuedev/libs/steps/draft/dlt.py
@@ -10,7 +10,7 @@ class SetupPipelineStep(Step):
api_description: str # e.g. "I want to load data from the weatherapi.com API"
async def run(self, sdk: ContinueSDK):
- source_name = sdk.llm.complete(
+ source_name = (await sdk.models.gpt35()).complete(
f"Write a snake_case name for the data source described by {self.api_description}: ").strip()
filename = f'{source_name}.py'
diff --git a/continuedev/src/continuedev/libs/steps/main.py b/continuedev/src/continuedev/libs/steps/main.py
index c8a85800..70c0d4b8 100644
--- a/continuedev/src/continuedev/libs/steps/main.py
+++ b/continuedev/src/continuedev/libs/steps/main.py
@@ -11,7 +11,7 @@ from ...core.observation import Observation, TextObservation, TracebackObservati
from ..llm.prompt_utils import MarkdownStyleEncoderDecoder
from textwrap import dedent
from ...core.main import Step
-from ...core.sdk import ContinueSDK
+from ...core.sdk import ContinueSDK, Models
from ...core.observation import Observation
import subprocess
from .core.core import EditCodeStep
@@ -20,7 +20,7 @@ from .core.core import EditCodeStep
class RunCodeStep(Step):
cmd: str
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return f"Ran command: `{self.cmd}`"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -59,7 +59,7 @@ class RunCommandStep(Step):
name: str = "Run command"
_description: str = None
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
if self._description is not None:
return self._description
return self.cmd
@@ -125,7 +125,7 @@ class FasterEditHighlightedCodeStep(Step):
Here is the description of changes to make:
""")
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Editing highlighted code"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -154,7 +154,7 @@ class FasterEditHighlightedCodeStep(Step):
for rif in rif_with_contents:
rif_dict[rif.filepath] = rif.contents
- completion = sdk.llm.complete(prompt)
+ completion = (await sdk.models.gpt35()).complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
@@ -215,10 +215,10 @@ class FasterEditHighlightedCodeStep(Step):
class StarCoderEditHighlightedCodeStep(Step):
user_input: str
- hide = True
+ hide = False
_prompt: str = "<commit_before>{code}<commit_msg>{user_request}<commit_after>"
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Editing highlighted code"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -271,7 +271,7 @@ This is the user request:
This is the code after being changed to perfectly satisfy the user request:
""")
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Editing highlighted code"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -293,7 +293,7 @@ This is the code after being changed to perfectly satisfy the user request:
class FindCodeStep(Step):
prompt: str
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Finding code"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -307,7 +307,7 @@ class UserInputStep(Step):
class SolveTracebackStep(Step):
traceback: Traceback
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return f"```\n{self.traceback.full_traceback}\n```"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
diff --git a/continuedev/src/continuedev/libs/steps/migration.py b/continuedev/src/continuedev/libs/steps/migration.py
index f044a60f..7b70422d 100644
--- a/continuedev/src/continuedev/libs/steps/migration.py
+++ b/continuedev/src/continuedev/libs/steps/migration.py
@@ -15,7 +15,7 @@ class MigrationStep(Step):
recent_edits = await sdk.ide.get_recent_edits(self.edited_file)
recent_edits_string = "\n\n".join(
map(lambda x: x.to_string(), recent_edits))
- description = await sdk.llm.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
+ description = await (await sdk.models.gpt35()).complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
await sdk.run_step(RunCommandStep(cmd=f"cd libs && poetry run alembic revision --autogenerate -m {description}"))
migration_file = f"libs/alembic/versions/{?}.py"
contents = await sdk.ide.readFile(migration_file)
diff --git a/continuedev/src/continuedev/libs/steps/nate.py b/continuedev/src/continuedev/libs/steps/nate.py
index a0e728e5..2f84e9d7 100644
--- a/continuedev/src/continuedev/libs/steps/nate.py
+++ b/continuedev/src/continuedev/libs/steps/nate.py
@@ -45,7 +45,7 @@ Here are additional instructions:
Here is a complete set of pytest unit tests:
""")
- # tests = sdk.llm.complete(prompt)
+ # tests = (await sdk.models.gpt35()).complete(prompt)
tests = '''
import pytest
@@ -169,9 +169,9 @@ export class Order {
tracking_number: string;
}'''
time.sleep(2)
- # orm_entity = sdk.llm.complete(
+ # orm_entity = (await sdk.models.gpt35()).complete(
# f"{self.sql_str}\n\nWrite a TypeORM entity called {entity_name} for this table, importing as necessary:")
- # sdk.llm.complete("What is the name of the entity?")
+ # (await sdk.models.gpt35()).complete("What is the name of the entity?")
await sdk.apply_filesystem_edit(AddFile(filepath=f"/Users/natesesti/Desktop/continue/extension/examples/python/MyProject/src/entity/{entity_name}.ts", content=orm_entity))
await sdk.ide.setFileOpen(f"/Users/natesesti/Desktop/continue/extension/examples/python/MyProject/src/entity/{entity_name}.ts", True)
diff --git a/continuedev/src/continuedev/libs/steps/pytest.py b/continuedev/src/continuedev/libs/steps/pytest.py
index b4e6dfd2..2e83ae2d 100644
--- a/continuedev/src/continuedev/libs/steps/pytest.py
+++ b/continuedev/src/continuedev/libs/steps/pytest.py
@@ -33,5 +33,5 @@ class WritePytestsStep(Step):
Here is a complete set of pytest unit tests:
""")
- tests = sdk.llm.complete(prompt)
+ tests = (await sdk.models.gpt35()).complete(prompt)
await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests))
diff --git a/continuedev/src/continuedev/libs/steps/ty.py b/continuedev/src/continuedev/libs/steps/ty.py
index 5ff03f04..9dde7c86 100644
--- a/continuedev/src/continuedev/libs/steps/ty.py
+++ b/continuedev/src/continuedev/libs/steps/ty.py
@@ -18,7 +18,7 @@ class SetupPipelineStep(Step):
api_description: str # e.g. "I want to load data from the weatherapi.com API"
async def run(self, sdk: ContinueSDK):
- # source_name = sdk.llm.complete(
+ # source_name = (await sdk.models.gpt35()).complete(
# f"Write a snake_case name for the data source described by {self.api_description}: ").strip()
filename = f'/Users/natesesti/Desktop/continue/extension/examples/python/{source_name}.py'
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index dd1dc463..50296841 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -1,5 +1,6 @@
# This is a separate server from server/main.py
import asyncio
+import json
import os
from typing import Any, Dict, List, Type, TypeVar, Union
import uuid
@@ -90,31 +91,33 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
def __init__(self, session_manager: SessionManager):
self.session_manager = session_manager
- async def _send_json(self, data: Any):
- await self.websocket.send_json(data)
+ async def _send_json(self, message_type: str, data: Any):
+ await self.websocket.send_json({
+ "messageType": message_type,
+ "data": data
+ })
async def _receive_json(self, message_type: str) -> Any:
return await self.sub_queue.get(message_type)
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
- await self._send_json(data)
+ await self._send_json(message_type, data)
resp = await self._receive_json(message_type)
return resp_model.parse_obj(resp)
- async def handle_json(self, data: Any):
- t = data["messageType"]
- if t == "openNotebook":
+ async def handle_json(self, message_type: str, data: Any):
+ if message_type == "openNotebook":
await self.openNotebook()
- elif t == "setFileOpen":
+ elif message_type == "setFileOpen":
await self.setFileOpen(data["filepath"], data["open"])
- elif t == "fileEdits":
+ elif message_type == "fileEdits":
fileEdits = list(
map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"]))
self.onFileEdits(fileEdits)
- elif t in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]:
- self.sub_queue.post(t, data)
+ elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]:
+ self.sub_queue.post(message_type, data)
else:
- raise ValueError("Unknown message type", t)
+ raise ValueError("Unknown message type", message_type)
# ------------------------------- #
# Request actions in IDE, doesn't matter which Session
@@ -123,24 +126,21 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
async def setFileOpen(self, filepath: str, open: bool = True):
# Agent needs access to this.
- await self.websocket.send_json({
- "messageType": "setFileOpen",
+ await self._send_json("setFileOpen", {
"filepath": filepath,
"open": open
})
async def openNotebook(self):
session_id = self.session_manager.new_session(self)
- await self._send_json({
- "messageType": "openNotebook",
+ await self._send_json("openNotebook", {
"sessionId": session_id
})
async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool:
ids = [str(uuid.uuid4()) for _ in suggestions]
for i in range(len(suggestions)):
- self._send_json({
- "messageType": "showSuggestion",
+ self._send_json("showSuggestion", {
"suggestion": suggestions[i],
"suggestionId": ids[i]
})
@@ -210,8 +210,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
async def saveFile(self, filepath: str):
"""Save a file"""
- await self._send_json({
- "messageType": "saveFile",
+ await self._send_json("saveFile", {
"filepath": filepath
})
@@ -293,10 +292,17 @@ ideProtocolServer = IdeProtocolServer(session_manager)
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
print("Accepted websocket connection from, ", websocket.client)
- await websocket.send_json({"messageType": "connected"})
+ await websocket.send_json({"messageType": "connected", "data": {}})
ideProtocolServer.websocket = websocket
while True:
- data = await websocket.receive_json()
- await ideProtocolServer.handle_json(data)
+ message = await websocket.receive_text()
+ message = json.loads(message)
+
+ if "messageType" not in message or "data" not in message:
+ continue
+ message_type = message["messageType"]
+ data = message["data"]
+
+ await ideProtocolServer.handle_json(message_type, data)
await websocket.close()
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index 11ad1d8f..e87d5fa9 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -32,7 +32,7 @@ args = parser.parse_args()
def run_server():
- uvicorn.run(app, host="0.0.0.0", port=args.port, log_config="logging.ini")
+ uvicorn.run(app, host="0.0.0.0", port=args.port)
if __name__ == "__main__":
diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py
index c5dcea31..edb61a45 100644
--- a/continuedev/src/continuedev/server/notebook.py
+++ b/continuedev/src/continuedev/server/notebook.py
@@ -1,18 +1,12 @@
-from fastapi import FastAPI, Depends, Header, WebSocket, APIRouter
-from typing import Any, Dict, List, Union
-from uuid import uuid4
+import json
+from fastapi import Depends, Header, WebSocket, APIRouter
+from typing import Any, Type, TypeVar, Union
from pydantic import BaseModel
from uvicorn.main import Server
-from ..models.filesystem_edit import FileEditWithFullContents
-from ..core.policy import DemoPolicy
-from ..core.main import FullState, History, Step
-from ..core.agent import Agent
-from ..libs.steps.nate import ImplementAbstractMethodStep
-from ..core.observation import Observation
-from ..libs.llm.openai import OpenAI
-from .ide_protocol import AbstractIdeProtocolServer
-from ..core.env import get_env_var
+from .session_manager import SessionManager, session_manager, Session
+from .notebook_protocol import AbstractNotebookProtocolServer
+from ..libs.util.queue import AsyncSubscriptionQueue
import asyncio
import nest_asyncio
nest_asyncio.apply()
@@ -36,160 +30,99 @@ class AppStatus:
Server.handle_exit = AppStatus.handle_exit
-class Session:
- session_id: str
- agent: Agent
- ws: Union[WebSocket, None]
-
- def __init__(self, session_id: str, agent: Agent):
- self.session_id = session_id
- self.agent = agent
- self.ws = None
-
-
-class DemoAgent(Agent):
- first_seen: bool = False
- cumulative_edit_string = ""
-
- def handle_manual_edits(self, edits: List[FileEditWithFullContents]):
- for edit in edits:
- self.cumulative_edit_string += edit.fileEdit.replacement
- self._manual_edits_buffer.append(edit)
- # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge.
- # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit)
- # FOR DEMO PURPOSES
- if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement:
- self.cumulative_edit_string = ""
- asyncio.create_task(self.run_from_step(
- ImplementAbstractMethodStep()))
-
-
-class SessionManager:
- sessions: Dict[str, Session] = {}
- _event_loop: Union[asyncio.BaseEventLoop, None] = None
-
- def get_session(self, session_id: str) -> Session:
- if session_id not in self.sessions:
- raise KeyError("Session ID not recognized")
- return self.sessions[session_id]
-
- def new_session(self, ide: AbstractIdeProtocolServer) -> str:
- cmd = "python3 /Users/natesesti/Desktop/continue/extension/examples/python/main.py"
- agent = DemoAgent(llm=OpenAI(api_key=get_env_var("OPENAI_API_KEY")),
- policy=DemoPolicy(cmd=cmd), ide=ide)
- session_id = str(uuid4())
- session = Session(session_id=session_id, agent=agent)
- self.sessions[session_id] = session
+def session(x_continue_session_id: str = Header("anonymous")) -> Session:
+ return session_manager.get_session(x_continue_session_id)
- def on_update(state: FullState):
- session_manager.send_ws_data(session_id, {
- "messageType": "state",
- "state": agent.get_full_state().dict()
- })
- agent.on_update(on_update)
- asyncio.create_task(agent.run_policy())
- return session_id
+def websocket_session(session_id: str) -> Session:
+ return session_manager.get_session(session_id)
- def remove_session(self, session_id: str):
- del self.sessions[session_id]
- def register_websocket(self, session_id: str, ws: WebSocket):
- self.sessions[session_id].ws = ws
- print("Registered websocket for session", session_id)
+T = TypeVar("T", bound=BaseModel)
- def send_ws_data(self, session_id: str, data: Any):
- if self.sessions[session_id].ws is None:
- print(f"Session {session_id} has no websocket")
- return
+# You should probably abstract away the websocket stuff into a separate class
- async def a():
- await self.sessions[session_id].ws.send_json(data)
- # Run coroutine in background
- if self._event_loop is None or self._event_loop.is_closed():
- self._event_loop = asyncio.new_event_loop()
- self._event_loop.run_until_complete(a())
- self._event_loop.close()
- else:
- self._event_loop.run_until_complete(a())
- self._event_loop.close()
+class NotebookProtocolServer(AbstractNotebookProtocolServer):
+ websocket: WebSocket
+ session: Session
+ sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
+ def __init__(self, session: Session):
+ self.session = session
-session_manager = SessionManager()
+ async def _send_json(self, data: Any):
+ await self.websocket.send_json(data)
+ async def _receive_json(self, message_type: str) -> Any:
+ return await self.sub_queue.get(message_type)
-def session(x_continue_session_id: str = Header("anonymous")) -> Session:
- return session_manager.get_session(x_continue_session_id)
+ async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
+ await self._send_json(data)
+ resp = await self._receive_json(message_type)
+ return resp_model.parse_obj(resp)
+ def handle_json(self, message_type: str, data: Any):
+ try:
+ if message_type == "main_input":
+ self.on_main_input(data["input"])
+ elif message_type == "step_user_input":
+ self.on_step_user_input(data["input"], data["index"])
+ elif message_type == "refinement_input":
+ self.on_refinement_input(data["input"], data["index"])
+ elif message_type == "reverse_to_index":
+ self.on_reverse_to_index(data["index"])
+ except Exception as e:
+ print(e)
-def websocket_session(session_id: str) -> Session:
- return session_manager.get_session(session_id)
+ async def send_state_update(self):
+ state = self.session.agent.get_full_state().dict()
+ await self._send_json({
+ "messageType": "state_update",
+ "state": state
+ })
+ def on_main_input(self, input: str):
+ # Do something with user input
+ asyncio.create_task(self.session.agent.accept_user_input(input))
-class StartSessionBody(BaseModel):
- config_file_path: Union[str, None]
+ def on_reverse_to_index(self, index: int):
+ # Reverse the history to the given index
+ asyncio.create_task(self.session.agent.reverse_to_index(index))
+ def on_step_user_input(self, input: str, index: int):
+ asyncio.create_task(
+ self.session.agent.give_user_input(input, index))
-class StartSessionResp(BaseModel):
- session_id: str
+ def on_refinement_input(self, input: str, index: int):
+ asyncio.create_task(
+ self.session.agent.accept_refinement_input(input, index))
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
await websocket.accept()
+ print("Session started")
session_manager.register_websocket(session.session_id, websocket)
- data = await websocket.receive_text()
+ protocol = NotebookProtocolServer(session)
+ protocol.websocket = websocket
+
# Update any history that may have happened before connection
- await websocket.send_json({
- "messageType": "state",
- "state": session_manager.get_session(session.session_id).agent.get_full_state().dict()
- })
- print("Session started", data)
+ await protocol.send_state_update()
+
while AppStatus.should_exit is False:
- data = await websocket.receive_json()
- print("Received data", data)
+ message = await websocket.receive_json()
+ print("Received message", message)
+ if type(message) is str:
+ message = json.loads(message)
- if "messageType" not in data:
+ if "messageType" not in message or "data" not in message:
continue
- messageType = data["messageType"]
+ message_type = message["messageType"]
+ data = message["data"]
- try:
- if messageType == "main_input":
- # Do something with user input
- asyncio.create_task(
- session.agent.accept_user_input(data["value"]))
- elif messageType == "step_user_input":
- asyncio.create_task(
- session.agent.give_user_input(data["value"], data["index"]))
- elif messageType == "refinement_input":
- asyncio.create_task(
- session.agent.accept_refinement_input(data["value"], data["index"]))
- elif messageType == "reverse":
- # Reverse the history to the given index
- asyncio.create_task(
- session.agent.reverse_to_index(data["index"]))
- except Exception as e:
- print(e)
+ protocol.handle_json(message_type, data)
print("Closing websocket")
await websocket.close()
-
-
-@router.post("/run")
-def request_run(step: Step, session=Depends(session)):
- """Tell an agent to take a specific action."""
- asyncio.create_task(session.agent.run_from_step(step))
- return "Success"
-
-
-@router.get("/history")
-def get_history(session=Depends(session)) -> History:
- return session.agent.history
-
-
-@router.post("/observation")
-def post_observation(observation: Observation, session=Depends(session)):
- asyncio.create_task(session.agent.run_from_observation(observation))
- return "Success"
diff --git a/continuedev/src/continuedev/server/notebook_protocol.py b/continuedev/src/continuedev/server/notebook_protocol.py
new file mode 100644
index 00000000..c2be82e0
--- /dev/null
+++ b/continuedev/src/continuedev/server/notebook_protocol.py
@@ -0,0 +1,28 @@
+from typing import Any
+from abc import ABC, abstractmethod
+
+
+class AbstractNotebookProtocolServer(ABC):
+ @abstractmethod
+ async def handle_json(self, data: Any):
+ """Handle a json message"""
+
+ @abstractmethod
+ def on_main_input(self, input: str):
+ """Called when the user inputs something"""
+
+ @abstractmethod
+ def on_reverse_to_index(self, index: int):
+ """Called when the user requests reverse to a previous index"""
+
+ @abstractmethod
+ def on_refinement_input(self, input: str, index: int):
+ """Called when the user inputs a refinement"""
+
+ @abstractmethod
+ def on_step_user_input(self, input: str, index: int):
+ """Called when the user inputs a step"""
+
+ @abstractmethod
+ async def send_state_update(self, state: dict):
+ """Send a state update to the client"""
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
new file mode 100644
index 00000000..b48c21b7
--- /dev/null
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -0,0 +1,101 @@
+from fastapi import WebSocket
+from typing import Any, Dict, List, Union
+from uuid import uuid4
+
+from ..models.filesystem_edit import FileEditWithFullContents
+from ..core.policy import DemoPolicy
+from ..core.main import FullState
+from ..core.agent import Agent
+from ..libs.steps.nate import ImplementAbstractMethodStep
+from .ide_protocol import AbstractIdeProtocolServer
+import asyncio
+import nest_asyncio
+nest_asyncio.apply()
+
+
+class Session:
+ session_id: str
+ agent: Agent
+ ws: Union[WebSocket, None]
+
+ def __init__(self, session_id: str, agent: Agent):
+ self.session_id = session_id
+ self.agent = agent
+ self.ws = None
+
+
+class DemoAgent(Agent):
+ first_seen: bool = False
+ cumulative_edit_string = ""
+
+ def handle_manual_edits(self, edits: List[FileEditWithFullContents]):
+ for edit in edits:
+ self.cumulative_edit_string += edit.fileEdit.replacement
+ self._manual_edits_buffer.append(edit)
+ # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge.
+ # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit)
+ # FOR DEMO PURPOSES
+ if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement:
+ self.cumulative_edit_string = ""
+ asyncio.create_task(self.run_from_step(
+ ImplementAbstractMethodStep()))
+
+
+class SessionManager:
+ sessions: Dict[str, Session] = {}
+ _event_loop: Union[asyncio.BaseEventLoop, None] = None
+
+ def get_session(self, session_id: str) -> Session:
+ if session_id not in self.sessions:
+ raise KeyError("Session ID not recognized")
+ return self.sessions[session_id]
+
+ def new_session(self, ide: AbstractIdeProtocolServer) -> str:
+ agent = DemoAgent(policy=DemoPolicy(), ide=ide)
+ session_id = str(uuid4())
+ session = Session(session_id=session_id, agent=agent)
+ self.sessions[session_id] = session
+
+ async def on_update(state: FullState):
+ await session_manager.send_ws_data(session_id, "state_update", {
+ "state": agent.get_full_state().dict()
+ })
+
+ agent.on_update(on_update)
+ asyncio.create_task(agent.run_policy())
+ return session_id
+
+ def remove_session(self, session_id: str):
+ del self.sessions[session_id]
+
+ def register_websocket(self, session_id: str, ws: WebSocket):
+ self.sessions[session_id].ws = ws
+ print("Registered websocket for session", session_id)
+
+ async def send_ws_data(self, session_id: str, message_type: str, data: Any):
+ if self.sessions[session_id].ws is None:
+ print(f"Session {session_id} has no websocket")
+ return
+
+ async def a():
+ await self.sessions[session_id].ws.send_json({
+ "messageType": message_type,
+ "data": data
+ })
+
+ # Run coroutine in background
+ await self.sessions[session_id].ws.send_json({
+ "messageType": message_type,
+ "data": data
+ })
+ return
+ if self._event_loop is None or self._event_loop.is_closed():
+ self._event_loop = asyncio.new_event_loop()
+ self._event_loop.run_until_complete(a())
+ self._event_loop.close()
+ else:
+ self._event_loop.run_until_complete(a())
+ self._event_loop.close()
+
+
+session_manager = SessionManager()