summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-06-01 00:09:19 -0400
committerNate Sesti <sestinj@gmail.com>2023-06-01 00:09:19 -0400
commitea5d50af9ba84242c25e82069d86c08ac039e543 (patch)
tree29b99530ca6df54f5bc641be667c6820ff7e4014 /continuedev/src
parente6dded34c26fd17ede17776755cc41c26782a045 (diff)
downloadsncontinue-ea5d50af9ba84242c25e82069d86c08ac039e543.tar.gz
sncontinue-ea5d50af9ba84242c25e82069d86c08ac039e543.tar.bz2
sncontinue-ea5d50af9ba84242c25e82069d86c08ac039e543.zip
Polishing for dlt codespace and !config!
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/config.py29
-rw-r--r--continuedev/src/continuedev/core/policy.py9
-rw-r--r--continuedev/src/continuedev/core/sdk.py17
-rw-r--r--continuedev/src/continuedev/libs/steps/draft/dlt.py4
-rw-r--r--continuedev/src/continuedev/libs/steps/main.py20
-rw-r--r--continuedev/src/continuedev/libs/steps/steps_on_startup.py30
-rw-r--r--continuedev/src/continuedev/server/main.py8
-rw-r--r--continuedev/src/continuedev/server/notebook.py2
8 files changed, 112 insertions, 7 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
new file mode 100644
index 00000000..e62f0e4f
--- /dev/null
+++ b/continuedev/src/continuedev/core/config.py
@@ -0,0 +1,29 @@
+import json
+import os
+from pydantic import BaseModel
+from typing import List, Optional, Dict
+import yaml
+
+
+class ContinueConfig(BaseModel):
+ """
+ A pydantic class for the continue config file.
+ """
+ steps_on_startup: Optional[Dict[str, Dict]] = {}
+ server_url: Optional[str] = None
+
+
+def load_config(config_file: str) -> ContinueConfig:
+ """
+ Load the config file and return a ContinueConfig object.
+ """
+ _, ext = os.path.splitext(config_file)
+ if ext == '.yaml':
+ with open(config_file, 'r') as f:
+ config_dict = yaml.safe_load(f)
+ elif ext == '.json':
+ with open(config_file, 'r') as f:
+ config_dict = json.load(f)
+ else:
+ raise ValueError(f'Unknown config file extension: {ext}')
+ return ContinueConfig(**config_dict)
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index c0ba0f4f..9f68515f 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -1,9 +1,10 @@
from typing import List, Tuple, Type
-from ..libs.steps.ty import CreatePipelineStep
+from ..libs.steps.steps_on_startup import StepsOnStartupStep
+from ..libs.steps.draft.dlt import CreatePipelineStep
from .main import Step, Validator, History, Policy
from .observation import Observation, TracebackObservation, UserInputObservation
-from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep
+from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep, MessageStep, EmptyStep
from ..libs.steps.nate import WritePytestsStep, CreateTableStep
# from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma
from ..libs.steps.continue_step import ContinueStepStep
@@ -13,6 +14,10 @@ class DemoPolicy(Policy):
ran_code_last: bool = False
def next(self, history: History) -> Step:
+ # At the very start, run initial Steps spcecified in the config
+ if history.get_current() is None:
+ return MessageStep(message="Welcome to Continue!") >> StepsOnStartupStep()
+
observation = history.last_observation()
if observation is not None and isinstance(observation, UserInputObservation):
# This could be defined with ObservationTypePolicy. Ergonomics not right though.
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 6ae0be04..4d82a1ae 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -1,5 +1,7 @@
import os
from typing import Coroutine, Union
+
+from .config import ContinueConfig, load_config
from ..models.filesystem_edit import FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory
from ..models.filesystem import RangeInFile
from ..libs.llm import LLM
@@ -106,3 +108,18 @@ class ContinueSDK:
val = (await self.run_step(WaitForUserInputStep(prompt=prompt))).text
save_env_var(env_var, val)
return val
+
+ async def get_config(self) -> ContinueConfig:
+ dir = await self.ide.getWorkspaceDirectory()
+ yaml_path = os.path.join(dir, 'continue.yaml')
+ json_path = os.path.join(dir, 'continue.json')
+ if os.path.exists(yaml_path):
+ return load_config(yaml_path)
+ elif os.path.exists(json_path):
+ return load_config(json_path)
+ else:
+ return ContinueConfig()
+
+ def set_loading_message(self, message: str):
+ # self.__agent.set_loading_message(message)
+ raise NotImplementedError()
diff --git a/continuedev/src/continuedev/libs/steps/draft/dlt.py b/continuedev/src/continuedev/libs/steps/draft/dlt.py
index 460aa0cc..778ced1d 100644
--- a/continuedev/src/continuedev/libs/steps/draft/dlt.py
+++ b/continuedev/src/continuedev/libs/steps/draft/dlt.py
@@ -2,7 +2,7 @@ from textwrap import dedent
from ....models.filesystem_edit import AddFile
from ....core.main import Step
from ....core.sdk import ContinueSDK
-from ..main import WaitForUserInputStep
+from ..core.core import WaitForUserInputStep
class SetupPipelineStep(Step):
@@ -77,6 +77,6 @@ class CreatePipelineStep(Step):
async def run(self, sdk: ContinueSDK):
await sdk.run_step(
WaitForUserInputStep(prompt="What API do you want to load data from?") >>
- SetupPipelineStep() >>
+ SetupPipelineStep(api_description="WeatherAPI.com API") >>
ValidatePipelineStep()
)
diff --git a/continuedev/src/continuedev/libs/steps/main.py b/continuedev/src/continuedev/libs/steps/main.py
index 70c0d4b8..d31db0eb 100644
--- a/continuedev/src/continuedev/libs/steps/main.py
+++ b/continuedev/src/continuedev/libs/steps/main.py
@@ -331,3 +331,23 @@ class SolveTracebackStep(Step):
await sdk.run_step(EditCodeStep(
range_in_files=range_in_files, prompt=prompt))
return None
+
+
+class MessageStep(Step):
+ message: str
+
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
+ return self.message
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ pass
+
+
+class EmptyStep(Step):
+ hide: bool = True
+
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
+ return ""
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ pass
diff --git a/continuedev/src/continuedev/libs/steps/steps_on_startup.py b/continuedev/src/continuedev/libs/steps/steps_on_startup.py
new file mode 100644
index 00000000..fd1eb8f0
--- /dev/null
+++ b/continuedev/src/continuedev/libs/steps/steps_on_startup.py
@@ -0,0 +1,30 @@
+
+
+from ...core.main import ContinueSDK, Models, Step
+from .main import UserInputStep
+from .draft.dlt import CreatePipelineStep
+
+
+step_name_to_step_class = {
+ "UserInputStep": UserInputStep,
+ "CreatePipelineStep": CreatePipelineStep
+}
+
+
+class StepsOnStartupStep(Step):
+ hide: bool = True
+
+ async def describe(self, models: Models):
+ return "Running steps on startup"
+
+ async def run(self, sdk: ContinueSDK):
+ steps_descriptions = (await sdk.get_config()).steps_on_startup
+
+ for step_name, step_params in steps_descriptions.items():
+ try:
+ step = step_name_to_step_class[step_name](**step_params)
+ except:
+ print(
+ f"Incorrect parameters for step {step_name}. Parameters provided were: {step_params}")
+ continue
+ await sdk.run_step(step)
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index 1977bfdd..1ffe1450 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -1,3 +1,4 @@
+import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .ide import router as ide_router
@@ -32,8 +33,11 @@ args = parser.parse_args()
def run_server():
- uvicorn.run(app, host="0.0.0.0", port=args.port,
- log_config="logging.yaml")
+ if os.path.exists("logging.yaml"):
+ uvicorn.run(app, host="0.0.0.0", port=args.port,
+ log_config="logging.yaml")
+ else:
+ uvicorn.run(app, host="0.0.0.0", port=args.port)
if __name__ == "__main__":
diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py
index ab9211a8..9ca510dd 100644
--- a/continuedev/src/continuedev/server/notebook.py
+++ b/continuedev/src/continuedev/server/notebook.py
@@ -53,7 +53,7 @@ class NotebookProtocolServer(AbstractNotebookProtocolServer):
async def _send_json(self, message_type: str, data: Any):
await self.websocket.send_json({
- "message_type": message_type,
+ "messageType": message_type,
"data": data
})