diff options
| author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-06-12 21:36:11 -0700 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-06-12 21:36:11 -0700 | 
| commit | 297511b5e3ba67e3c37e5cb4c038186093da6a95 (patch) | |
| tree | c850a433c3a7d594fdeabf11cabd505bb5388d40 /continuedev | |
| parent | 52ffaa321ee24d2a930ac4e6ff083aaa37be79e8 (diff) | |
| parent | 01cfbc179a33c99d55acdc989dbafd554db16a92 (diff) | |
| download | sncontinue-297511b5e3ba67e3c37e5cb4c038186093da6a95.tar.gz sncontinue-297511b5e3ba67e3c37e5cb4c038186093da6a95.tar.bz2 sncontinue-297511b5e3ba67e3c37e5cb4c038186093da6a95.zip | |
Merge pull request #76 from continuedev/superset-of-chat
Superset of chat
Diffstat (limited to 'continuedev')
27 files changed, 526 insertions, 193 deletions
| diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock index 857a7c99..4aedce87 100644 --- a/continuedev/poetry.lock +++ b/continuedev/poetry.lock @@ -345,6 +345,21 @@ typing-inspect = ">=0.4.0"  dev = ["flake8", "hypothesis", "ipython", "mypy (>=0.710)", "portray", "pytest (>=6.2.3)", "simplejson", "types-dataclasses"]  [[package]] +name = "diff-match-patch" +version = "20230430" +description = "Diff Match and Patch" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ +    {file = "diff-match-patch-20230430.tar.gz", hash = "sha256:953019cdb9c9d2c9e47b5b12bcff3cf4746fc4598eb406076fa1fc27e6a1f15c"}, +    {file = "diff_match_patch-20230430-py3-none-any.whl", hash = "sha256:dce43505fb7b1b317de7195579388df0746d90db07015ed47a85e5e44930ef93"}, +] + +[package.extras] +dev = ["attribution (==1.6.2)", "black (==23.3.0)", "flit (==3.8.0)", "mypy (==1.2.0)", "ufmt (==2.1.0)", "usort (==1.0.6)"] + +[[package]]  name = "fastapi"  version = "0.95.1"  description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" @@ -1252,23 +1267,6 @@ socks = ["PySocks (>=1.5.6,!=1.5.7)"]  use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]  [[package]] -name = "setuptools" -version = "67.7.2" -description = "Easily download, build, install, upgrade, and uninstall Python packages" -category = "main" -optional = false -python-versions = ">=3.7" -files = [ -    {file = "setuptools-67.7.2-py3-none-any.whl", hash = "sha256:23aaf86b85ca52ceb801d32703f12d77517b2556af839621c641fca11287952b"}, -    {file = "setuptools-67.7.2.tar.gz", hash = "sha256:f104fa03692a2602fa0fec6c6a9e63b6c8a968de13e17c026957dd1f53d80990"}, -] - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] -testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] - -[[package]]  name = "six"  version = "1.16.0"  description = "Python 2 and 3 compatibility utilities" @@ -1739,4 +1737,4 @@ multidict = ">=4.0"  [metadata]  lock-version = "2.0"  python-versions = "^3.9" -content-hash = "9f9254c954b7948c49debba86bc81a4a9c3f50694424f5940d0058725b1bf0fb" +content-hash = "0f5f759bac0e44a1fbcc9babeccdea8688ea2226a4bae7a13858542ae03a3228" diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 631742ec..7315e79d 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -7,6 +7,7 @@ readme = "README.md"  [tool.poetry.dependencies]  python = "^3.9" +diff-match-patch = "^20230430"  fastapi = "^0.95.1"  typer = "^0.7.0"  openai = "^0.27.5" diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 417971cd..0658f1b8 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -76,8 +76,8 @@ class AbstractContinueSDK(ABC):      async def get_user_secret(self, env_var: str, prompt: str) -> str:          pass -    @abstractmethod -    async def get_config(self) -> ContinueConfig: +    @abstractproperty +    def config(self) -> ContinueConfig:          pass      @abstractmethod @@ -88,6 +88,6 @@ class AbstractContinueSDK(ABC):      def add_chat_context(self, content: str, role: ChatMessageRole = "assistent"):          pass -    @abstractproperty -    def chat_context(self) -> List[ChatMessage]: +    @abstractmethod +    async def get_chat_context(self) -> List[ChatMessage]:          pass diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index c979d53a..0874bbc5 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -40,6 +40,15 @@ class Autopilot(ContinueBaseModel):      def get_full_state(self) -> FullState:          return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue) +    async def get_available_slash_commands(self) -> List[Dict]: +        return list(map(lambda x: {"name": x.name, "description": x.description}, self.continue_sdk.config.slash_commands)) or [] + +    async def clear_history(self): +        self.history = History.from_empty() +        self._main_user_input_queue = [] +        self._active = False +        await self.update_subscribers() +      def on_update(self, callback: Coroutine["FullState", None, None]):          """Subscribe to changes to state"""          self._on_update_callbacks.append(callback) @@ -88,6 +97,10 @@ class Autopilot(ContinueBaseModel):      async def retry_at_index(self, index: int):          self._retry_queue.post(str(index), None) +    async def delete_at_index(self, index: int): +        self.history.timeline[index].step.hide = True +        await self.update_subscribers() +      async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]:          capture_event(              'step run', {'step_name': step.name, 'params': step.dict()}) @@ -192,7 +205,7 @@ class Autopilot(ContinueBaseModel):              await self._run_singular_step(next_step, is_future_step) -            if next_step := self.policy.next(self.history): +            if next_step := self.policy.next(self.continue_sdk.config, self.history):                  is_future_step = False              elif next_step := self.history.take_next_step():                  is_future_step = True @@ -205,11 +218,11 @@ class Autopilot(ContinueBaseModel):          await self.update_subscribers()      async def run_from_observation(self, observation: Observation): -        next_step = self.policy.next(self.history) +        next_step = self.policy.next(self.continue_sdk.config, self.history)          await self.run_from_step(next_step)      async def run_policy(self): -        first_step = self.policy.next(self.history) +        first_step = self.policy.next(self.continue_sdk.config, self.history)          await self.run_from_step(first_step)      async def _request_halt(self): diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 8ed41a82..cf723984 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,9 +1,18 @@  import json  import os -from pydantic import BaseModel +from pydantic import BaseModel, validator  from typing import List, Optional, Dict  import yaml +from .main import Step + + +class SlashCommand(BaseModel): +    name: str +    description: str +    step_name: str +    params: Optional[Dict] = {} +  class ContinueConfig(BaseModel):      """ @@ -12,6 +21,29 @@ class ContinueConfig(BaseModel):      steps_on_startup: Optional[Dict[str, Dict]] = {}      server_url: Optional[str] = None      allow_anonymous_telemetry: Optional[bool] = True +    slash_commands: Optional[List[SlashCommand]] = [ +        # SlashCommand( +        #     name="pytest", +        #     description="Write pytest unit tests for the current file", +        #     step_name="WritePytestsRecipe", +        #     params=??) + +        SlashCommand( +            name="dlt", +            description="Create a dlt pipeline", +            step_name="CreatePipelineRecipe", +        ), +        SlashCommand( +            name="ddtobq", +            description="Create a dlt pipeline to load data from a data source into BigQuery", +            step_name="DDtoBQRecipe", +        ), +        SlashCommand( +            name="deployairflow", +            description="Deploy a dlt pipeline to Airflow", +            step_name="DeployPipelineAirflowRecipe", +        ), +    ]  def load_config(config_file: str) -> ContinueConfig: diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 81aaaf2e..f6b26d69 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -118,11 +118,15 @@ class Models:      pass +class ContinueConfig: +    pass + +  class Policy(ContinueBaseModel):      """A rule that determines which step to take next"""      # Note that history is mutable, kinda sus -    def next(self, history: History = History.from_empty()) -> "Step": +    def next(self, config: ContinueConfig, history: History = History.from_empty()) -> "Step":          raise NotImplementedError diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index 8e43bf55..37a10e36 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -1,5 +1,6 @@  from typing import List, Tuple, Type +from .config import ContinueConfig  from ..steps.chroma import AnswerQuestionChroma, EditFileChroma, CreateCodebaseIndexChroma  from ..steps.steps_on_startup import StepsOnStartupStep  from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe @@ -11,14 +12,17 @@ from ..steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeSte  from ..recipes.WritePytestsRecipe.main import WritePytestsRecipe  from ..recipes.ContinueRecipeRecipe.main import ContinueStepStep  from ..steps.comment_code import CommentCodeStep +from ..steps.react import NLDecisionStep +from ..steps.chat import SimpleChatStep  from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe  from ..steps.core.core import MessageStep +from ..libs.util.step_name_to_steps import get_step_from_name  class DemoPolicy(Policy):      ran_code_last: bool = False -    def next(self, history: History) -> Step: +    def next(self, config: ContinueConfig, history: History) -> Step:          # At the very start, run initial Steps spcecified in the config          if history.get_current() is None:              return ( @@ -30,25 +34,36 @@ class DemoPolicy(Policy):          observation = history.get_current().observation          if observation is not None and isinstance(observation, UserInputObservation):              # This could be defined with ObservationTypePolicy. Ergonomics not right though. +            user_input = observation.user_input + +            if user_input.startswith("/"): +                command_name = user_input.split(" ")[0] +                after_command = " ".join(user_input.split(" ")[1:]) +                for slash_command in config.slash_commands: +                    if slash_command.name == command_name[1:]: +                        return get_step_from_name(slash_command.step_name, slash_command.params) + +            if "/pytest" in user_input.lower(): +                return WritePytestsRecipe(instructions=user_input)              if "/pytest" in observation.user_input.lower():                  return WritePytestsRecipe(instructions=observation.user_input) -            elif "/dlt" in observation.user_input.lower(): -                return CreatePipelineRecipe() -            elif "/ddtobq" in observation.user_input.lower(): -                return DDtoBQRecipe() -            elif "/airflow" in observation.user_input.lower(): -                return DeployPipelineAirflowRecipe() -            elif "/transform" in observation.user_input.lower(): -                return AddTransformRecipe()              elif "/comment" in observation.user_input.lower():                  return CommentCodeStep() -            elif "/ask" in observation.user_input: -                return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:])) -            elif "/edit" in observation.user_input: -                return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:])) -            elif "/step" in observation.user_input: -                return ContinueStepStep(prompt=" ".join(observation.user_input.split(" ")[1:])) -            return EditHighlightedCodeStep(user_input=observation.user_input) +            elif "/ask" in user_input: +                return AnswerQuestionChroma(question=" ".join(user_input.split(" ")[1:])) +            elif "/edit" in user_input: +                return EditFileChroma(request=" ".join(user_input.split(" ")[1:])) +            elif "/step" in user_input: +                return ContinueStepStep(prompt=" ".join(user_input.split(" ")[1:])) +            # return EditHighlightedCodeStep(user_input=user_input) +            return NLDecisionStep(user_input=user_input, steps=[ +                (EditHighlightedCodeStep(user_input=user_input), +                 "Edit the highlighted code"), +                # AnswerQuestionChroma(question=user_input), +                # EditFileChroma(request=user_input), +                (SimpleChatStep(user_input=user_input), +                 "Respond to the user with a chat message"), +            ], default_step=EditHighlightedCodeStep(user_input=user_input))          state = history.get_current() @@ -57,54 +72,3 @@ class DemoPolicy(Policy):              return SolveTracebackStep(traceback=observation.traceback)          else:              return None - - -class ObservationTypePolicy(Policy): -    def __init__(self, base_policy: Policy, observation_type: Type[Observation], step_type: Type[Step]): -        self.observation_type = observation_type -        self.step_type = step_type -        self.base_policy = base_policy - -    def next(self, history: History) -> Step: -        observation = history.last_observation() -        if observation is not None and isinstance(observation, self.observation_type): -            return self.step_type(observation) -        return self.base_policy.next(history) - - -class PolicyWrappedWithValidators(Policy): -    """Default is to stop, unless the validator tells what to do next""" -    index: int -    stage: int - -    def __init__(self, base_policy: Policy, pairs: List[Tuple[Validator, Type[Step]]]): -        # Want to pass Type[Validator], or just the Validator? Question of where params are coming from. -        self.pairs = pairs -        self.index = len(pairs) -        self.validating = 0 -        self.base_policy = base_policy - -    def next(self, history: History) -> Step: -        if self.index == len(self.pairs): -            self.index = 0 -            return self.base_policy.next(history) - -        if self.stage == 0: -            # Running the validator at the current index for the first time -            validator, step = self.pairs[self.index] -            self.stage = 1 -            return validator -        elif self.stage == 1: -            # Previously ran the validator at the current index, now receiving its ValidatorObservation -            observation = history.last_observation() -            if observation.passed: -                self.stage = 0 -                self.index += 1 -                if self.index == len(self.pairs): -                    self.index = 0 -                    return self.base_policy.next(history) -                else: -                    return self.pairs[self.index][0] -            else: -                _, step_type = self.pairs[self.index] -                return step_type(observation) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 59bfc0f2..1f4cdfb2 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod  import asyncio  from functools import cached_property  from typing import Coroutine, Union @@ -119,8 +118,9 @@ class ContinueSDK(AbstractContinueSDK):      async def get_user_secret(self, env_var: str, prompt: str) -> str:          return await self.ide.getUserSecret(env_var) -    async def get_config(self) -> ContinueConfig: -        dir = await self.ide.getWorkspaceDirectory() +    @property +    def config(self) -> ContinueConfig: +        dir = self.ide.workspace_directory          yaml_path = os.path.join(dir, '.continue', 'config.yaml')          json_path = os.path.join(dir, '.continue', 'config.json')          if os.path.exists(yaml_path): @@ -141,6 +141,14 @@ class ContinueSDK(AbstractContinueSDK):          self.history.timeline[self.history.current_index].step.chat_context.append(              ChatMessage(content=content, role=role)) -    @property -    def chat_context(self) -> List[ChatMessage]: -        return self.history.to_chat_history() +    async def get_chat_context(self) -> List[ChatMessage]: +        history_context = self.history.to_chat_history() +        highlighted_code = await self.ide.getHighlightedCode() +        for rif in highlighted_code: +            code = await self.ide.readRangeInFile(rif) +            history_context.append(ChatMessage( +                content=f"The following code is highlighted:\n```\n{code}\n```", role="user")) +        return history_context + +    async def update_ui(self): +        await self.__autopilot.update_subscribers() diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 6a537afd..9b8d3447 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -23,7 +23,7 @@ class OpenAI(LLM):      def with_system_message(self, system_message: Union[str, None]):          return OpenAI(api_key=self.api_key, system_message=system_message) -    def stream_chat(self, messages, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          self.completion_count += 1          args = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,                  "frequency_penalty": 0, "presence_penalty": 0} | kwargs @@ -31,7 +31,7 @@ class OpenAI(LLM):          args["model"] = "gpt-3.5-turbo"          for chunk in openai.ChatCompletion.create( -            messages=messages, +            messages=self.compile_chat_messages(with_history, prompt),              **args,          ):              if "content" in chunk.choices[0].delta: @@ -39,7 +39,21 @@ class OpenAI(LLM):              else:                  continue -    def stream_complete(self, prompt: str, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    def compile_chat_messages(self, msgs: List[ChatMessage], prompt: str) -> List[Dict]: +        history = [] +        if self.system_message: +            history.append({ +                "role": "system", +                "content": self.system_message +            }) +        history += [msg.dict() for msg in msgs] +        history.append({ +            "role": "user", +            "content": prompt +        }) +        return history + +    def stream_complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          self.completion_count += 1          args = {"model": self.default_model, "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5,                  "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, "suffix": None} | kwargs @@ -47,10 +61,7 @@ class OpenAI(LLM):          if args["model"] == "gpt-3.5-turbo":              generator = openai.ChatCompletion.create( -                messages=[{ -                    "role": "user", -                    "content": prompt -                }], +                messages=self.compile_chat_messages(with_history, prompt),                  **args,              )              for chunk in generator: @@ -71,19 +82,8 @@ class OpenAI(LLM):                  "frequency_penalty": 0, "presence_penalty": 0, "stream": False} | kwargs          if args["model"] == "gpt-3.5-turbo": -            messages = [] -            if self.system_message: -                messages.append({ -                    "role": "system", -                    "content": self.system_message -                }) -            messages += [msg.dict() for msg in with_history] -            messages.append({ -                "role": "user", -                "content": prompt -            })              resp = openai.ChatCompletion.create( -                messages=messages, +                messages=self.compile_chat_messages(with_history, prompt),                  **args,              ).choices[0].message.content          else: diff --git a/continuedev/src/continuedev/libs/util/calculate_diff.py b/continuedev/src/continuedev/libs/util/calculate_diff.py new file mode 100644 index 00000000..ff0a135f --- /dev/null +++ b/continuedev/src/continuedev/libs/util/calculate_diff.py @@ -0,0 +1,166 @@ +import difflib +from typing import List +from ...models.main import Position, Range +from ...models.filesystem import FileEdit +from diff_match_patch import diff_match_patch + + +def calculate_diff_match_patch(filepath: str, original: str, updated: str) -> List[FileEdit]: +    dmp = diff_match_patch() +    diffs = dmp.diff_main(original, updated) +    dmp.diff_cleanupSemantic(diffs) + +    replacements = [] + +    current_index = 0 +    deleted_length = 0 + +    for diff in diffs: +        if diff[0] == diff_match_patch.DIFF_EQUAL: +            current_index += len(diff[1]) +            deleted_length = 0 +        elif diff[0] == diff_match_patch.DIFF_INSERT: +            current_index += deleted_length +            replacements.append((current_index, current_index, diff[1])) +            current_index += len(diff[1]) +            deleted_length = 0 +        elif diff[0] == diff_match_patch.DIFF_DELETE: +            replacements.append( +                (current_index, current_index + len(diff[1]), '')) +            deleted_length += len(diff[1]) +        elif diff[0] == diff_match_patch.DIFF_REPLACE: +            replacements.append( +                (current_index, current_index + len(diff[1]), '')) +            current_index += deleted_length +            replacements.append((current_index, current_index, diff[2])) +            current_index += len(diff[2]) +            deleted_length = 0 + +    return [FileEdit(filepath=filepath, range=Range.from_indices(original, r[0], r[1]), replacement=r[2]) for r in replacements] + + +def calculate_diff(filepath: str, original: str, updated: str) -> List[FileEdit]: +    s = difflib.SequenceMatcher(None, original, updated) +    offset = 0  # The indices are offset by previous deletions/insertions +    edits = [] +    for tag, i1, i2, j1, j2 in s.get_opcodes(): +        i1, i2, j1, j2 = i1 + offset, i2 + offset, j1 + offset, j2 + offset +        replacement = updated[j1:j2] +        if tag == "equal": +            pass +        elif tag == "delete": +            edits.append(FileEdit.from_deletion( +                filepath, Range.from_indices(original, i1, i2))) +            offset -= i2 - i1 +        elif tag == "insert": +            edits.append(FileEdit.from_insertion( +                filepath, Position.from_index(original, i1), replacement)) +            offset += j2 - j1 +        elif tag == "replace": +            edits.append(FileEdit(filepath=filepath, range=Range.from_indices( +                original, i1, i2), replacement=replacement)) +            offset += (j2 - j1) - (i2 - i1) +        else: +            raise Exception("Unexpected difflib.SequenceMatcher tag: " + tag) + +    return edits + + +def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit]: +    # original_lines = original.splitlines() +    # updated_lines = updated.splitlines() +    # offset = 0 +    # while len(original_lines) and len(updated_lines) and original_lines[0] == updated_lines[0]: +    #     original_lines = original_lines[1:] +    #     updated_lines = updated_lines[1:] + +    # while len(original_lines) and len(updated_lines) and original_lines[-1] == updated_lines[-1]: +    #     original_lines = original_lines[:-1] +    #     updated_lines = updated_lines[:-1] + +    # original = "\n".join(original_lines) +    # updated = "\n".join(updated_lines) + +    edits = [] +    max_iterations = 1000 +    i = 0 +    while not original == updated: +        # TODO - For some reason it can't handle a single newline at the end of the file? +        s = difflib.SequenceMatcher(None, original, updated) +        opcodes = s.get_opcodes() +        for edit_index in range(len(opcodes)): +            tag, i1, i2, j1, j2 = s.get_opcodes()[edit_index] +            replacement = updated[j1:j2] +            if tag == "equal": +                continue +            elif tag == "delete": +                edits.append(FileEdit.from_deletion( +                    filepath, Range.from_indices(original, i1, i2))) +            elif tag == "insert": +                edits.append(FileEdit.from_insertion( +                    filepath, Position.from_index(original, i1), replacement)) +            elif tag == "replace": +                edits.append(FileEdit(filepath=filepath, range=Range.from_indices( +                    original, i1, i2), replacement=replacement)) +            else: +                raise Exception( +                    "Unexpected difflib.SequenceMatcher tag: " + tag) +            break + +        original = apply_edit_to_str(original, edits[-1]) + +        i += 1 +        if i > max_iterations: +            raise Exception("Max iterations reached") + +    return edits + + +def read_range_in_str(s: str, r: Range) -> str: +    lines = s.splitlines()[r.start.line:r.end.line + 1] +    if len(lines) == 0: +        return "" + +    lines[0] = lines[0][r.start.character:] +    lines[-1] = lines[-1][:r.end.character + 1] +    return "\n".join(lines) + + +def apply_edit_to_str(s: str, edit: FileEdit) -> str: +    original = read_range_in_str(s, edit.range) + +    # Split lines and deal with some edge cases (could obviously be nicer) +    lines = s.splitlines() +    if s.startswith("\n"): +        lines.insert(0, "") +    if s.endswith("\n"): +        lines.append("") + +    if len(lines) == 0: +        lines = [""] + +    end = Position(line=edit.range.end.line, +                   character=edit.range.end.character) +    if edit.range.end.line == len(lines) and edit.range.end.character == 0: +        end = Position(line=edit.range.end.line - 1, +                       character=len(lines[min(len(lines) - 1, edit.range.end.line - 1)])) + +    before_lines = lines[:edit.range.start.line] +    after_lines = lines[end.line + 1:] +    between_str = lines[min(len(lines) - 1, edit.range.start.line)][:edit.range.start.character] + \ +        edit.replacement + \ +        lines[min(len(lines) - 1, end.line)][end.character + 1:] + +    new_range = Range( +        start=edit.range.start, +        end=Position( +            line=edit.range.start.line + +            len(edit.replacement.splitlines()) - 1, +            character=edit.range.start.character + +            len(edit.replacement.splitlines() +                [-1]) if edit.replacement != "" else 0 +        ) +    ) + +    lines = before_lines + between_str.splitlines() + after_lines +    return "\n".join(lines) diff --git a/continuedev/src/continuedev/libs/util/copy_codebase.py b/continuedev/src/continuedev/libs/util/copy_codebase.py index af957a34..97143faf 100644 --- a/continuedev/src/continuedev/libs/util/copy_codebase.py +++ b/continuedev/src/continuedev/libs/util/copy_codebase.py @@ -3,13 +3,12 @@ from pathlib import Path  from typing import Iterable, List, Union  from watchdog.observers import Observer  from watchdog.events import PatternMatchingEventHandler -from ..models.main import FileEdit, DeleteDirectory, DeleteFile, AddDirectory, AddFile, FileSystemEdit, Position, Range, RenameFile, RenameDirectory, SequentialFileSystemEdit -from ..models.filesystem import FileSystem -from ..libs.main import Autopilot -from ..libs.map_path import map_path -from ..libs.steps.main import ManualEditAction +from ...models.main import FileEdit, DeleteDirectory, DeleteFile, AddDirectory, AddFile, FileSystemEdit, RenameFile, RenameDirectory, SequentialFileSystemEdit +from ...models.filesystem import FileSystem +from ...core.autopilot import Autopilot +from .map_path import map_path +from ...core.sdk import ManualEditStep  import shutil -import difflib  def create_copy(orig_root: str, copy_root: str = None, ignore: Iterable[str] = []): @@ -36,33 +35,6 @@ def create_copy(orig_root: str, copy_root: str = None, ignore: Iterable[str] = [                  os.symlink(child, map_path(child)) -def calculate_diff(filepath: str, original: str, updated: str) -> List[FileEdit]: -    s = difflib.SequenceMatcher(None, original, updated) -    offset = 0  # The indices are offset by previous deletions/insertions -    edits = [] -    for tag, i1, i2, j1, j2 in s.get_opcodes(): -        i1, i2, j1, j2 = i1 + offset, i2 + offset, j1 + offset, j2 + offset -        replacement = updated[j1:j2] -        if tag == "equal": -            pass -        elif tag == "delete": -            edits.append(FileEdit.from_deletion( -                filepath, Range.from_indices(original, i1, i2))) -            offset -= i2 - i1 -        elif tag == "insert": -            edits.append(FileEdit.from_insertion( -                filepath, Position.from_index(original, i1), replacement)) -            offset += j2 - j1 -        elif tag == "replace": -            edits.append(FileEdit(filepath, Range.from_indices( -                original, i1, i2), replacement)) -            offset += (j2 - j1) - (i2 + i1) -        else: -            raise Exception("Unexpected difflib.SequenceMatcher tag: " + tag) - -    return edits - -  # The whole usage of watchdog here should only be specific to RealFileSystem, you want to have a different "Observer" class for VirtualFileSystem, which would depend on being sent notifications  class CopyCodebaseEventHandler(PatternMatchingEventHandler):      def __init__(self, ignore_directories: List[str], ignore_patterns: List[str], autopilot: Autopilot, orig_root: str, copy_root: str, filesystem: FileSystem): diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py new file mode 100644 index 00000000..4023b73b --- /dev/null +++ b/continuedev/src/continuedev/libs/util/step_name_to_steps.py @@ -0,0 +1,27 @@ +from typing import Dict + +from ...core.main import Step +from ...steps.core.core import UserInputStep +from ...recipes.CreatePipelineRecipe.main import CreatePipelineRecipe +from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe +from ...recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe +from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe +from ...recipes.AddTransformRecipe.main import AddTransformRecipe + +step_name_to_step_class = { +    "UserInputStep": UserInputStep, +    "CreatePipelineRecipe": CreatePipelineRecipe, +    "DDtoBQRecipe": DDtoBQRecipe, +    "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe, +    "AddTransformRecipe": AddTransformRecipe, +    "DDtoBQRecipe": DDtoBQRecipe +} + + +def get_step_from_name(step_name: str, params: Dict) -> Step: +    try: +        return step_name_to_step_class[step_name](**params) +    except: +        print( +            f"Incorrect parameters for step {step_name}. Parameters provided were: {params}") +        raise diff --git a/continuedev/src/continuedev/models/filesystem_edit.py b/continuedev/src/continuedev/models/filesystem_edit.py index 8e74b819..b06ca2b3 100644 --- a/continuedev/src/continuedev/models/filesystem_edit.py +++ b/continuedev/src/continuedev/models/filesystem_edit.py @@ -30,8 +30,8 @@ class FileEdit(AtomicFileSystemEdit):          return FileEdit(map_path(self.filepath, orig_root, copy_root), self.range, self.replacement)      @staticmethod -    def from_deletion(filepath: str, start: Position, end: Position) -> "FileEdit": -        return FileEdit(filepath, Range(start, end), "") +    def from_deletion(filepath: str, range: Range) -> "FileEdit": +        return FileEdit(filepath=filepath, range=range, replacement="")      @staticmethod      def from_insertion(filepath: str, position: Position, content: str) -> "FileEdit": diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py index 818168ba..92bddc98 100644 --- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py +++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py @@ -1,7 +1,7 @@  from textwrap import dedent -from ...core.main import Step  from ...core.sdk import ContinueSDK +from ...core.main import Step  from ...steps.core.core import WaitForUserInputStep  from ...steps.core.core import MessageStep  from .steps import SetupPipelineStep, ValidatePipelineStep, RunQueryStep diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py index e59cc51c..ea4607da 100644 --- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py @@ -6,11 +6,10 @@ import time  from ...models.main import Range  from ...models.filesystem import RangeInFile  from ...steps.core.core import MessageStep -from ...core.sdk import Models  from ...core.observation import DictObservation, InternalErrorObservation  from ...models.filesystem_edit import AddFile, FileEdit  from ...core.main import Step -from ...core.sdk import ContinueSDK +from ...core.sdk import ContinueSDK, Models  AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index b873a88f..cf046734 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -77,6 +77,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer):                  self.on_reverse_to_index(data["index"])              elif message_type == "retry_at_index":                  self.on_retry_at_index(data["index"]) +            elif message_type == "clear_history": +                self.on_clear_history() +            elif message_type == "delete_at_index": +                self.on_delete_at_index(data["index"])          except Exception as e:              print(e) @@ -86,6 +90,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer):              "state": state          }) +    async def send_available_slash_commands(self): +        commands = await self.session.autopilot.get_available_slash_commands() +        await self._send_json("available_slash_commands", { +            "commands": commands +        }) +      def on_main_input(self, input: str):          # Do something with user input          asyncio.create_task(self.session.autopilot.accept_user_input(input)) @@ -106,6 +116,12 @@ class GUIProtocolServer(AbstractGUIProtocolServer):          asyncio.create_task(              self.session.autopilot.retry_at_index(index)) +    def on_clear_history(self): +        asyncio.create_task(self.session.autopilot.clear_history()) + +    def on_delete_at_index(self, index: int): +        asyncio.create_task(self.session.autopilot.delete_at_index(index)) +  @router.websocket("/ws")  async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): @@ -117,6 +133,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we      protocol.websocket = websocket      # Update any history that may have happened before connection +    await protocol.send_available_slash_commands()      await protocol.send_state_update()      while AppStatus.should_exit is False: diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py index 287f9e3b..d9506c6f 100644 --- a/continuedev/src/continuedev/server/gui_protocol.py +++ b/continuedev/src/continuedev/server/gui_protocol.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict, List  from abc import ABC, abstractmethod @@ -28,5 +28,17 @@ class AbstractGUIProtocolServer(ABC):          """Send a state update to the client"""      @abstractmethod +    async def send_available_slash_commands(self, commands: List[Dict]): +        """Send a list of available slash commands to the client""" + +    @abstractmethod      def on_retry_at_index(self, index: int):          """Called when the user requests a retry at a previous index""" + +    @abstractmethod +    def on_clear_history(self): +        """Called when the user requests to clear the history""" + +    @abstractmethod +    def on_delete_at_index(self, index: int): +        """Called when the user requests to delete a step at a given index""" diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 0dbfaf38..ebea08a5 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -28,6 +28,7 @@ class DemoAutopilot(Autopilot):      cumulative_edit_string = ""      def handle_manual_edits(self, edits: List[FileEditWithFullContents]): +        return          for edit in edits:              self.cumulative_edit_string += edit.fileEdit.replacement              self._manual_edits_buffer.append(edit) diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py new file mode 100644 index 00000000..c9bd760b --- /dev/null +++ b/continuedev/src/continuedev/server/state_manager.py @@ -0,0 +1,21 @@ +from typing import Any, List, Tuple, Union +from fastapi import WebSocket +from pydantic import BaseModel +from ..core.main import FullState + +# State updates represented as (path, replacement) pairs +StateUpdate = Tuple[List[Union[str, int]], Any] + + +class StateManager: +    """ +    A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client. +    """ + +    def __init__(self, ws: WebSocket): +        self.ws = ws + +    def _send_update(self, updates: List[StateUpdate]): +        self.ws.send_json( +            [update.dict() for update in updates] +        ) diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py new file mode 100644 index 00000000..aadcfa8e --- /dev/null +++ b/continuedev/src/continuedev/steps/chat.py @@ -0,0 +1,19 @@ +from textwrap import dedent +from typing import List +from ..core.main import Step +from ..core.sdk import ContinueSDK +from .core.core import MessageStep + + +class SimpleChatStep(Step): +    user_input: str +    name: str = "Chat" + +    async def run(self, sdk: ContinueSDK): +        self.description = "" +        for chunk in sdk.models.gpt35.stream_chat(self.user_input, with_history=await sdk.get_chat_context()): +            self.description += chunk +            await sdk.update_ui() + +        self.name = sdk.models.gpt35.complete( +            f"Write a short title for the following chat message: {self.description}").strip() diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 392339c6..8dc2478b 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -3,8 +3,10 @@ import os  import subprocess  from textwrap import dedent  from typing import Coroutine, List, Union -from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder +from ...models.main import Range +from ...libs.util.calculate_diff import calculate_diff2, apply_edit_to_str +from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder  from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit  from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents  from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation @@ -85,7 +87,7 @@ class ShellCommandsStep(Step):                      {output}                      ``` -                    This is a brief summary of the error followed by a suggestion on how it can be fixed:"""), with_history=sdk.chat_context) +                    This is a brief summary of the error followed by a suggestion on how it can be fixed:"""), with_history=await sdk.get_chat_context())                  sdk.raise_exception(                      title="Error while running query", message=output, with_step=MessageStep(name=f"Suggestion to solve error {AI_ASSISTED_STRING}", message=f"{suggestion}\n\nYou can click the retry button on the failed step to try again.") @@ -149,7 +151,11 @@ class Gpt35EditCodeStep(Step):      _prompt_and_completion: str = ""      async def describe(self, models: Models) -> Coroutine[str, None, None]: -        return models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:") +        description = models.gpt35.complete( +            f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points. Be concise and only mention changes made to the commit before, not prefix or suffix:") +        self.name = models.gpt35.complete( +            f"Write a short title for this description: {description}") +        return description      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:          rif_with_contents = [] @@ -174,11 +180,40 @@ class Gpt35EditCodeStep(Step):              self._prompt_and_completion += prompt + completion -            await sdk.ide.applyFileSystemEdit( -                FileEdit(filepath=rif.filepath, range=rif.range, replacement=completion)) -            await sdk.ide.saveFile(rif.filepath) +            # Calculate diff, open file, apply edits, and highlight changed lines +            edits = calculate_diff2( +                rif.filepath, rif.contents, completion.removesuffix("\n")) +              await sdk.ide.setFileOpen(rif.filepath) +            lines_to_highlight = set() +            for edit in edits: +                edit.range.start.line += rif.range.start.line +                edit.range.start.character += rif.range.start.character +                edit.range.end.line += rif.range.start.line +                edit.range.end.character += rif.range.start.character if edit.range.end.line == 0 else 0 + +                for line in range(edit.range.start.line, edit.range.end.line + 1 + len(edit.replacement.splitlines()) - (edit.range.end.line - edit.range.start.line + 1)): +                    lines_to_highlight.add(line) + +                await sdk.ide.applyFileSystemEdit(edit) + +            current_start = None +            last_line = None +            for line in sorted(list(lines_to_highlight)): +                if current_start is None: +                    current_start = line +                elif line != last_line + 1: +                    await sdk.ide.highlightCode(RangeInFile(filepath=edit.filepath, range=Range.from_shorthand(current_start, 0, last_line, 0))) +                    current_start = line + +                last_line = line + +            if current_start is not None: +                await sdk.ide.highlightCode(RangeInFile(filepath=edit.filepath, range=Range.from_shorthand(current_start, 0, last_line, 0))) + +            await sdk.ide.saveFile(rif.filepath) +  class EditFileStep(Step):      filepath: str diff --git a/continuedev/src/continuedev/steps/draft/migration.py b/continuedev/src/continuedev/steps/draft/migration.py index f3b36b5e..7c4b7eb5 100644 --- a/continuedev/src/continuedev/steps/draft/migration.py +++ b/continuedev/src/continuedev/steps/draft/migration.py @@ -13,7 +13,7 @@ class MigrationStep(Step):          recent_edits = await sdk.ide.get_recent_edits(self.edited_file)          recent_edits_string = "\n\n".join(              map(lambda x: x.to_string(), recent_edits)) -        description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") +        description = sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")          await sdk.run([              "cd libs",              "poetry run alembic revision --autogenerate -m " + description, diff --git a/continuedev/src/continuedev/steps/find_and_replace.py b/continuedev/src/continuedev/steps/find_and_replace.py index fec33997..690872c0 100644 --- a/continuedev/src/continuedev/steps/find_and_replace.py +++ b/continuedev/src/continuedev/steps/find_and_replace.py @@ -10,7 +10,7 @@ class FindAndReplaceStep(Step):      replacement: str      async def describe(self, models: Models): -        return f"Replace all instances of `{self.pattern}` with `{self.replacement}` in `{self.filepath}`" +        return f"Replaced all instances of `{self.pattern}` with `{self.replacement}` in `{self.filepath}`"      async def run(self, sdk: ContinueSDK):          file_content = await sdk.ide.readFile(self.filepath) diff --git a/continuedev/src/continuedev/steps/input/nl_multiselect.py b/continuedev/src/continuedev/steps/input/nl_multiselect.py index c3c832f5..36c489c7 100644 --- a/continuedev/src/continuedev/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/steps/input/nl_multiselect.py @@ -23,5 +23,6 @@ class NLMultiselectStep(Step):          if first_try is not None:              return first_try -        gpt_parsed = await sdk.models.gpt35.complete(f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:") +        gpt_parsed = sdk.models.gpt35.complete( +            f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:")          return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py index 24335b4f..36e4f519 100644 --- a/continuedev/src/continuedev/steps/main.py +++ b/continuedev/src/continuedev/steps/main.py @@ -16,6 +16,7 @@ from ..core.sdk import ContinueSDK, Models  from ..core.observation import Observation  import subprocess  from .core.core import Gpt35EditCodeStep +from ..libs.util.calculate_diff import calculate_diff2  class SetupContinueWorkspaceStep(Step): @@ -62,10 +63,10 @@ class RunPolicyUntilDoneStep(Step):      policy: "Policy"      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: -        next_step = self.policy.next(sdk.history) +        next_step = self.policy.next(sdk.config, sdk.history)          while next_step is not None:              observation = await sdk.run_step(next_step) -            next_step = self.policy.next(sdk.history) +            next_step = self.policy.next(sdk.config, sdk.history)          return observation @@ -216,7 +217,8 @@ class StarCoderEditHighlightedCodeStep(Step):      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:          range_in_files = await sdk.ide.getHighlightedCode() -        if len(range_in_files) == 0: +        found_highlighted_code = len(range_in_files) > 0 +        if not found_highlighted_code:              # Get the full contents of all open files              files = await sdk.ide.getOpenFiles()              contents = {} @@ -239,15 +241,29 @@ class StarCoderEditHighlightedCodeStep(Step):          for rif in rif_with_contents:              prompt = self._prompt.format(                  code=rif.contents, user_request=self.user_input) -            completion = str(sdk.models.starcoder.complete(prompt)) + +            if found_highlighted_code: +                full_file_contents = await sdk.ide.readFile(rif.filepath) +                segs = full_file_contents.split(rif.contents) +                prompt = f"<file_prefix>{segs[0]}<file_suffix>{segs[1]}" + prompt + +            completion = str((await sdk.models.starcoder()).complete(prompt))              eot_token = "<|endoftext|>" -            if completion.endswith(eot_token): -                completion = completion[:completion.rindex(eot_token)] +            completion = completion.removesuffix(eot_token) + +            if found_highlighted_code: +                rif.contents = segs[0] + rif.contents + segs[1] +                completion = segs[0] + completion + segs[1]              self._prompt_and_completion += prompt + completion -            await sdk.ide.applyFileSystemEdit( -                FileEdit(filepath=rif.filepath, range=rif.range, replacement=completion)) +            edits = calculate_diff2( +                rif.filepath, rif.contents, completion.removesuffix("\n")) +            for edit in edits: +                await sdk.ide.applyFileSystemEdit(edit) + +            # await sdk.ide.applyFileSystemEdit( +            #     FileEdit(filepath=rif.filepath, range=rif.range, replacement=completion))              await sdk.ide.saveFile(rif.filepath)              await sdk.ide.setFileOpen(rif.filepath) diff --git a/continuedev/src/continuedev/steps/react.py b/continuedev/src/continuedev/steps/react.py new file mode 100644 index 00000000..d825d424 --- /dev/null +++ b/continuedev/src/continuedev/steps/react.py @@ -0,0 +1,39 @@ +from textwrap import dedent +from typing import List, Union, Tuple +from ..core.main import Step +from ..core.sdk import ContinueSDK +from .core.core import MessageStep + + +class NLDecisionStep(Step): +    user_input: str +    default_step: Union[Step, None] = None +    steps: List[Tuple[Step, str]] + +    hide: bool = True + +    async def run(self, sdk: ContinueSDK): +        step_descriptions = "\n".join([ +            f"- {step[0].name}: {step[1]}" +            for step in self.steps +        ]) +        prompt = dedent(f"""\ +            The following steps are available, in the format "- [step name]: [step description]": +            {step_descriptions} +             +            The user gave the following input: +             +            {self.user_input} +             +            Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""") + +        resp = sdk.models.gpt35.complete(prompt).lower() + +        step_to_run = None +        for step in self.steps: +            if step[0].name.lower() in resp: +                step_to_run = step[0] + +        step_to_run = step_to_run or self.default_step or self.steps[0] + +        await sdk.run_step(step_to_run) diff --git a/continuedev/src/continuedev/steps/steps_on_startup.py b/continuedev/src/continuedev/steps/steps_on_startup.py index eae8b558..365cbe1a 100644 --- a/continuedev/src/continuedev/steps/steps_on_startup.py +++ b/continuedev/src/continuedev/steps/steps_on_startup.py @@ -1,19 +1,12 @@ -from ..core.main import ContinueSDK, Models, Step +from ..core.main import Step +from ..core.sdk import Models, ContinueSDK  from .main import UserInputStep  from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe  from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe  from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe  from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe  from ..recipes.AddTransformRecipe.main import AddTransformRecipe - -step_name_to_step_class = { -    "UserInputStep": UserInputStep, -    "CreatePipelineRecipe": CreatePipelineRecipe, -    "DDtoBQRecipe": DDtoBQRecipe, -    "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe, -    "AddTransformRecipe": AddTransformRecipe, -    "DDtoBQRecipe": DDtoBQRecipe -} +from ..libs.util.step_name_to_steps import get_step_from_name  class StepsOnStartupStep(Step): @@ -23,13 +16,8 @@ class StepsOnStartupStep(Step):          return "Running steps on startup"      async def run(self, sdk: ContinueSDK): -        steps_descriptions = (await sdk.get_config()).steps_on_startup +        steps_on_startup = sdk.config.steps_on_startup -        for step_name, step_params in steps_descriptions.items(): -            try: -                step = step_name_to_step_class[step_name](**step_params) -            except: -                print( -                    f"Incorrect parameters for step {step_name}. Parameters provided were: {step_params}") -                continue +        for step_name, step_params in steps_on_startup.items(): +            step = get_step_from_name(step_name, step_params)              await sdk.run_step(step) | 
