diff options
Diffstat (limited to 'continuedev/src')
84 files changed, 2336 insertions, 1235 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index e048f877..98730d38 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -1,11 +1,10 @@ -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod  from typing import Coroutine, List, Union -from .config import ContinueConfig  from ..models.filesystem_edit import FileSystemEdit +from .config import ContinueConfig +from .main import ChatMessage, History, Step  from .observation import Observation -from .main import ChatMessage, History, Step, ChatMessageRole -  """  [[Generate]] diff --git a/continuedev/src/continuedev/core/env.py b/continuedev/src/continuedev/core/env.py index 2692c348..60b86538 100644 --- a/continuedev/src/continuedev/core/env.py +++ b/continuedev/src/continuedev/core/env.py @@ -1,6 +1,7 @@ -from dotenv import load_dotenv  import os +from dotenv import load_dotenv +  def get_env_var(var_name: str):      load_dotenv() @@ -8,21 +9,21 @@ def get_env_var(var_name: str):  def make_sure_env_exists(): -    if not os.path.exists('.env'): -        with open('.env', 'w') as f: -            f.write('') +    if not os.path.exists(".env"): +        with open(".env", "w") as f: +            f.write("")  def save_env_var(var_name: str, var_value: str):      make_sure_env_exists() -    with open('.env', 'r') as f: +    with open(".env", "r") as f:          lines = f.readlines() -    with open('.env', 'w') as f: +    with open(".env", "w") as f:          values = {}          for line in lines: -            key, value = line.split('=') -            value = value.replace('"', '') +            key, value = line.split("=") +            value = value.replace('"', "")              values[key] = value          values[var_name] = var_value diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index a33d777e..53440dae 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -1,10 +1,10 @@  import json  from typing import Coroutine, Dict, List, Literal, Optional, Union -from pydantic.schema import schema +from pydantic import BaseModel, validator +from pydantic.schema import schema  from ..models.main import ContinueBaseModel -from pydantic import BaseModel, validator  from .observation import Observation  ChatMessageRole = Literal["assistant", "user", "system", "function"] @@ -27,8 +27,7 @@ class ChatMessage(ContinueBaseModel):          d = self.dict()          del d["summary"]          if d["function_call"] is not None: -            d["function_call"]["name"] = d["function_call"]["name"].replace( -                " ", "") +            d["function_call"]["name"] = d["function_call"]["name"].replace(" ", "")          if d["content"] is None:              d["content"] = "" @@ -49,9 +48,9 @@ class ChatMessage(ContinueBaseModel):  def resolve_refs(schema_data):      def traverse(obj):          if isinstance(obj, dict): -            if '$ref' in obj: -                ref = obj['$ref'] -                parts = ref.split('/') +            if "$ref" in obj: +                ref = obj["$ref"] +                parts = ref.split("/")                  ref_obj = schema_data                  for part in parts[1:]:                      ref_obj = ref_obj[part] @@ -67,8 +66,14 @@ def resolve_refs(schema_data):      return traverse(schema_data) -unincluded_parameters = ["system_message", "chat_context", -                         "manage_own_chat_context", "hide", "name", "description"] +unincluded_parameters = [ +    "system_message", +    "chat_context", +    "manage_own_chat_context", +    "hide", +    "name", +    "description", +]  def step_to_json_schema(step) -> str: @@ -82,7 +87,7 @@ def step_to_json_schema(step) -> str:      return {          "name": step.name.replace(" ", ""),          "description": step.description or "", -        "parameters": parameters +        "parameters": parameters,      } @@ -96,6 +101,7 @@ def step_to_fn_call_arguments(step: "Step") -> str:  class HistoryNode(ContinueBaseModel):      """A point in history, a list of which make up History""" +      step: "Step"      observation: Union[Observation, None]      depth: int @@ -111,12 +117,14 @@ class HistoryNode(ContinueBaseModel):                  role="assistant",                  name=self.step.__class__.__name__,                  content=self.step.description or f"Ran function {self.step.name}", -                summary=f"Called function {self.step.name}" -            )] +                summary=f"Called function {self.step.name}", +            ) +        ]  class History(ContinueBaseModel):      """A history of steps taken and their results""" +      timeline: List[HistoryNode]      current_index: int @@ -128,7 +136,7 @@ class History(ContinueBaseModel):          return msgs      def add_node(self, node: HistoryNode) -> int: -        """ Add node and return the index where it was added """ +        """Add node and return the index where it was added"""          self.timeline.insert(self.current_index + 1, node)          self.current_index += 1          return self.current_index @@ -138,10 +146,15 @@ class History(ContinueBaseModel):              return None          return self.timeline[self.current_index] -    def get_last_at_depth(self, depth: int, include_current: bool = False) -> Union[HistoryNode, None]: +    def get_last_at_depth( +        self, depth: int, include_current: bool = False +    ) -> Union[HistoryNode, None]:          i = self.current_index if include_current else self.current_index - 1          while i >= 0: -            if self.timeline[i].depth == depth and type(self.timeline[i].step).__name__ != "ManualEditStep": +            if ( +                self.timeline[i].depth == depth +                and type(self.timeline[i].step).__name__ != "ManualEditStep" +            ):                  return self.timeline[i]              i -= 1          return None @@ -204,24 +217,27 @@ class ContextItemId(BaseModel):      """      A ContextItemId is a unique identifier for a ContextItem.      """ +      provider_title: str      item_id: str -    @validator('provider_title', 'item_id') +    @validator("provider_title", "item_id")      def must_be_valid_id(cls, v):          import re -        if not re.match(r'^[0-9a-zA-Z_-]*$', v): + +        if not re.match(r"^[0-9a-zA-Z_-]*$", v):              raise ValueError( -                "Both provider_title and item_id can only include characters 0-9, a-z, A-Z, -, and _") +                "Both provider_title and item_id can only include characters 0-9, a-z, A-Z, -, and _" +            )          return v      def to_string(self) -> str:          return f"{self.provider_title}-{self.item_id}"      @staticmethod -    def from_string(string: str) -> 'ContextItemId': -        provider_title, *rest = string.split('-') -        item_id = '-'.join(rest) +    def from_string(string: str) -> "ContextItemId": +        provider_title, *rest = string.split("-") +        item_id = "-".join(rest)          return ContextItemId(provider_title=provider_title, item_id=item_id) @@ -231,22 +247,24 @@ class ContextItemDescription(BaseModel):      The id can be used to retrieve the ContextItem from the ContextManager.      """ +      name: str      description: str -    id:  ContextItemId +    id: ContextItemId  class ContextItem(BaseModel):      """      A ContextItem is a single item that is stored in the ContextManager.      """ +      description: ContextItemDescription      content: str -    @validator('content', pre=True) +    @validator("content", pre=True)      def content_must_be_string(cls, v):          if v is None: -            return '' +            return ""          return v      editing: bool = False @@ -261,6 +279,7 @@ class SessionInfo(ContinueBaseModel):  class FullState(ContinueBaseModel):      """A full state of the program, including the history""" +      history: History      active: bool      user_input_queue: List[str] @@ -286,7 +305,9 @@ class Policy(ContinueBaseModel):      """A rule that determines which step to take next"""      # Note that history is mutable, kinda sus -    def next(self, config: ContinueConfig, history: History = History.from_empty()) -> "Step": +    def next( +        self, config: ContinueConfig, history: History = History.from_empty() +    ) -> "Step":          raise NotImplementedError @@ -373,7 +394,12 @@ class ContinueCustomException(Exception):      message: str      with_step: Union[Step, None] -    def __init__(self, message: str, title: str = "Error while running step:", with_step: Union[Step, None] = None): +    def __init__( +        self, +        message: str, +        title: str = "Error while running step:", +        with_step: Union[Step, None] = None, +    ):          self.message = message          self.title = title          self.with_step = with_step diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index 900762b6..52a52b1d 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,10 +1,13 @@ -from typing import Optional, Any -from pydantic import BaseModel, validator +from typing import Any, Optional + +from pydantic import BaseModel +  from ..libs.llm import LLM  class Models(BaseModel):      """Main class that holds the current model configuration""" +      default: LLM      small: Optional[LLM] = None      medium: Optional[LLM] = None diff --git a/continuedev/src/continuedev/core/observation.py b/continuedev/src/continuedev/core/observation.py index 126cf19e..8a5e454e 100644 --- a/continuedev/src/continuedev/core/observation.py +++ b/continuedev/src/continuedev/core/observation.py @@ -1,4 +1,5 @@  from pydantic import BaseModel, validator +  from ..models.main import Traceback diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py index dba4874f..d77cce49 100644 --- a/continuedev/src/continuedev/libs/chroma/query.py +++ b/continuedev/src/continuedev/libs/chroma/query.py @@ -1,12 +1,19 @@  import json +import os  import subprocess +from functools import cached_property  from typing import List, Tuple -from llama_index import GPTVectorStoreIndex, StorageContext, load_index_from_storage, Document + +from llama_index import ( +    Document, +    GPTVectorStoreIndex, +    StorageContext, +    load_index_from_storage, +)  from llama_index.langchain_helpers.text_splitter import TokenTextSplitter -import os -from .update import filter_ignored_files, load_gpt_index_documents +  from ..util.logging import logger -from functools import cached_property +from .update import filter_ignored_files, load_gpt_index_documents  class ChromaIndexManager: @@ -18,23 +25,42 @@ class ChromaIndexManager:      @cached_property      def current_commit(self) -> str:          """Get the current commit.""" -        return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=self.workspace_dir).decode("utf-8").strip() +        return ( +            subprocess.check_output( +                ["git", "rev-parse", "HEAD"], cwd=self.workspace_dir +            ) +            .decode("utf-8") +            .strip() +        )      @cached_property      def current_branch(self) -> str:          """Get the current branch.""" -        return subprocess.check_output( -            ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=self.workspace_dir).decode("utf-8").strip() +        return ( +            subprocess.check_output( +                ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=self.workspace_dir +            ) +            .decode("utf-8") +            .strip() +        )      @cached_property      def index_dir(self) -> str: -        return os.path.join(self.workspace_dir, ".continue", "chroma", self.current_branch) +        return os.path.join( +            self.workspace_dir, ".continue", "chroma", self.current_branch +        )      @cached_property      def git_root_dir(self):          """Get the root directory of a Git repository."""          try: -            return subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], cwd=self.workspace_dir).strip().decode() +            return ( +                subprocess.check_output( +                    ["git", "rev-parse", "--show-toplevel"], cwd=self.workspace_dir +                ) +                .strip() +                .decode() +            )          except subprocess.CalledProcessError:              return None @@ -57,8 +83,7 @@ class ChromaIndexManager:              try:                  text_chunks = text_splitter.split_text(doc.text)              except: -                logger.warning( -                    f"ERROR (probably found special token): {doc.text}") +                logger.warning(f"ERROR (probably found special token): {doc.text}")                  continue  # lol              filename = doc.extra_info["filename"]              chunks[filename] = len(text_chunks) @@ -66,8 +91,7 @@ class ChromaIndexManager:                  doc_chunks.append(Document(text, doc_id=f"{filename}::{i}"))          with open(f"{self.index_dir}/metadata.json", "w") as f: -            json.dump({"commit": self.current_commit, -                       "chunks": chunks}, f, indent=4) +            json.dump({"commit": self.current_commit, "chunks": chunks}, f, indent=4)          index = GPTVectorStoreIndex([]) @@ -89,17 +113,30 @@ class ChromaIndexManager:          with open(metadata, "r") as f:              previous_commit = json.load(f)["commit"] -        modified_deleted_files = subprocess.check_output( -            ["git", "diff", "--name-only", previous_commit, self.current_commit]).decode("utf-8").strip() +        modified_deleted_files = ( +            subprocess.check_output( +                ["git", "diff", "--name-only", previous_commit, self.current_commit] +            ) +            .decode("utf-8") +            .strip() +        )          modified_deleted_files = modified_deleted_files.split("\n")          modified_deleted_files = [f for f in modified_deleted_files if f]          deleted_files = [ -            f for f in modified_deleted_files if not os.path.exists(os.path.join(self.workspace_dir, f))] +            f +            for f in modified_deleted_files +            if not os.path.exists(os.path.join(self.workspace_dir, f)) +        ]          modified_files = [ -            f for f in modified_deleted_files if os.path.exists(os.path.join(self.workspace_dir, f))] +            f +            for f in modified_deleted_files +            if os.path.exists(os.path.join(self.workspace_dir, f)) +        ] -        return filter_ignored_files(modified_files, self.index_dir), filter_ignored_files(deleted_files, self.index_dir) +        return filter_ignored_files( +            modified_files, self.index_dir +        ), filter_ignored_files(deleted_files, self.index_dir)      def update_codebase_index(self):          """Update the index with a list of files.""" @@ -108,15 +145,13 @@ class ChromaIndexManager:              self.create_codebase_index()          else:              # index = GPTFaissIndex.load_from_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index") -            index = GPTVectorStoreIndex.load_from_disk( -                f"{self.index_dir}/index.json") +            index = GPTVectorStoreIndex.load_from_disk(f"{self.index_dir}/index.json")              modified_files, deleted_files = self.get_modified_deleted_files()              with open(f"{self.index_dir}/metadata.json", "r") as f:                  metadata = json.load(f)              for file in deleted_files: -                  num_chunks = metadata["chunks"][file]                  for i in range(num_chunks):                      index.delete(f"{file}::{i}") @@ -126,9 +161,7 @@ class ChromaIndexManager:                  logger.debug(f"Deleted {file}")              for file in modified_files: -                  if file in metadata["chunks"]: -                      num_chunks = metadata["chunks"][file]                      for i in range(num_chunks): @@ -159,12 +192,10 @@ class ChromaIndexManager:      def query_codebase_index(self, query: str) -> str:          """Query the codebase index."""          if not self.check_index_exists(): -            logger.debug( -                f"No index found for the codebase at {self.index_dir}") +            logger.debug(f"No index found for the codebase at {self.index_dir}")              return "" -        storage_context = StorageContext.from_defaults( -            persist_dir=self.index_dir) +        storage_context = StorageContext.from_defaults(persist_dir=self.index_dir)          index = load_index_from_storage(storage_context)          # index = GPTVectorStoreIndex.load_from_disk(path)          engine = index.as_query_engine() @@ -173,14 +204,15 @@ class ChromaIndexManager:      def query_additional_index(self, query: str) -> str:          """Query the additional index."""          index = GPTVectorStoreIndex.load_from_disk( -            os.path.join(self.index_dir, 'additional_index.json')) +            os.path.join(self.index_dir, "additional_index.json") +        )          return index.query(query)      def replace_additional_index(self, info: str):          """Replace the additional index with the given info.""" -        with open(f'{self.index_dir}/additional_context.txt', 'w') as f: +        with open(f"{self.index_dir}/additional_context.txt", "w") as f:              f.write(info)          documents = [Document(info)]          index = GPTVectorStoreIndex(documents) -        index.save_to_disk(f'{self.index_dir}/additional_index.json') +        index.save_to_disk(f"{self.index_dir}/additional_index.json")          logger.debug("Additional index replaced") diff --git a/continuedev/src/continuedev/libs/chroma/update.py b/continuedev/src/continuedev/libs/chroma/update.py index d5326a06..7a1217f9 100644 --- a/continuedev/src/continuedev/libs/chroma/update.py +++ b/continuedev/src/continuedev/libs/chroma/update.py @@ -1,28 +1,24 @@  # import faiss  import os  import subprocess - -from llama_index import SimpleDirectoryReader, Document  from typing import List +  from dotenv import load_dotenv +from llama_index import Document, SimpleDirectoryReader  load_dotenv() -FILE_TYPES_TO_IGNORE = [ -    '.pyc', -    '.png', -    '.jpg', -    '.jpeg', -    '.gif', -    '.svg', -    '.ico' -] +FILE_TYPES_TO_IGNORE = [".pyc", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico"]  def filter_ignored_files(files: List[str], root_dir: str):      """Further filter files before indexing."""      for file in files: -        if file.endswith(tuple(FILE_TYPES_TO_IGNORE)) or file.startswith('.git') or file.startswith('archive'): +        if ( +            file.endswith(tuple(FILE_TYPES_TO_IGNORE)) +            or file.startswith(".git") +            or file.startswith("archive") +        ):              continue  # nice          yield root_dir + "/" + file @@ -30,9 +26,15 @@ def filter_ignored_files(files: List[str], root_dir: str):  def get_git_ignored_files(root_dir: str):      """Get the list of ignored files in a Git repository."""      try: -        output = subprocess.check_output( -            ['git', 'ls-files', '--ignored', '--others', '--exclude-standard'], cwd=root_dir).strip().decode() -        return output.split('\n') +        output = ( +            subprocess.check_output( +                ["git", "ls-files", "--ignored", "--others", "--exclude-standard"], +                cwd=root_dir, +            ) +            .strip() +            .decode() +        ) +        return output.split("\n")      except subprocess.CalledProcessError:          return [] @@ -57,4 +59,8 @@ def load_gpt_index_documents(root: str) -> List[Document]:      # Get input files      input_files = get_input_files(root)      # Use SimpleDirectoryReader to load the files into Documents -    return SimpleDirectoryReader(root, input_files=input_files, file_metadata=lambda filename: {"filename": filename}).load_data() +    return SimpleDirectoryReader( +        root, +        input_files=input_files, +        file_metadata=lambda filename: {"filename": filename}, +    ).load_data() diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 70c67856..4af6b8e2 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,5 +1,5 @@  from abc import ABC, abstractproperty -from typing import Any, Coroutine, Dict, Generator, List, Union, Optional +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union  from ...core.main import ChatMessage  from ...models.main import ContinueBaseModel @@ -28,15 +28,21 @@ class LLM(ContinueBaseModel, ABC):          """Stop the connection to the LLM."""          raise NotImplementedError -    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, str]:          """Return the completion of the text with the given temperature."""          raise NotImplementedError -    def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          """Stream the completion through generator."""          raise NotImplementedError -    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat( +        self, messages: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          """Stream the chat through generator."""          raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 9d7bc93f..9a7d0ac9 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,10 +1,15 @@ -from functools import cached_property -import time  from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union + +from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic +  from ...core.main import ChatMessage -from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic  from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages +from ..util.count_tokens import ( +    DEFAULT_ARGS, +    compile_chat_messages, +    count_tokens, +    format_chat_messages, +)  class AnthropicLLM(LLM): @@ -19,7 +24,13 @@ class AnthropicLLM(LLM):      write_log: Optional[Callable[[str], None]] = None -    async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs): +    async def start( +        self, +        *, +        api_key: Optional[str] = None, +        write_log: Callable[[str], None], +        **kwargs, +    ):          self.write_log = write_log          self._async_client = AsyncAnthropic(api_key=self.api_key) @@ -58,7 +69,11 @@ class AnthropicLLM(LLM):          prompt = ""          # Anthropic prompt must start with a Human turn -        if len(messages) > 0 and messages[0]["role"] != "user" and messages[0]["role"] != "system": +        if ( +            len(messages) > 0 +            and messages[0]["role"] != "user" +            and messages[0]["role"] != "system" +        ):              prompt += f"{HUMAN_PROMPT} Hello."          for msg in messages:              prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " @@ -66,7 +81,9 @@ class AnthropicLLM(LLM):          prompt += AI_PROMPT          return prompt -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True @@ -76,50 +93,62 @@ class AnthropicLLM(LLM):          self.write_log(f"Prompt: \n\n{prompt}")          completion = ""          async for chunk in await self._async_client.completions.create( -            prompt=prompt, -            **args +            prompt=prompt, **args          ):              yield chunk.completion              completion += chunk.completion          self.write_log(f"Completion: \n\n{completion}") -    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat( +        self, messages: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True          args = self._transform_args(args)          messages = compile_chat_messages( -            args["model"], messages, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) +            args["model"], +            messages, +            self.context_length, +            args["max_tokens_to_sample"], +            functions=args.get("functions", None), +            system_message=self.system_message, +        )          completion = ""          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          async for chunk in await self._async_client.completions.create( -            prompt=self.__messages_to_prompt(messages), -            **args +            prompt=self.__messages_to_prompt(messages), **args          ): -            yield { -                "role": "assistant", -                "content": chunk.completion -            } +            yield {"role": "assistant", "content": chunk.completion}              completion += chunk.completion          self.write_log(f"Completion: \n\n{completion}") -    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs}          args = self._transform_args(args)          messages = compile_chat_messages( -            args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) +            args["model"], +            with_history, +            self.context_length, +            args["max_tokens_to_sample"], +            prompt, +            functions=None, +            system_message=self.system_message, +        ) -        completion = ""          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        resp = (await self._async_client.completions.create( -            prompt=self.__messages_to_prompt(messages), -            **args -        )).completion +        resp = ( +            await self._async_client.completions.create( +                prompt=self.__messages_to_prompt(messages), **args +            ) +        ).completion          self.write_log(f"Completion: \n\n{resp}")          return resp diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 25a61e63..2d60384b 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,12 +1,11 @@ -from functools import cached_property  import json  from typing import Any, Coroutine, Dict, Generator, List, Union -from pydantic import ConfigDict  import aiohttp +  from ...core.main import ChatMessage  from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens  class GGML(LLM): @@ -22,7 +21,8 @@ class GGML(LLM):      async def start(self, **kwargs):          self._client_session = aiohttp.ClientSession( -            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) +            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) +        )      async def stop(self):          await self._client_session.close() @@ -42,19 +42,27 @@ class GGML(LLM):      def count_tokens(self, text: str):          return count_tokens(self.name, text) -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) - -        async with self._client_session.post(f"{self.server_url}/v1/completions", json={ -            "messages": messages, -            **args -        }) as resp: +            self.name, +            with_history, +            self.context_length, +            args["max_tokens"], +            prompt, +            functions=args.get("functions", None), +            system_message=self.system_message, +        ) + +        async with self._client_session.post( +            f"{self.server_url}/v1/completions", json={"messages": messages, **args} +        ) as resp:              async for line in resp.content.iter_any():                  if line:                      try: @@ -62,40 +70,66 @@ class GGML(LLM):                      except:                          raise Exception(str(line)) -    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat( +        self, messages: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +            self.name, +            messages, +            self.context_length, +            args["max_tokens"], +            None, +            functions=args.get("functions", None), +            system_message=self.system_message, +        )          args["stream"] = True -        async with self._client_session.post(f"{self.server_url}/v1/chat/completions", json={ -            "messages": messages, -            **args -        }) as resp: +        async with self._client_session.post( +            f"{self.server_url}/v1/chat/completions", +            json={"messages": messages, **args}, +        ) as resp:              # This is streaming application/json instaed of text/event-stream              async for line in resp.content.iter_chunks():                  if line[1]:                      try:                          json_chunk = line[0].decode("utf-8") -                        if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): +                        if json_chunk.startswith(": ping - ") or json_chunk.startswith( +                            "data: [DONE]" +                        ):                              continue                          chunks = json_chunk.split("\n")                          for chunk in chunks:                              if chunk.strip() != "":                                  yield {                                      "role": "assistant", -                                    "content": json.loads(chunk[6:])["choices"][0]["delta"] +                                    "content": json.loads(chunk[6:])["choices"][0][ +                                        "delta" +                                    ],                                  }                      except:                          raise Exception(str(line[0])) -    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} -        async with self._client_session.post(f"{self.server_url}/v1/completions", json={ -            "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), -            **args -        }) as resp: +        async with self._client_session.post( +            f"{self.server_url}/v1/completions", +            json={ +                "messages": compile_chat_messages( +                    args["model"], +                    with_history, +                    self.context_length, +                    args["max_tokens"], +                    prompt, +                    functions=None, +                    system_message=self.system_message, +                ), +                **args, +            }, +        ) as resp:              try:                  return await resp.text()              except: diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 8945250c..76331a28 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,13 +1,13 @@ -from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List +  import aiohttp  import requests  from ...core.main import ChatMessage -from ..util.count_tokens import DEFAULT_ARGS, count_tokens -from ...core.main import ChatMessage  from ..llm import LLM +from ..util.count_tokens import DEFAULT_ARGS, count_tokens -DEFAULT_MAX_TIME = 120. +DEFAULT_MAX_TIME = 120.0  class HuggingFaceInferenceAPI(LLM): @@ -24,7 +24,8 @@ class HuggingFaceInferenceAPI(LLM):      async def start(self, **kwargs):          self._client_session = aiohttp.ClientSession( -            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) +            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) +        )      async def stop(self):          await self._client_session.close() @@ -44,35 +45,43 @@ class HuggingFaceInferenceAPI(LLM):      def count_tokens(self, text: str):          return count_tokens(self.name, text) -    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ):          """Return the completion of the text with the given temperature."""          API_URL = f"https://api-inference.huggingface.co/models/{self.model}" -        headers = { -            "Authorization": f"Bearer {self.hf_token}"} - -        response = requests.post(API_URL, headers=headers, json={ -            "inputs": prompt, "parameters": { -                "max_new_tokens": min(250, self.max_context_length - self.count_tokens(prompt)), -                "max_time": DEFAULT_MAX_TIME, -                "return_full_text": False, -            } -        }) +        headers = {"Authorization": f"Bearer {self.hf_token}"} + +        response = requests.post( +            API_URL, +            headers=headers, +            json={ +                "inputs": prompt, +                "parameters": { +                    "max_new_tokens": min( +                        250, self.max_context_length - self.count_tokens(prompt) +                    ), +                    "max_time": DEFAULT_MAX_TIME, +                    "return_full_text": False, +                }, +            }, +        )          data = response.json()          # Error if the response is not a list          if not isinstance(data, list): -            raise Exception( -                "Hugging Face returned an error response: \n\n", data) +            raise Exception("Hugging Face returned an error response: \n\n", data)          return data[0]["generated_text"] -    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]: +    async def stream_chat( +        self, messages: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]:          response = await self.complete(messages[-1].content, messages[:-1]) -        yield { -            "content": response, -            "role": "assistant" -        } +        yield {"content": response, "role": "assistant"} -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Any | List | Dict, None, None]: +    async def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ) -> Generator[Any | List | Dict, None, None]:          response = await self.complete(prompt, with_history)          yield response diff --git a/continuedev/src/continuedev/libs/llm/hugging_face.py b/continuedev/src/continuedev/libs/llm/hugging_face.py index b0db585b..f246a43c 100644 --- a/continuedev/src/continuedev/libs/llm/hugging_face.py +++ b/continuedev/src/continuedev/libs/llm/hugging_face.py @@ -1,5 +1,6 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +  from .llm import LLM -from transformers import AutoTokenizer, AutoModelForCausalLM  class HuggingFace(LLM): @@ -12,6 +13,5 @@ class HuggingFace(LLM):          args = {"max_tokens": 100}          args.update(kwargs)          input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids -        generated_ids = self.model.generate( -            input_ids, max_length=args["max_tokens"]) +        generated_ids = self.model.generate(input_ids, max_length=args["max_tokens"])          return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index fbc2c43f..65e5db3a 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -1,9 +1,9 @@ -from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union  from ...core.main import ChatMessage  from . import LLM -from .proxy_server import ProxyServer  from .openai import OpenAI +from .proxy_server import ProxyServer  class MaybeProxyOpenAI(LLM): @@ -24,7 +24,13 @@ class MaybeProxyOpenAI(LLM):      def context_length(self):          return self.llm.context_length -    async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]): +    async def start( +        self, +        *, +        api_key: Optional[str] = None, +        unique_id: str, +        write_log: Callable[[str], None] +    ):          if self.api_key is None or self.api_key.strip() == "":              self.llm = ProxyServer(model=self.model)          else: @@ -35,16 +41,21 @@ class MaybeProxyOpenAI(LLM):      async def stop(self):          await self.llm.stop() -    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, str]:          return await self.llm.complete(prompt, with_history=with_history, **kwargs) -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: -        resp = self.llm.stream_complete( -            prompt, with_history=with_history, **kwargs) +    async def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]: +        resp = self.llm.stream_complete(prompt, with_history=with_history, **kwargs)          async for item in resp:              yield item -    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat( +        self, messages: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          resp = self.llm.stream_chat(messages=messages, **kwargs)          async for item in resp:              yield item diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index a9f9f7aa..ef3cdc66 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -1,11 +1,11 @@ -from functools import cached_property  import json  from typing import Any, Coroutine, Dict, Generator, List, Union  import aiohttp +  from ...core.main import ChatMessage  from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens  class Ollama(LLM): @@ -62,20 +62,32 @@ class Ollama(LLM):              if msgs[i]["role"] == "user":                  prompt += f"[INST] {msgs[i]['content']} [/INST]"              else: -                prompt += msgs[i]['content'] +                prompt += msgs[i]["content"]          return prompt -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +            self.name, +            with_history, +            self.context_length, +            args["max_tokens"], +            prompt, +            functions=None, +            system_message=self.system_message, +        )          prompt = self.convert_to_chat(messages) -        async with self._client_session.post(f"{self.server_url}/api/generate", json={ -            "prompt": prompt, -            "model": self.model, -        }) as resp: +        async with self._client_session.post( +            f"{self.server_url}/api/generate", +            json={ +                "prompt": prompt, +                "model": self.model, +            }, +        ) as resp:              async for line in resp.content.iter_any():                  if line:                      try: @@ -89,16 +101,28 @@ class Ollama(LLM):                      except:                          raise Exception(str(line[0])) -    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat( +        self, messages: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, messages, self.context_length, args["max_tokens"], None, functions=None, system_message=self.system_message) +            self.name, +            messages, +            self.context_length, +            args["max_tokens"], +            None, +            functions=None, +            system_message=self.system_message, +        )          prompt = self.convert_to_chat(messages) -        async with self._client_session.post(f"{self.server_url}/api/generate", json={ -            "prompt": prompt, -            "model": self.model, -        }) as resp: +        async with self._client_session.post( +            f"{self.server_url}/api/generate", +            json={ +                "prompt": prompt, +                "model": self.model, +            }, +        ) as resp:              # This is streaming application/json instaed of text/event-stream              async for line in resp.content.iter_chunks():                  if line[1]: @@ -111,18 +135,23 @@ class Ollama(LLM):                                  if "response" in j:                                      yield {                                          "role": "assistant", -                                        "content": j["response"] +                                        "content": j["response"],                                      }                      except:                          raise Exception(str(line[0])) -    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, str]:          completion = "" -        async with self._client_session.post(f"{self.server_url}/api/generate", json={ -            "prompt": prompt, -            "model": self.model, -        }) as resp: +        async with self._client_session.post( +            f"{self.server_url}/api/generate", +            json={ +                "prompt": prompt, +                "model": self.model, +            }, +        ) as resp:              async for line in resp.content.iter_any():                  if line:                      try: diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 7eb516a3..c2d86841 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,12 +1,28 @@ -from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional +from typing import ( +    Any, +    Callable, +    Coroutine, +    Dict, +    Generator, +    List, +    Literal, +    Optional, +    Union, +) -from pydantic import BaseModel +import certifi  import openai +from pydantic import BaseModel  from ...core.main import ChatMessage -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top  from ..llm import LLM -import certifi +from ..util.count_tokens import ( +    DEFAULT_ARGS, +    compile_chat_messages, +    count_tokens, +    format_chat_messages, +    prune_raw_prompt_from_top, +)  class OpenAIServerInfo(BaseModel): @@ -16,9 +32,7 @@ class OpenAIServerInfo(BaseModel):      api_type: Literal["azure", "openai"] = "openai" -CHAT_MODELS = { -    "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" -} +CHAT_MODELS = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"}  MAX_TOKENS_FOR_MODEL = {      "gpt-3.5-turbo": 4096,      "gpt-3.5-turbo-0613": 4096, @@ -27,7 +41,7 @@ MAX_TOKENS_FOR_MODEL = {      "gpt-35-turbo-16k": 16_384,      "gpt-35-turbo-0613": 4096,      "gpt-35-turbo": 4096, -    "gpt-4-32k": 32_768 +    "gpt-4-32k": 32_768,  } @@ -43,7 +57,13 @@ class OpenAI(LLM):      system_message: Optional[str] = None      write_log: Optional[Callable[[str], None]] = None -    async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs): +    async def start( +        self, +        *, +        api_key: Optional[str] = None, +        write_log: Callable[[str], None], +        **kwargs, +    ):          self.write_log = write_log          openai.api_key = self.api_key @@ -54,7 +74,7 @@ class OpenAI(LLM):              if self.openai_server_info.api_version is not None:                  openai.api_version = self.openai_server_info.api_version -        if self.verify_ssl == False: +        if self.verify_ssl is False:              openai.verify_ssl_certs = False          openai.ca_bundle_path = self.ca_bundle_path or certifi.where() @@ -80,14 +100,23 @@ class OpenAI(LLM):      def count_tokens(self, text: str):          return count_tokens(self.model, text) -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True          if args["model"] in CHAT_MODELS:              messages = compile_chat_messages( -                args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +                args["model"], +                with_history, +                self.context_length, +                args["max_tokens"], +                prompt, +                functions=None, +                system_message=self.system_message, +            )              self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")              completion = ""              async for chunk in await openai.ChatCompletion.acreate( @@ -110,17 +139,28 @@ class OpenAI(LLM):              self.write_log(f"Completion:\n\n{completion}") -    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat( +        self, messages: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream"] = True          # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it? -        args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613" +        args["model"] = ( +            self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613" +        )          if not args["model"].endswith("0613") and "functions" in args:              del args["functions"]          messages = compile_chat_messages( -            args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +            args["model"], +            messages, +            self.context_length, +            args["max_tokens"], +            None, +            functions=args.get("functions", None), +            system_message=self.system_message, +        )          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          completion = ""          async for chunk in await openai.ChatCompletion.acreate( @@ -132,26 +172,48 @@ class OpenAI(LLM):                  completion += chunk.choices[0].delta.content          self.write_log(f"Completion: \n\n{completion}") -    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs}          if args["model"] in CHAT_MODELS:              messages = compile_chat_messages( -                args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +                args["model"], +                with_history, +                self.context_length, +                args["max_tokens"], +                prompt, +                functions=None, +                system_message=self.system_message, +            )              self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -            resp = (await openai.ChatCompletion.acreate( -                messages=messages, -                **args, -            )).choices[0].message.content +            resp = ( +                ( +                    await openai.ChatCompletion.acreate( +                        messages=messages, +                        **args, +                    ) +                ) +                .choices[0] +                .message.content +            )              self.write_log(f"Completion: \n\n{resp}")          else:              prompt = prune_raw_prompt_from_top( -                args["model"], self.context_length, prompt, args["max_tokens"]) +                args["model"], self.context_length, prompt, args["max_tokens"] +            )              self.write_log(f"Prompt:\n\n{prompt}") -            resp = (await openai.Completion.acreate( -                prompt=prompt, -                **args, -            )).choices[0].text +            resp = ( +                ( +                    await openai.Completion.acreate( +                        prompt=prompt, +                        **args, +                    ) +                ) +                .choices[0] +                .text +            )              self.write_log(f"Completion:\n\n{resp}")          return resp diff --git a/continuedev/src/continuedev/libs/llm/prompt_utils.py b/continuedev/src/continuedev/libs/llm/prompt_utils.py index 51b64732..930b5220 100644 --- a/continuedev/src/continuedev/libs/llm/prompt_utils.py +++ b/continuedev/src/continuedev/libs/llm/prompt_utils.py @@ -1,4 +1,5 @@  from typing import Dict, List, Union +  from ...models.filesystem import RangeInFileWithContents  from ...models.filesystem_edit import FileEdit @@ -11,23 +12,28 @@ class MarkdownStyleEncoderDecoder:          self.range_in_files = range_in_files      def encode(self) -> str: -        return "\n\n".join([ -            f"File ({rif.filepath})\n```\n{rif.contents}\n```" -            for rif in self.range_in_files -        ]) +        return "\n\n".join( +            [ +                f"File ({rif.filepath})\n```\n{rif.contents}\n```" +                for rif in self.range_in_files +            ] +        )      def _suggestions_to_file_edits(self, suggestions: Dict[str, str]) -> List[FileEdit]:          file_edits: List[FileEdit] = []          for suggestion_filepath, suggestion in suggestions.items():              matching_rifs = list( -                filter(lambda r: r.filepath == suggestion_filepath, self.range_in_files)) -            if (len(matching_rifs) > 0): +                filter(lambda r: r.filepath == suggestion_filepath, self.range_in_files) +            ) +            if len(matching_rifs) > 0:                  range_in_file = matching_rifs[0] -                file_edits.append(FileEdit( -                    range=range_in_file.range, -                    filepath=range_in_file.filepath, -                    replacement=suggestion -                )) +                file_edits.append( +                    FileEdit( +                        range=range_in_file.range, +                        filepath=range_in_file.filepath, +                        replacement=suggestion, +                    ) +                )          return file_edits @@ -35,9 +41,9 @@ class MarkdownStyleEncoderDecoder:          if len(self.range_in_files) == 0:              return {} -        if not '```' in completion: +        if "```" not in completion:              completion = "```\n" + completion + "\n```" -        if completion.strip().splitlines()[0].strip() == '```': +        if completion.strip().splitlines()[0].strip() == "```":              first_filepath = self.range_in_files[0].filepath              completion = f"File ({first_filepath})\n" + completion @@ -56,8 +62,7 @@ class MarkdownStyleEncoderDecoder:              elif inside_file:                  if line.startswith("```"):                      inside_file = False -                    suggestions[current_filepath] = "\n".join( -                        current_file_lines) +                    suggestions[current_filepath] = "\n".join(current_file_lines)                      current_file_lines = []                      current_filepath = None                  else: diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index fd8210e9..acc6653d 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,13 +1,29 @@  import json +import ssl  import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional +from typing import ( +    Any, +    Callable, +    Coroutine, +    Dict, +    Generator, +    List, +    Optional, +    Union, +) +  import aiohttp +import certifi +  from ...core.main import ChatMessage  from ..llm import LLM +from ..util.count_tokens import ( +    DEFAULT_ARGS, +    compile_chat_messages, +    count_tokens, +    format_chat_messages, +)  from ..util.telemetry import posthog_logger -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens, format_chat_messages -import certifi -import ssl  ca_bundle_path = certifi.where()  ssl_context = ssl.create_default_context(cafile=ca_bundle_path) @@ -37,9 +53,17 @@ class ProxyServer(LLM):      class Config:          arbitrary_types_allowed = True -    async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs): +    async def start( +        self, +        *, +        api_key: Optional[str] = None, +        write_log: Callable[[str], None], +        unique_id: str, +        **kwargs, +    ):          self._client_session = aiohttp.ClientSession( -            connector=aiohttp.TCPConnector(ssl_context=ssl_context)) +            connector=aiohttp.TCPConnector(ssl_context=ssl_context) +        )          self.write_log = write_log          self.unique_id = unique_id @@ -65,16 +89,26 @@ class ProxyServer(LLM):          # headers with unique id          return {"unique_id": self.unique_id} -    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) +            args["model"], +            with_history, +            self.context_length, +            args["max_tokens"], +            prompt, +            functions=None, +            system_message=self.system_message, +        )          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with self._client_session.post(f"{SERVER_URL}/complete", json={ -            "messages": messages, -            **args -        }, headers=self.get_headers()) as resp: +        async with self._client_session.post( +            f"{SERVER_URL}/complete", +            json={"messages": messages, **args}, +            headers=self.get_headers(), +        ) as resp:              if resp.status != 200:                  raise Exception(await resp.text()) @@ -82,16 +116,26 @@ class ProxyServer(LLM):              self.write_log(f"Completion: \n\n{response_text}")              return response_text -    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: +    async def stream_chat( +        self, messages: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +            args["model"], +            messages, +            self.context_length, +            args["max_tokens"], +            None, +            functions=args.get("functions", None), +            system_message=self.system_message, +        )          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={ -            "messages": messages, -            **args -        }, headers=self.get_headers()) as resp: +        async with self._client_session.post( +            f"{SERVER_URL}/stream_chat", +            json={"messages": messages, **args}, +            headers=self.get_headers(), +        ) as resp:              # This is streaming application/json instaed of text/event-stream              completion = ""              if resp.status != 200: @@ -109,23 +153,40 @@ class ProxyServer(LLM):                                  if "content" in loaded_chunk:                                      completion += loaded_chunk["content"]                      except Exception as e: -                        posthog_logger.capture_event("proxy_server_parse_error", { -                            "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) +                        posthog_logger.capture_event( +                            "proxy_server_parse_error", +                            { +                                "error_title": "Proxy server stream_chat parsing failed", +                                "error_message": "\n".join( +                                    traceback.format_exception(e) +                                ), +                            }, +                        )                  else:                      break              self.write_log(f"Completion: \n\n{completion}") -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) +            self.model, +            with_history, +            self.context_length, +            args["max_tokens"], +            prompt, +            functions=args.get("functions", None), +            system_message=self.system_message, +        )          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={ -            "messages": messages, -            **args -        }, headers=self.get_headers()) as resp: +        async with self._client_session.post( +            f"{SERVER_URL}/stream_complete", +            json={"messages": messages, **args}, +            headers=self.get_headers(), +        ) as resp:              completion = ""              if resp.status != 200:                  raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py index 0dd359e7..c4373185 100644 --- a/continuedev/src/continuedev/libs/llm/replicate.py +++ b/continuedev/src/continuedev/libs/llm/replicate.py @@ -1,10 +1,10 @@ -from abc import abstractproperty -from typing import List, Optional -import replicate  import concurrent.futures +from typing import List + +import replicate -from ..util.count_tokens import DEFAULT_ARGS, count_tokens  from ...core.main import ChatMessage +from ..util.count_tokens import DEFAULT_ARGS, count_tokens  from . import LLM @@ -36,10 +36,12 @@ class ReplicateLLM(LLM):      async def stop(self):          pass -    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ):          def helper():              output = self._client.run(self.model, input={"message": prompt}) -            completion = '' +            completion = ""              for item in output:                  completion += item @@ -51,13 +53,14 @@ class ReplicateLLM(LLM):          return completion -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs): +    async def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ):          for item in self._client.run(self.model, input={"message": prompt}):              yield item      async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs): -        for item in self._client.run(self.model, input={"message": messages[-1].content}): -            yield { -                "content": item, -                "role": "assistant" -            } +        for item in self._client.run( +            self.model, input={"message": messages[-1].content} +        ): +            yield {"content": item, "role": "assistant"} diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index 874dea07..44f5030c 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -2,9 +2,10 @@ import json  from typing import Any, Coroutine, Dict, Generator, List, Union  import aiohttp +  from ...core.main import ChatMessage  from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens  class TogetherLLM(LLM): @@ -19,7 +20,8 @@ class TogetherLLM(LLM):      async def start(self, **kwargs):          self._client_session = aiohttp.ClientSession( -            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) +            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) +        )      async def stop(self):          await self._client_session.close() @@ -53,21 +55,29 @@ class TogetherLLM(LLM):          prompt += "<bot>:"          return prompt -    async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_complete( +        self, prompt, with_history: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy()          args.update(kwargs)          args["stream_tokens"] = True          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) - -        async with self._client_session.post(f"{self.base_url}/inference", json={ -            "prompt": self.convert_to_prompt(messages), -            **args -        }, headers={ -            "Authorization": f"Bearer {self.api_key}" -        }) as resp: +            self.name, +            with_history, +            self.context_length, +            args["max_tokens"], +            prompt, +            functions=args.get("functions", None), +            system_message=self.system_message, +        ) + +        async with self._client_session.post( +            f"{self.base_url}/inference", +            json={"prompt": self.convert_to_prompt(messages), **args}, +            headers={"Authorization": f"Bearer {self.api_key}"}, +        ) as resp:              async for line in resp.content.iter_any():                  if line:                      try: @@ -75,22 +85,32 @@ class TogetherLLM(LLM):                      except:                          raise Exception(str(line)) -    async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat( +        self, messages: List[ChatMessage] = None, **kwargs +    ) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) +            self.name, +            messages, +            self.context_length, +            args["max_tokens"], +            None, +            functions=args.get("functions", None), +            system_message=self.system_message, +        )          args["stream_tokens"] = True -        async with self._client_session.post(f"{self.base_url}/inference", json={ -            "prompt": self.convert_to_prompt(messages), -            **args -        }, headers={ -            "Authorization": f"Bearer {self.api_key}" -        }) as resp: +        async with self._client_session.post( +            f"{self.base_url}/inference", +            json={"prompt": self.convert_to_prompt(messages), **args}, +            headers={"Authorization": f"Bearer {self.api_key}"}, +        ) as resp:              async for line in resp.content.iter_chunks():                  if line[1]:                      json_chunk = line[0].decode("utf-8") -                    if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): +                    if json_chunk.startswith(": ping - ") or json_chunk.startswith( +                        "data: [DONE]" +                    ):                          continue                      if json_chunk.startswith("data: "):                          json_chunk = json_chunk[6:] @@ -101,20 +121,28 @@ class TogetherLLM(LLM):                              if "choices" in json_chunk:                                  yield {                                      "role": "assistant", -                                    "content": json_chunk["choices"][0]["text"] +                                    "content": json_chunk["choices"][0]["text"],                                  } -    async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: +    async def complete( +        self, prompt: str, with_history: List[ChatMessage] = None, **kwargs +    ) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} -        messages = compile_chat_messages(args["model"], with_history, self.context_length, -                                         args["max_tokens"], prompt, functions=None, system_message=self.system_message) -        async with self._client_session.post(f"{self.base_url}/inference", json={ -            "prompt": self.convert_to_prompt(messages), -            **args -        }, headers={ -            "Authorization": f"Bearer {self.api_key}" -        }) as resp: +        messages = compile_chat_messages( +            args["model"], +            with_history, +            self.context_length, +            args["max_tokens"], +            prompt, +            functions=None, +            system_message=self.system_message, +        ) +        async with self._client_session.post( +            f"{self.base_url}/inference", +            json={"prompt": self.convert_to_prompt(messages), **args}, +            headers={"Authorization": f"Bearer {self.api_key}"}, +        ) as resp:              try:                  text = await resp.text()                  j = json.loads(text) diff --git a/continuedev/src/continuedev/libs/util/calculate_diff.py b/continuedev/src/continuedev/libs/util/calculate_diff.py index e8e48839..99301ae7 100644 --- a/continuedev/src/continuedev/libs/util/calculate_diff.py +++ b/continuedev/src/continuedev/libs/util/calculate_diff.py @@ -1,7 +1,8 @@  import difflib  from typing import List -from ...models.main import Position, Range +  from ...models.filesystem import FileEdit +from ...models.main import Position, Range  def calculate_diff(filepath: str, original: str, updated: str) -> List[FileEdit]: @@ -14,16 +15,25 @@ def calculate_diff(filepath: str, original: str, updated: str) -> List[FileEdit]          if tag == "equal":              pass          elif tag == "delete": -            edits.append(FileEdit.from_deletion( -                filepath, Range.from_indices(original, i1, i2))) +            edits.append( +                FileEdit.from_deletion(filepath, Range.from_indices(original, i1, i2)) +            )              offset -= i2 - i1          elif tag == "insert": -            edits.append(FileEdit.from_insertion( -                filepath, Position.from_index(original, i1), replacement)) +            edits.append( +                FileEdit.from_insertion( +                    filepath, Position.from_index(original, i1), replacement +                ) +            )              offset += j2 - j1          elif tag == "replace": -            edits.append(FileEdit(filepath=filepath, range=Range.from_indices( -                original, i1, i2), replacement=replacement)) +            edits.append( +                FileEdit( +                    filepath=filepath, +                    range=Range.from_indices(original, i1, i2), +                    replacement=replacement, +                ) +            )              offset += (j2 - j1) - (i2 - i1)          else:              raise Exception("Unexpected difflib.SequenceMatcher tag: " + tag) @@ -59,17 +69,27 @@ def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit              if tag == "equal":                  continue  # ;)              elif tag == "delete": -                edits.append(FileEdit.from_deletion( -                    filepath, Range.from_indices(original, i1, i2))) +                edits.append( +                    FileEdit.from_deletion( +                        filepath, Range.from_indices(original, i1, i2) +                    ) +                )              elif tag == "insert": -                edits.append(FileEdit.from_insertion( -                    filepath, Position.from_index(original, i1), replacement)) +                edits.append( +                    FileEdit.from_insertion( +                        filepath, Position.from_index(original, i1), replacement +                    ) +                )              elif tag == "replace": -                edits.append(FileEdit(filepath=filepath, range=Range.from_indices( -                    original, i1, i2), replacement=replacement)) +                edits.append( +                    FileEdit( +                        filepath=filepath, +                        range=Range.from_indices(original, i1, i2), +                        replacement=replacement, +                    ) +                )              else: -                raise Exception( -                    "Unexpected difflib.SequenceMatcher tag: " + tag) +                raise Exception("Unexpected difflib.SequenceMatcher tag: " + tag)              break          original = apply_edit_to_str(original, edits[-1]) @@ -82,17 +102,17 @@ def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit  def read_range_in_str(s: str, r: Range) -> str: -    lines = s.splitlines()[r.start.line:r.end.line + 1] +    lines = s.splitlines()[r.start.line : r.end.line + 1]      if len(lines) == 0:          return "" -    lines[0] = lines[0][r.start.character:] -    lines[-1] = lines[-1][:r.end.character + 1] +    lines[0] = lines[0][r.start.character :] +    lines[-1] = lines[-1][: r.end.character + 1]      return "\n".join(lines)  def apply_edit_to_str(s: str, edit: FileEdit) -> str: -    original = read_range_in_str(s, edit.range) +    read_range_in_str(s, edit.range)      # Split lines and deal with some edge cases (could obviously be nicer)      lines = s.splitlines() @@ -104,27 +124,30 @@ def apply_edit_to_str(s: str, edit: FileEdit) -> str:      if len(lines) == 0:          lines = [""] -    end = Position(line=edit.range.end.line, -                   character=edit.range.end.character) +    end = Position(line=edit.range.end.line, character=edit.range.end.character)      if edit.range.end.line == len(lines) and edit.range.end.character == 0: -        end = Position(line=edit.range.end.line - 1, -                       character=len(lines[min(len(lines) - 1, edit.range.end.line - 1)])) +        end = Position( +            line=edit.range.end.line - 1, +            character=len(lines[min(len(lines) - 1, edit.range.end.line - 1)]), +        ) -    before_lines = lines[:edit.range.start.line] -    after_lines = lines[end.line + 1:] -    between_str = lines[min(len(lines) - 1, edit.range.start.line)][:edit.range.start.character] + \ -        edit.replacement + \ -        lines[min(len(lines) - 1, end.line)][end.character + 1:] +    before_lines = lines[: edit.range.start.line] +    after_lines = lines[end.line + 1 :] +    between_str = ( +        lines[min(len(lines) - 1, edit.range.start.line)][: edit.range.start.character] +        + edit.replacement +        + lines[min(len(lines) - 1, end.line)][end.character + 1 :] +    ) -    new_range = Range( +    Range(          start=edit.range.start,          end=Position( -            line=edit.range.start.line + -            len(edit.replacement.splitlines()) - 1, -            character=edit.range.start.character + -            len(edit.replacement.splitlines() -                [-1]) if edit.replacement != "" else 0 -        ) +            line=edit.range.start.line + len(edit.replacement.splitlines()) - 1, +            character=edit.range.start.character +            + len(edit.replacement.splitlines()[-1]) +            if edit.replacement != "" +            else 0, +        ),      )      lines = before_lines + between_str.splitlines() + after_lines diff --git a/continuedev/src/continuedev/libs/util/commonregex.py b/continuedev/src/continuedev/libs/util/commonregex.py index 55da7fc0..3c4fb38c 100644 --- a/continuedev/src/continuedev/libs/util/commonregex.py +++ b/continuedev/src/continuedev/libs/util/commonregex.py @@ -1,37 +1,54 @@  # coding: utf-8 -import json  import re -from typing import Any, Dict +from typing import Any  date = re.compile( -    '(?:(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?\s+(?:of\s+)?(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)|(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)\s+(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?)(?:\,)?\s*(?:\d{4})?|[0-3]?\d[-\./][0-3]?\d[-\./]\d{2,4}', re.IGNORECASE) -time = re.compile( -    '\d{1,2}:\d{2} ?(?:[ap]\.?m\.?)?|\d[ap]\.?m\.?', re.IGNORECASE) +    "(?:(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?\s+(?:of\s+)?(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)|(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)\s+(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?)(?:\,)?\s*(?:\d{4})?|[0-3]?\d[-\./][0-3]?\d[-\./]\d{2,4}", +    re.IGNORECASE, +) +time = re.compile("\d{1,2}:\d{2} ?(?:[ap]\.?m\.?)?|\d[ap]\.?m\.?", re.IGNORECASE)  phone = re.compile( -    '''((?:(?<![\d-])(?:\+?\d{1,3}[-.\s*]?)?(?:\(?\d{3}\)?[-.\s*]?)?\d{3}[-.\s*]?\d{4}(?![\d-]))|(?:(?<![\d-])(?:(?:\(\+?\d{2}\))|(?:\+?\d{2}))\s*\d{2}\s*\d{3}\s*\d{4}(?![\d-])))''') +    """((?:(?<![\d-])(?:\+?\d{1,3}[-.\s*]?)?(?:\(?\d{3}\)?[-.\s*]?)?\d{3}[-.\s*]?\d{4}(?![\d-]))|(?:(?<![\d-])(?:(?:\(\+?\d{2}\))|(?:\+?\d{2}))\s*\d{2}\s*\d{3}\s*\d{4}(?![\d-])))""" +)  phones_with_exts = re.compile( -    '((?:(?:\+?1\s*(?:[.-]\s*)?)?(?:\(\s*(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9])\s*\)|(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9]))\s*(?:[.-]\s*)?)?(?:[2-9]1[02-9]|[2-9][02-9]1|[2-9][02-9]{2})\s*(?:[.-]\s*)?(?:[0-9]{4})(?:\s*(?:#|x\.?|ext\.?|extension)\s*(?:\d+)?))', re.IGNORECASE) -link = re.compile('(?i)((?:https?://|www\d{0,3}[.])?[a-z0-9.\-]+[.](?:(?:international)|(?:construction)|(?:contractors)|(?:enterprises)|(?:photography)|(?:immobilien)|(?:management)|(?:technology)|(?:directory)|(?:education)|(?:equipment)|(?:institute)|(?:marketing)|(?:solutions)|(?:builders)|(?:clothing)|(?:computer)|(?:democrat)|(?:diamonds)|(?:graphics)|(?:holdings)|(?:lighting)|(?:plumbing)|(?:training)|(?:ventures)|(?:academy)|(?:careers)|(?:company)|(?:domains)|(?:florist)|(?:gallery)|(?:guitars)|(?:holiday)|(?:kitchen)|(?:recipes)|(?:shiksha)|(?:singles)|(?:support)|(?:systems)|(?:agency)|(?:berlin)|(?:camera)|(?:center)|(?:coffee)|(?:estate)|(?:kaufen)|(?:luxury)|(?:monash)|(?:museum)|(?:photos)|(?:repair)|(?:social)|(?:tattoo)|(?:travel)|(?:viajes)|(?:voyage)|(?:build)|(?:cheap)|(?:codes)|(?:dance)|(?:email)|(?:glass)|(?:house)|(?:ninja)|(?:photo)|(?:shoes)|(?:solar)|(?:today)|(?:aero)|(?:arpa)|(?:asia)|(?:bike)|(?:buzz)|(?:camp)|(?:club)|(?:coop)|(?:farm)|(?:gift)|(?:guru)|(?:info)|(?:jobs)|(?:kiwi)|(?:land)|(?:limo)|(?:link)|(?:menu)|(?:mobi)|(?:moda)|(?:name)|(?:pics)|(?:pink)|(?:post)|(?:rich)|(?:ruhr)|(?:sexy)|(?:tips)|(?:wang)|(?:wien)|(?:zone)|(?:biz)|(?:cab)|(?:cat)|(?:ceo)|(?:com)|(?:edu)|(?:gov)|(?:int)|(?:mil)|(?:net)|(?:onl)|(?:org)|(?:pro)|(?:red)|(?:tel)|(?:uno)|(?:xxx)|(?:ac)|(?:ad)|(?:ae)|(?:af)|(?:ag)|(?:ai)|(?:al)|(?:am)|(?:an)|(?:ao)|(?:aq)|(?:ar)|(?:as)|(?:at)|(?:au)|(?:aw)|(?:ax)|(?:az)|(?:ba)|(?:bb)|(?:bd)|(?:be)|(?:bf)|(?:bg)|(?:bh)|(?:bi)|(?:bj)|(?:bm)|(?:bn)|(?:bo)|(?:br)|(?:bs)|(?:bt)|(?:bv)|(?:bw)|(?:by)|(?:bz)|(?:ca)|(?:cc)|(?:cd)|(?:cf)|(?:cg)|(?:ch)|(?:ci)|(?:ck)|(?:cl)|(?:cm)|(?:cn)|(?:co)|(?:cr)|(?:cu)|(?:cv)|(?:cw)|(?:cx)|(?:cy)|(?:cz)|(?:de)|(?:dj)|(?:dk)|(?:dm)|(?:do)|(?:dz)|(?:ec)|(?:ee)|(?:eg)|(?:er)|(?:es)|(?:et)|(?:eu)|(?:fi)|(?:fj)|(?:fk)|(?:fm)|(?:fo)|(?:fr)|(?:ga)|(?:gb)|(?:gd)|(?:ge)|(?:gf)|(?:gg)|(?:gh)|(?:gi)|(?:gl)|(?:gm)|(?:gn)|(?:gp)|(?:gq)|(?:gr)|(?:gs)|(?:gt)|(?:gu)|(?:gw)|(?:gy)|(?:hk)|(?:hm)|(?:hn)|(?:hr)|(?:ht)|(?:hu)|(?:id)|(?:ie)|(?:il)|(?:im)|(?:in)|(?:io)|(?:iq)|(?:ir)|(?:is)|(?:it)|(?:je)|(?:jm)|(?:jo)|(?:jp)|(?:ke)|(?:kg)|(?:kh)|(?:ki)|(?:km)|(?:kn)|(?:kp)|(?:kr)|(?:kw)|(?:ky)|(?:kz)|(?:la)|(?:lb)|(?:lc)|(?:li)|(?:lk)|(?:lr)|(?:ls)|(?:lt)|(?:lu)|(?:lv)|(?:ly)|(?:ma)|(?:mc)|(?:md)|(?:me)|(?:mg)|(?:mh)|(?:mk)|(?:ml)|(?:mm)|(?:mn)|(?:mo)|(?:mp)|(?:mq)|(?:mr)|(?:ms)|(?:mt)|(?:mu)|(?:mv)|(?:mw)|(?:mx)|(?:my)|(?:mz)|(?:na)|(?:nc)|(?:ne)|(?:nf)|(?:ng)|(?:ni)|(?:nl)|(?:no)|(?:np)|(?:nr)|(?:nu)|(?:nz)|(?:om)|(?:pa)|(?:pe)|(?:pf)|(?:pg)|(?:ph)|(?:pk)|(?:pl)|(?:pm)|(?:pn)|(?:pr)|(?:ps)|(?:pt)|(?:pw)|(?:py)|(?:qa)|(?:re)|(?:ro)|(?:rs)|(?:ru)|(?:rw)|(?:sa)|(?:sb)|(?:sc)|(?:sd)|(?:se)|(?:sg)|(?:sh)|(?:si)|(?:sj)|(?:sk)|(?:sl)|(?:sm)|(?:sn)|(?:so)|(?:sr)|(?:st)|(?:su)|(?:sv)|(?:sx)|(?:sy)|(?:sz)|(?:tc)|(?:td)|(?:tf)|(?:tg)|(?:th)|(?:tj)|(?:tk)|(?:tl)|(?:tm)|(?:tn)|(?:to)|(?:tp)|(?:tr)|(?:tt)|(?:tv)|(?:tw)|(?:tz)|(?:ua)|(?:ug)|(?:uk)|(?:us)|(?:uy)|(?:uz)|(?:va)|(?:vc)|(?:ve)|(?:vg)|(?:vi)|(?:vn)|(?:vu)|(?:wf)|(?:ws)|(?:ye)|(?:yt)|(?:za)|(?:zm)|(?:zw))(?:/[^\s()<>]+[^\s`!()\[\]{};:\'".,<>?\xab\xbb\u201c\u201d\u2018\u2019])?)', re.IGNORECASE) +    "((?:(?:\+?1\s*(?:[.-]\s*)?)?(?:\(\s*(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9])\s*\)|(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9]))\s*(?:[.-]\s*)?)?(?:[2-9]1[02-9]|[2-9][02-9]1|[2-9][02-9]{2})\s*(?:[.-]\s*)?(?:[0-9]{4})(?:\s*(?:#|x\.?|ext\.?|extension)\s*(?:\d+)?))", +    re.IGNORECASE, +) +link = re.compile( +    "(?i)((?:https?://|www\d{0,3}[.])?[a-z0-9.\-]+[.](?:(?:international)|(?:construction)|(?:contractors)|(?:enterprises)|(?:photography)|(?:immobilien)|(?:management)|(?:technology)|(?:directory)|(?:education)|(?:equipment)|(?:institute)|(?:marketing)|(?:solutions)|(?:builders)|(?:clothing)|(?:computer)|(?:democrat)|(?:diamonds)|(?:graphics)|(?:holdings)|(?:lighting)|(?:plumbing)|(?:training)|(?:ventures)|(?:academy)|(?:careers)|(?:company)|(?:domains)|(?:florist)|(?:gallery)|(?:guitars)|(?:holiday)|(?:kitchen)|(?:recipes)|(?:shiksha)|(?:singles)|(?:support)|(?:systems)|(?:agency)|(?:berlin)|(?:camera)|(?:center)|(?:coffee)|(?:estate)|(?:kaufen)|(?:luxury)|(?:monash)|(?:museum)|(?:photos)|(?:repair)|(?:social)|(?:tattoo)|(?:travel)|(?:viajes)|(?:voyage)|(?:build)|(?:cheap)|(?:codes)|(?:dance)|(?:email)|(?:glass)|(?:house)|(?:ninja)|(?:photo)|(?:shoes)|(?:solar)|(?:today)|(?:aero)|(?:arpa)|(?:asia)|(?:bike)|(?:buzz)|(?:camp)|(?:club)|(?:coop)|(?:farm)|(?:gift)|(?:guru)|(?:info)|(?:jobs)|(?:kiwi)|(?:land)|(?:limo)|(?:link)|(?:menu)|(?:mobi)|(?:moda)|(?:name)|(?:pics)|(?:pink)|(?:post)|(?:rich)|(?:ruhr)|(?:sexy)|(?:tips)|(?:wang)|(?:wien)|(?:zone)|(?:biz)|(?:cab)|(?:cat)|(?:ceo)|(?:com)|(?:edu)|(?:gov)|(?:int)|(?:mil)|(?:net)|(?:onl)|(?:org)|(?:pro)|(?:red)|(?:tel)|(?:uno)|(?:xxx)|(?:ac)|(?:ad)|(?:ae)|(?:af)|(?:ag)|(?:ai)|(?:al)|(?:am)|(?:an)|(?:ao)|(?:aq)|(?:ar)|(?:as)|(?:at)|(?:au)|(?:aw)|(?:ax)|(?:az)|(?:ba)|(?:bb)|(?:bd)|(?:be)|(?:bf)|(?:bg)|(?:bh)|(?:bi)|(?:bj)|(?:bm)|(?:bn)|(?:bo)|(?:br)|(?:bs)|(?:bt)|(?:bv)|(?:bw)|(?:by)|(?:bz)|(?:ca)|(?:cc)|(?:cd)|(?:cf)|(?:cg)|(?:ch)|(?:ci)|(?:ck)|(?:cl)|(?:cm)|(?:cn)|(?:co)|(?:cr)|(?:cu)|(?:cv)|(?:cw)|(?:cx)|(?:cy)|(?:cz)|(?:de)|(?:dj)|(?:dk)|(?:dm)|(?:do)|(?:dz)|(?:ec)|(?:ee)|(?:eg)|(?:er)|(?:es)|(?:et)|(?:eu)|(?:fi)|(?:fj)|(?:fk)|(?:fm)|(?:fo)|(?:fr)|(?:ga)|(?:gb)|(?:gd)|(?:ge)|(?:gf)|(?:gg)|(?:gh)|(?:gi)|(?:gl)|(?:gm)|(?:gn)|(?:gp)|(?:gq)|(?:gr)|(?:gs)|(?:gt)|(?:gu)|(?:gw)|(?:gy)|(?:hk)|(?:hm)|(?:hn)|(?:hr)|(?:ht)|(?:hu)|(?:id)|(?:ie)|(?:il)|(?:im)|(?:in)|(?:io)|(?:iq)|(?:ir)|(?:is)|(?:it)|(?:je)|(?:jm)|(?:jo)|(?:jp)|(?:ke)|(?:kg)|(?:kh)|(?:ki)|(?:km)|(?:kn)|(?:kp)|(?:kr)|(?:kw)|(?:ky)|(?:kz)|(?:la)|(?:lb)|(?:lc)|(?:li)|(?:lk)|(?:lr)|(?:ls)|(?:lt)|(?:lu)|(?:lv)|(?:ly)|(?:ma)|(?:mc)|(?:md)|(?:me)|(?:mg)|(?:mh)|(?:mk)|(?:ml)|(?:mm)|(?:mn)|(?:mo)|(?:mp)|(?:mq)|(?:mr)|(?:ms)|(?:mt)|(?:mu)|(?:mv)|(?:mw)|(?:mx)|(?:my)|(?:mz)|(?:na)|(?:nc)|(?:ne)|(?:nf)|(?:ng)|(?:ni)|(?:nl)|(?:no)|(?:np)|(?:nr)|(?:nu)|(?:nz)|(?:om)|(?:pa)|(?:pe)|(?:pf)|(?:pg)|(?:ph)|(?:pk)|(?:pl)|(?:pm)|(?:pn)|(?:pr)|(?:ps)|(?:pt)|(?:pw)|(?:py)|(?:qa)|(?:re)|(?:ro)|(?:rs)|(?:ru)|(?:rw)|(?:sa)|(?:sb)|(?:sc)|(?:sd)|(?:se)|(?:sg)|(?:sh)|(?:si)|(?:sj)|(?:sk)|(?:sl)|(?:sm)|(?:sn)|(?:so)|(?:sr)|(?:st)|(?:su)|(?:sv)|(?:sx)|(?:sy)|(?:sz)|(?:tc)|(?:td)|(?:tf)|(?:tg)|(?:th)|(?:tj)|(?:tk)|(?:tl)|(?:tm)|(?:tn)|(?:to)|(?:tp)|(?:tr)|(?:tt)|(?:tv)|(?:tw)|(?:tz)|(?:ua)|(?:ug)|(?:uk)|(?:us)|(?:uy)|(?:uz)|(?:va)|(?:vc)|(?:ve)|(?:vg)|(?:vi)|(?:vn)|(?:vu)|(?:wf)|(?:ws)|(?:ye)|(?:yt)|(?:za)|(?:zm)|(?:zw))(?:/[^\s()<>]+[^\s`!()\[\]{};:'\".,<>?\xab\xbb\u201c\u201d\u2018\u2019])?)", +    re.IGNORECASE, +)  email = re.compile( -    "([a-z0-9!#$%&'*+\/=?^_`{|.}~-]+@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?)", re.IGNORECASE) -ip = re.compile('(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)', re.IGNORECASE) +    "([a-z0-9!#$%&'*+\/=?^_`{|.}~-]+@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?)", +    re.IGNORECASE, +) +ip = re.compile( +    "(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)", +    re.IGNORECASE, +)  ipv6 = re.compile( -    '\s*(?!.*::.*::)(?:(?!:)|:(?=:))(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)){6}(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)[0-9a-f]{0,4}(?:(?<=::)|(?<!:)|(?<=:)(?<!::):)|(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)(?:\.(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)){3})\s*', re.VERBOSE | re.IGNORECASE | re.DOTALL) -price = re.compile('[$]\s?[+-]?[0-9]{1,3}(?:(?:,?[0-9]{3}))*(?:\.[0-9]{1,2})?') -hex_color = re.compile('(#(?:[0-9a-fA-F]{8})|#(?:[0-9a-fA-F]{3}){1,2})\\b') -credit_card = re.compile('((?:(?:\\d{4}[- ]?){3}\\d{4}|\\d{15,16}))(?![\\d])') +    "\s*(?!.*::.*::)(?:(?!:)|:(?=:))(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)){6}(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)[0-9a-f]{0,4}(?:(?<=::)|(?<!:)|(?<=:)(?<!::):)|(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)(?:\.(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)){3})\s*", +    re.VERBOSE | re.IGNORECASE | re.DOTALL, +) +price = re.compile("[$]\s?[+-]?[0-9]{1,3}(?:(?:,?[0-9]{3}))*(?:\.[0-9]{1,2})?") +hex_color = re.compile("(#(?:[0-9a-fA-F]{8})|#(?:[0-9a-fA-F]{3}){1,2})\\b") +credit_card = re.compile("((?:(?:\\d{4}[- ]?){3}\\d{4}|\\d{15,16}))(?![\\d])")  btc_address = re.compile( -    '(?<![a-km-zA-HJ-NP-Z0-9])[13][a-km-zA-HJ-NP-Z0-9]{26,33}(?![a-km-zA-HJ-NP-Z0-9])') +    "(?<![a-km-zA-HJ-NP-Z0-9])[13][a-km-zA-HJ-NP-Z0-9]{26,33}(?![a-km-zA-HJ-NP-Z0-9])" +)  street_address = re.compile( -    '\d{1,4} [\w\s]{1,20}(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|park|parkway|pkwy|circle|cir|boulevard|blvd)\W?(?=\s|$)', re.IGNORECASE) -zip_code = re.compile(r'\b\d{5}(?:[-\s]\d{4})?\b') -po_box = re.compile(r'P\.? ?O\.? Box \d+', re.IGNORECASE) +    "\d{1,4} [\w\s]{1,20}(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|park|parkway|pkwy|circle|cir|boulevard|blvd)\W?(?=\s|$)", +    re.IGNORECASE, +) +zip_code = re.compile(r"\b\d{5}(?:[-\s]\d{4})?\b") +po_box = re.compile(r"P\.? ?O\.? Box \d+", re.IGNORECASE)  ssn = re.compile( -    '(?!000|666|333)0*(?:[0-6][0-9][0-9]|[0-7][0-6][0-9]|[0-7][0-7][0-2])[- ](?!00)[0-9]{2}[- ](?!0000)[0-9]{4}') +    "(?!000|666|333)0*(?:[0-6][0-9][0-9]|[0-7][0-6][0-9]|[0-7][0-7][0-2])[- ](?!00)[0-9]{2}[- ](?!0000)[0-9]{4}" +)  win_absolute_filepath = re.compile( -    r'^(?:[a-zA-Z]\:|\\\\[\w\.]+\\[\w.$]+)\\(?:[\w]+\\)*\w([\w.])+', re.IGNORECASE) -unix_absolute_filepath = re.compile( -    r'^\/(?:[\/\w]+\/)*\w([\w.])+', re.IGNORECASE) +    r"^(?:[a-zA-Z]\:|\\\\[\w\.]+\\[\w.$]+)\\(?:[\w]+\\)*\w([\w.])+", re.IGNORECASE +) +unix_absolute_filepath = re.compile(r"^\/(?:[\/\w]+\/)*\w([\w.])+", re.IGNORECASE)  regexes = {      "win_absolute_filepath": win_absolute_filepath, @@ -77,7 +94,6 @@ placeholders = {  class regex: -      def __init__(self, obj, regex):          self.obj = obj          self.regex = regex @@ -85,11 +101,11 @@ class regex:      def __call__(self, *args):          def regex_method(text=None):              return [x.strip() for x in self.regex.findall(text or self.obj.text)] +          return regex_method  class CommonRegex(object): -      def __init__(self, text=""):          self.text = text diff --git a/continuedev/src/continuedev/libs/util/copy_codebase.py b/continuedev/src/continuedev/libs/util/copy_codebase.py index 97143faf..aafb435c 100644 --- a/continuedev/src/continuedev/libs/util/copy_codebase.py +++ b/continuedev/src/continuedev/libs/util/copy_codebase.py @@ -1,14 +1,24 @@  import os +import shutil  from pathlib import Path  from typing import Iterable, List, Union -from watchdog.observers import Observer +  from watchdog.events import PatternMatchingEventHandler -from ...models.main import FileEdit, DeleteDirectory, DeleteFile, AddDirectory, AddFile, FileSystemEdit, RenameFile, RenameDirectory, SequentialFileSystemEdit -from ...models.filesystem import FileSystem +from watchdog.observers import Observer +  from ...core.autopilot import Autopilot +from ...models.filesystem import FileSystem +from ...models.main import ( +    AddDirectory, +    AddFile, +    DeleteDirectory, +    DeleteFile, +    FileSystemEdit, +    RenameDirectory, +    RenameFile, +    SequentialFileSystemEdit, +)  from .map_path import map_path -from ...core.sdk import ManualEditStep -import shutil  def create_copy(orig_root: str, copy_root: str = None, ignore: Iterable[str] = []): @@ -24,8 +34,7 @@ def create_copy(orig_root: str, copy_root: str = None, ignore: Iterable[str] = [          if os.path.isdir(child):              if child not in ignore:                  os.mkdir(map_path(child)) -                create_copy(Path(orig_root) / child, -                            Path(copy_root) / child, ignore) +                create_copy(Path(orig_root) / child, Path(copy_root) / child, ignore)              else:                  os.symlink(child, map_path(child))          else: @@ -37,8 +46,18 @@ def create_copy(orig_root: str, copy_root: str = None, ignore: Iterable[str] = [  # The whole usage of watchdog here should only be specific to RealFileSystem, you want to have a different "Observer" class for VirtualFileSystem, which would depend on being sent notifications  class CopyCodebaseEventHandler(PatternMatchingEventHandler): -    def __init__(self, ignore_directories: List[str], ignore_patterns: List[str], autopilot: Autopilot, orig_root: str, copy_root: str, filesystem: FileSystem): -        super().__init__(ignore_directories=ignore_directories, ignore_patterns=ignore_patterns) +    def __init__( +        self, +        ignore_directories: List[str], +        ignore_patterns: List[str], +        autopilot: Autopilot, +        orig_root: str, +        copy_root: str, +        filesystem: FileSystem, +    ): +        super().__init__( +            ignore_directories=ignore_directories, ignore_patterns=ignore_patterns +        )          self.autopilot = autopilot          self.orig_root = orig_root          self.copy_root = copy_root @@ -85,10 +104,13 @@ class CopyCodebaseEventHandler(PatternMatchingEventHandler):          self.autopilot.act(action) -def maintain_copy_workspace(autopilot: Autopilot, filesystem: FileSystem, orig_root: str, copy_root: str): +def maintain_copy_workspace( +    autopilot: Autopilot, filesystem: FileSystem, orig_root: str, copy_root: str +):      observer = Observer()      event_handler = CopyCodebaseEventHandler( -        [".git"], [], autopilot, orig_root, copy_root, filesystem) +        [".git"], [], autopilot, orig_root, copy_root, filesystem +    )      observer.schedule(event_handler, orig_root, recursive=True)      observer.start()      try: diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 3b594036..a339c288 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,10 +1,10 @@  import json  from typing import Dict, List, Union + +import tiktoken +  from ...core.main import ChatMessage  from .templating import render_templated_string -from ...libs.llm import LLM -from tiktoken_ext import openai_public -import tiktoken  # TODO move many of these into specific LLM.properties() function that  # contains max tokens, if its a chat model or not, default args (not all models @@ -15,8 +15,13 @@ aliases = {      "claude-2": "gpt-3.5-turbo",  }  DEFAULT_MAX_TOKENS = 2048 -DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1, -                "frequency_penalty": 0, "presence_penalty": 0} +DEFAULT_ARGS = { +    "max_tokens": DEFAULT_MAX_TOKENS, +    "temperature": 0.5, +    "top_p": 1, +    "frequency_penalty": 0, +    "presence_penalty": 0, +}  def encoding_for_model(model_name: str): @@ -41,7 +46,9 @@ def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int      return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE -def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int): +def prune_raw_prompt_from_top( +    model_name: str, context_length: int, prompt: str, tokens_for_completion: int +):      max_tokens = context_length - tokens_for_completion      encoding = encoding_for_model(model_name)      tokens = encoding.encode(prompt, disallowed_special=()) @@ -51,10 +58,15 @@ def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str,          return encoding.decode(tokens[-max_tokens:]) -def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int): -    total_tokens = tokens_for_completion + \ -        sum(count_chat_message_tokens(model_name, message) -            for message in chat_history) +def prune_chat_history( +    model_name: str, +    chat_history: List[ChatMessage], +    context_length: int, +    tokens_for_completion: int, +): +    total_tokens = tokens_for_completion + sum( +        count_chat_message_tokens(model_name, message) for message in chat_history +    )      # 1. Replace beyond last 5 messages with summary      i = 0 @@ -66,13 +78,21 @@ def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context          i += 1      # 2. Remove entire messages until the last 5 -    while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0: +    while ( +        len(chat_history) > 5 +        and total_tokens > context_length +        and len(chat_history) > 0 +    ):          message = chat_history.pop(0)          total_tokens -= count_tokens(model_name, message.content)      # 3. Truncate message in the last 5, except last 1      i = 0 -    while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1: +    while ( +        total_tokens > context_length +        and len(chat_history) > 0 +        and i < len(chat_history) - 1 +    ):          message = chat_history[i]          total_tokens -= count_tokens(model_name, message.content)          total_tokens += count_tokens(model_name, message.summary) @@ -88,7 +108,8 @@ def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context      if total_tokens > context_length and len(chat_history) > 0:          message = chat_history[0]          message.content = prune_raw_prompt_from_top( -            model_name, context_length, message.content, tokens_for_completion) +            model_name, context_length, message.content, tokens_for_completion +        )          total_tokens = context_length      return chat_history @@ -98,12 +119,19 @@ def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context  TOKEN_BUFFER_FOR_SAFETY = 100 -def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages( +    model_name: str, +    msgs: Union[List[ChatMessage], None], +    context_length: int, +    max_tokens: int, +    prompt: Union[str, None] = None, +    functions: Union[List, None] = None, +    system_message: Union[str, None] = None, +) -> List[Dict]:      """      The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message      """ -    msgs_copy = [msg.copy(deep=True) -                 for msg in msgs] if msgs is not None else [] +    msgs_copy = [msg.copy(deep=True) for msg in msgs] if msgs is not None else []      if prompt is not None:          prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt) @@ -114,7 +142,10 @@ def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None],          # but move back to start after processing          rendered_system_message = render_templated_string(system_message)          system_chat_msg = ChatMessage( -            role="system", content=rendered_system_message, summary=rendered_system_message) +            role="system", +            content=rendered_system_message, +            summary=rendered_system_message, +        )          # insert at second-to-last position          msgs_copy.insert(-1, system_chat_msg) @@ -125,13 +156,20 @@ def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None],              function_tokens += count_tokens(model_name, json.dumps(function))      msgs_copy = prune_chat_history( -        model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) +        model_name, +        msgs_copy, +        context_length, +        function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY, +    ) -    history = [msg.to_dict(with_functions=functions is not None) -               for msg in msgs_copy] +    history = [msg.to_dict(with_functions=functions is not None) for msg in msgs_copy]      # Move system message back to start -    if system_message is not None and len(history) >= 2 and history[-2]["role"] == "system": +    if ( +        system_message is not None +        and len(history) >= 2 +        and history[-2]["role"] == "system" +    ):          system_message_dict = history.pop(-2)          history.insert(0, system_message_dict) diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py index 4c6d3c95..232d3fa1 100644 --- a/continuedev/src/continuedev/libs/util/create_async_task.py +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -1,13 +1,18 @@ -from typing import Callable, Coroutine, Optional, Union -import traceback -from .telemetry import posthog_logger -from .logging import logger  import asyncio +import traceback +from typing import Callable, Coroutine, Optional +  import nest_asyncio + +from .logging import logger +from .telemetry import posthog_logger +  nest_asyncio.apply() -def create_async_task(coro: Coroutine, on_error: Optional[Callable[[Exception], Coroutine]] = None): +def create_async_task( +    coro: Coroutine, on_error: Optional[Callable[[Exception], Coroutine]] = None +):      """asyncio.create_task and log errors by adding a callback"""      task = asyncio.create_task(coro) @@ -15,12 +20,15 @@ def create_async_task(coro: Coroutine, on_error: Optional[Callable[[Exception],          try:              future.result()          except Exception as e: -            formatted_tb = '\n'.join(traceback.format_exception(e)) -            logger.critical( -                f"Exception caught from async task: {formatted_tb}") -            posthog_logger.capture_event("async_task_error", { -                "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e)) -            }) +            formatted_tb = "\n".join(traceback.format_exception(e)) +            logger.critical(f"Exception caught from async task: {formatted_tb}") +            posthog_logger.capture_event( +                "async_task_error", +                { +                    "error_title": e.__str__() or e.__repr__(), +                    "error_message": "\n".join(traceback.format_exception(e)), +                }, +            )              # Log the error to the GUI              if on_error is not None: diff --git a/continuedev/src/continuedev/libs/util/logging.py b/continuedev/src/continuedev/libs/util/logging.py index 668d313f..4a550168 100644 --- a/continuedev/src/continuedev/libs/util/logging.py +++ b/continuedev/src/continuedev/libs/util/logging.py @@ -15,8 +15,7 @@ console_handler = logging.StreamHandler()  console_handler.setLevel(logging.DEBUG)  # Create a formatter -formatter = logging.Formatter( -    '[%(asctime)s] [%(levelname)s] %(message)s') +formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] %(message)s")  # Add the formatter to the handlers  file_handler.setFormatter(formatter) @@ -27,4 +26,4 @@ logger.addHandler(file_handler)  logger.addHandler(console_handler)  # Log a test message -logger.debug('Testing logs') +logger.debug("Testing logs") diff --git a/continuedev/src/continuedev/libs/util/map_path.py b/continuedev/src/continuedev/libs/util/map_path.py index 8eb57712..1dddc2e9 100644 --- a/continuedev/src/continuedev/libs/util/map_path.py +++ b/continuedev/src/continuedev/libs/util/map_path.py @@ -1,16 +1,16 @@  from pathlib import Path -def map_path(path: str, orig_root: str, copy_root: str) -> Path: -	path = Path(path) -	if path.is_relative_to(orig_root): -		if path.is_absolute(): -			return Path(copy_root) / path.relative_to(orig_root) -		else: -			return path -	else: -		if path.is_absolute(): -			return path -		else: -			# For this one, you need to know the directory from which the relative path is being used. -			return Path(orig_root) / path +def map_path(path: str, orig_root: str, copy_root: str) -> Path: +    path = Path(path) +    if path.is_relative_to(orig_root): +        if path.is_absolute(): +            return Path(copy_root) / path.relative_to(orig_root) +        else: +            return path +    else: +        if path.is_absolute(): +            return path +        else: +            # For this one, you need to know the directory from which the relative path is being used. +            return Path(orig_root) / path diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index 01b594cf..483f6b63 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -1,6 +1,11 @@  import os -from ..constants.main import CONTINUE_SESSIONS_FOLDER, CONTINUE_GLOBAL_FOLDER, CONTINUE_SERVER_FOLDER +  from ..constants.default_config import default_config +from ..constants.main import ( +    CONTINUE_GLOBAL_FOLDER, +    CONTINUE_SERVER_FOLDER, +    CONTINUE_SESSIONS_FOLDER, +)  def find_data_file(filename): @@ -36,7 +41,7 @@ def getSessionsListFilePath():      path = os.path.join(getSessionsFolderPath(), "sessions.json")      os.makedirs(os.path.dirname(path), exist_ok=True)      if not os.path.exists(path): -        with open(path, 'w') as f: +        with open(path, "w") as f:              f.write("[]")      return path @@ -46,19 +51,22 @@ def getConfigFilePath() -> str:      os.makedirs(os.path.dirname(path), exist_ok=True)      if not os.path.exists(path): -        with open(path, 'w') as f: +        with open(path, "w") as f:              f.write(default_config)      else: -        with open(path, 'r') as f: +        with open(path, "r") as f:              existing_content = f.read()          if existing_content.strip() == "": -            with open(path, 'w') as f: +            with open(path, "w") as f:                  f.write(default_config)          elif " continuedev.core" in existing_content: -            with open(path, 'w') as f: -                f.write(existing_content.replace(" continuedev.", -                                                 " continuedev.src.continuedev.")) +            with open(path, "w") as f: +                f.write( +                    existing_content.replace( +                        " continuedev.", " continuedev.src.continuedev." +                    ) +                )      return path diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py index ed1e79b7..0cca261f 100644 --- a/continuedev/src/continuedev/libs/util/step_name_to_steps.py +++ b/continuedev/src/continuedev/libs/util/step_name_to_steps.py @@ -1,20 +1,22 @@  from typing import Dict  from ...core.main import Step -from ...plugins.steps.core.core import UserInputStep -from ...plugins.steps.main import EditHighlightedCodeStep -from ...plugins.steps.chat import SimpleChatStep -from ...plugins.steps.comment_code import CommentCodeStep -from ...plugins.steps.feedback import FeedbackStep +from ...libs.util.logging import logger  from ...plugins.recipes.AddTransformRecipe.main import AddTransformRecipe  from ...plugins.recipes.CreatePipelineRecipe.main import CreatePipelineRecipe  from ...plugins.recipes.DDtoBQRecipe.main import DDtoBQRecipe -from ...plugins.recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe -from ...plugins.steps.on_traceback import DefaultOnTracebackStep +from ...plugins.recipes.DeployPipelineAirflowRecipe.main import ( +    DeployPipelineAirflowRecipe, +) +from ...plugins.steps.chat import SimpleChatStep  from ...plugins.steps.clear_history import ClearHistoryStep -from ...plugins.steps.open_config import OpenConfigStep +from ...plugins.steps.comment_code import CommentCodeStep +from ...plugins.steps.core.core import UserInputStep +from ...plugins.steps.feedback import FeedbackStep  from ...plugins.steps.help import HelpStep -from ...libs.util.logging import logger +from ...plugins.steps.main import EditHighlightedCodeStep +from ...plugins.steps.on_traceback import DefaultOnTracebackStep +from ...plugins.steps.open_config import OpenConfigStep  # This mapping is used to convert from string in ContinueConfig json to corresponding Step class.  # Used for example in slash_commands and steps_on_startup @@ -40,5 +42,6 @@ def get_step_from_name(step_name: str, params: Dict) -> Step:          return step_name_to_step_class[step_name](**params)      except:          logger.error( -            f"Incorrect parameters for step {step_name}. Parameters provided were: {params}") +            f"Incorrect parameters for step {step_name}. Parameters provided were: {params}" +        )          raise diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py index 285c1e47..d33c46c4 100644 --- a/continuedev/src/continuedev/libs/util/strings.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -43,7 +43,9 @@ def remove_quotes_and_escapes(output: str) -> str:      output = output.replace("\\n", "\n")      output = output.replace("\\t", "\t")      output = output.replace("\\\\", "\\") -    if (output.startswith('"') and output.endswith('"')) or (output.startswith("'") and output.endswith("'")): +    if (output.startswith('"') and output.endswith('"')) or ( +        output.startswith("'") and output.endswith("'") +    ):          output = output[1:-1]      return output diff --git a/continuedev/src/continuedev/libs/util/telemetry.py b/continuedev/src/continuedev/libs/util/telemetry.py index 2c76f4a3..ab5ec328 100644 --- a/continuedev/src/continuedev/libs/util/telemetry.py +++ b/continuedev/src/continuedev/libs/util/telemetry.py @@ -1,15 +1,16 @@ -from typing import Any -from posthog import Posthog  import os +from typing import Any +  from dotenv import load_dotenv +from posthog import Posthog + +from ..constants.main import CONTINUE_SERVER_VERSION_FILE  from .commonregex import clean_pii_from_any -from .logging import logger  from .paths import getServerFolderPath -from ..constants.main import CONTINUE_SERVER_VERSION_FILE  load_dotenv()  in_codespaces = os.getenv("CODESPACES") == "true" -POSTHOG_API_KEY = 'phc_JS6XFROuNbhJtVCEdTSYk6gl5ArRrTNMpCcguAXlSPs' +POSTHOG_API_KEY = "phc_JS6XFROuNbhJtVCEdTSYk6gl5ArRrTNMpCcguAXlSPs"  class PostHogLogger: @@ -20,16 +21,14 @@ class PostHogLogger:          self.api_key = api_key          # The personal API key is necessary only if you want to use local evaluation of feature flags. -        self.posthog = Posthog(self.api_key, host='https://app.posthog.com') +        self.posthog = Posthog(self.api_key, host="https://app.posthog.com")      def setup(self, unique_id: str, allow_anonymous_telemetry: bool):          self.unique_id = unique_id or "NO_UNIQUE_ID"          self.allow_anonymous_telemetry = allow_anonymous_telemetry or True          # Capture initial event -        self.capture_event("session_start", { -            "os": os.name -        }) +        self.capture_event("session_start", {"os": os.name})      def capture_event(self, event_name: str, event_properties: Any):          # logger.debug( @@ -53,13 +52,14 @@ class PostHogLogger:          # Add additional properties that are on every event          if in_codespaces: -            event_properties['codespaces'] = True +            event_properties["codespaces"] = True          server_version_file = os.path.join( -            getServerFolderPath(), CONTINUE_SERVER_VERSION_FILE) +            getServerFolderPath(), CONTINUE_SERVER_VERSION_FILE +        )          if os.path.exists(server_version_file):              with open(server_version_file, "r") as f: -                event_properties['server_version'] = f.read() +                event_properties["server_version"] = f.read()          # Send event to PostHog          self.posthog.capture(self.unique_id, event_name, event_properties) diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py index bb922ad7..edcf2884 100644 --- a/continuedev/src/continuedev/libs/util/templating.py +++ b/continuedev/src/continuedev/libs/util/templating.py @@ -1,4 +1,5 @@  import os +  import chevron @@ -6,14 +7,18 @@ def get_vars_in_template(template):      """      Get the variables in a template      """ -    return [token[1] for token in chevron.tokenizer.tokenize(template) if token[0] == 'variable'] +    return [ +        token[1] +        for token in chevron.tokenizer.tokenize(template) +        if token[0] == "variable" +    ]  def escape_var(var: str) -> str:      """      Escape a variable so it can be used in a template      """ -    return var.replace(os.path.sep, '').replace('.', '') +    return var.replace(os.path.sep, "").replace(".", "")  def render_templated_string(template: str) -> str: @@ -28,12 +33,11 @@ def render_templated_string(template: str) -> str:          if var.startswith(os.path.sep):              # Escape vars which are filenames, because mustache doesn't allow / in variable names              escaped_var = escape_var(var) -            template = template.replace( -                var, escaped_var) +            template = template.replace(var, escaped_var)              if os.path.exists(var): -                args[escaped_var] = open(var, 'r').read() +                args[escaped_var] = open(var, "r").read()              else: -                args[escaped_var] = '' +                args[escaped_var] = ""      return chevron.render(template, args) diff --git a/continuedev/src/continuedev/libs/util/traceback_parsers.py b/continuedev/src/continuedev/libs/util/traceback_parsers.py index a2e94c26..2b164c0f 100644 --- a/continuedev/src/continuedev/libs/util/traceback_parsers.py +++ b/continuedev/src/continuedev/libs/util/traceback_parsers.py @@ -15,11 +15,16 @@ def get_javascript_traceback(output: str) -> str:      first_line = None      for i in range(len(lines) - 1):          segs = lines[i].split(":") -        if len(segs) > 1 and segs[0] != "" and segs[1].startswith(" ") and lines[i + 1].strip().startswith("at"): +        if ( +            len(segs) > 1 +            and segs[0] != "" +            and segs[1].startswith(" ") +            and lines[i + 1].strip().startswith("at") +        ):              first_line = lines[i]              break      if first_line is not None: -        return "\n".join(lines[lines.index(first_line):]) +        return "\n".join(lines[lines.index(first_line) :])      else:          return None diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index ca12579c..de426282 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -1,9 +1,22 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Tuple  import os -from ..models.main import Position, Range, AbstractModel +from abc import abstractmethod +from typing import Dict, List, Tuple +  from pydantic import BaseModel -from .filesystem_edit import FileSystemEdit, FileEdit, AddFile, DeleteFile, RenameDirectory, RenameFile, AddDirectory, DeleteDirectory, EditDiff, SequentialFileSystemEdit + +from ..models.main import AbstractModel, Position, Range +from .filesystem_edit import ( +    AddDirectory, +    AddFile, +    DeleteDirectory, +    DeleteFile, +    EditDiff, +    FileEdit, +    FileSystemEdit, +    RenameDirectory, +    RenameFile, +    SequentialFileSystemEdit, +)  class RangeInFile(BaseModel): @@ -16,14 +29,12 @@ class RangeInFile(BaseModel):      @staticmethod      def from_entire_file(filepath: str, content: str) -> "RangeInFile":          range = Range.from_entire_file(content) -        return RangeInFile( -            filepath=filepath, -            range=range -        ) +        return RangeInFile(filepath=filepath, range=range)  class RangeInFileWithContents(RangeInFile):      """A range in a file with the contents of the range.""" +      contents: str      def __hash__(self): @@ -42,13 +53,15 @@ class RangeInFileWithContents(RangeInFile):          # Calculate union of contents          num_overlapping_lines = first.range.end.line - second.range.start.line + 1 -        union_lines = first.contents.splitlines()[:-num_overlapping_lines] + \ -            second.contents.splitlines() +        union_lines = ( +            first.contents.splitlines()[:-num_overlapping_lines] +            + second.contents.splitlines() +        )          return RangeInFileWithContents(              filepath=first.filepath,              range=first.range.union(second.range), -            contents="\n".join(union_lines) +            contents="\n".join(union_lines),          )      @staticmethod @@ -56,28 +69,24 @@ class RangeInFileWithContents(RangeInFile):          lines = content.splitlines()          if not lines:              return RangeInFileWithContents( -                filepath=filepath, -                range=Range.from_shorthand(0, 0, 0, 0), -                contents="" +                filepath=filepath, range=Range.from_shorthand(0, 0, 0, 0), contents=""              )          return RangeInFileWithContents(              filepath=filepath, -            range=Range.from_shorthand( -                0, 0, len(lines) - 1, len(lines[-1]) - 1), -            contents=content +            range=Range.from_shorthand(0, 0, len(lines) - 1, len(lines[-1]) - 1), +            contents=content,          )      @staticmethod      def from_range_in_file(rif: RangeInFile, content: str) -> "RangeInFileWithContents":          return RangeInFileWithContents( -            filepath=rif.filepath, -            range=rif.range, -            contents=content +            filepath=rif.filepath, range=rif.range, contents=content          )  class FileSystem(AbstractModel):      """An abstract filesystem that can read/write from a set of files.""" +      @abstractmethod      def read(self, path) -> str:          raise NotImplementedError @@ -129,12 +138,12 @@ class FileSystem(AbstractModel):      @classmethod      def read_range_in_str(self, s: str, r: Range) -> str: -        lines = s.split("\n")[r.start.line:r.end.line + 1] +        lines = s.split("\n")[r.start.line : r.end.line + 1]          if len(lines) == 0:              return "" -        lines[0] = lines[0][r.start.character:] -        lines[-1] = lines[-1][:r.end.character + 1] +        lines[0] = lines[0][r.start.character :] +        lines[-1] = lines[-1][: r.end.character + 1]          return "\n".join(lines)      @classmethod @@ -151,37 +160,40 @@ class FileSystem(AbstractModel):          if len(lines) == 0:              lines = [""] -        end = Position(line=edit.range.end.line, -                       character=edit.range.end.character) +        end = Position(line=edit.range.end.line, character=edit.range.end.character)          if edit.range.end.line == len(lines) and edit.range.end.character == 0: -            end = Position(line=edit.range.end.line - 1, -                           character=len(lines[min(len(lines) - 1, edit.range.end.line - 1)])) +            end = Position( +                line=edit.range.end.line - 1, +                character=len(lines[min(len(lines) - 1, edit.range.end.line - 1)]), +            ) -        before_lines = lines[:edit.range.start.line] -        after_lines = lines[end.line + 1:] -        between_str = lines[min(len(lines) - 1, edit.range.start.line)][:edit.range.start.character] + \ -            edit.replacement + \ -            lines[min(len(lines) - 1, end.line)][end.character + 1:] +        before_lines = lines[: edit.range.start.line] +        after_lines = lines[end.line + 1 :] +        between_str = ( +            lines[min(len(lines) - 1, edit.range.start.line)][ +                : edit.range.start.character +            ] +            + edit.replacement +            + lines[min(len(lines) - 1, end.line)][end.character + 1 :] +        )          new_range = Range(              start=edit.range.start,              end=Position( -                line=edit.range.start.line + -                len(edit.replacement.splitlines()) - 1, -                character=edit.range.start.character + -                len(edit.replacement.splitlines() -                    [-1]) if edit.replacement != "" else 0 -            ) +                line=edit.range.start.line + len(edit.replacement.splitlines()) - 1, +                character=edit.range.start.character +                + len(edit.replacement.splitlines()[-1]) +                if edit.replacement != "" +                else 0, +            ),          )          lines = before_lines + between_str.splitlines() + after_lines          return "\n".join(lines), EditDiff(              forward=edit,              backward=FileEdit( -                filepath=edit.filepath, -                range=new_range, -                replacement=original -            ) +                filepath=edit.filepath, range=new_range, replacement=original +            ),          )      def reverse_edit_on_str(self, s: str, diff: EditDiff) -> str: @@ -194,15 +206,17 @@ class FileSystem(AbstractModel):              start=diff.edit.range.start,              end=Position(                  line=diff.edit.range.start + replacement_d_lines, -                character=diff.edit.range.start.character + replacement_d_chars -            ) +                character=diff.edit.range.start.character + replacement_d_chars, +            ),          ) -        before_lines = lines[:replacement_range.start.line] -        after_lines = lines[replacement_range.end.line + 1:] -        between_str = lines[replacement_range.start.line][:replacement_range.start.character] + \ -            diff.original + \ -            lines[replacement_range.end.line][replacement_range.end.character + 1:] +        before_lines = lines[: replacement_range.start.line] +        after_lines = lines[replacement_range.end.line + 1 :] +        between_str = ( +            lines[replacement_range.start.line][: replacement_range.start.character] +            + diff.original +            + lines[replacement_range.end.line][replacement_range.end.character + 1 :] +        )          lines = before_lines + between_str.splitlines() + after_lines          return "\n".join(lines) @@ -221,8 +235,9 @@ class FileSystem(AbstractModel):              self.delete_file(edit.filepath)          elif isinstance(edit, RenameFile):              self.rename_file(edit.filepath, edit.new_filepath) -            backward = RenameFile(filepath=edit.new_filepath, -                                  new_filepath=edit.filepath) +            backward = RenameFile( +                filepath=edit.new_filepath, new_filepath=edit.filepath +            )          elif isinstance(edit, AddDirectory):              self.add_directory(edit.path)              backward = DeleteDirectory(edit.path) @@ -235,8 +250,7 @@ class FileSystem(AbstractModel):                      backward_edits.append(self.apply_edit(DeleteFile(path)))                  for d in dirs:                      path = os.path.join(root, d) -                    backward_edits.append( -                        self.apply_edit(DeleteDirectory(path))) +                    backward_edits.append(self.apply_edit(DeleteDirectory(path)))              backward_edits.append(self.apply_edit(DeleteDirectory(edit.path)))              backward_edits.reverse() @@ -253,10 +267,7 @@ class FileSystem(AbstractModel):          else:              raise TypeError("Unknown FileSystemEdit type: " + str(type(edit))) -        return EditDiff( -            forward=edit, -            backward=backward -        ) +        return EditDiff(forward=edit, backward=backward)  class RealFileSystem(FileSystem): @@ -304,6 +315,7 @@ class RealFileSystem(FileSystem):  class VirtualFileSystem(FileSystem):      """A simulated filesystem from a mapping of filepath to file contents.""" +      files: Dict[str, str]      def __init__(self, files: Dict[str, str]): @@ -331,7 +343,7 @@ class VirtualFileSystem(FileSystem):      def rename_directory(self, path: str, new_path: str):          for filepath in self.files:              if filepath.startswith(path): -                new_filepath = new_path + filepath[len(path):] +                new_filepath = new_path + filepath[len(path) :]                  self.files[new_filepath] = self.files[filepath]                  del self.files[filepath] @@ -349,9 +361,7 @@ class VirtualFileSystem(FileSystem):          old_content = self.read(edit.filepath)          new_content, original = FileSystem.apply_edit_to_str(old_content, edit)          self.write(edit.filepath, new_content) -        return EditDiff( -            edit=edit, -            original=original -        ) +        return EditDiff(edit=edit, original=original) +  # TODO: Uniform errors thrown by any FileSystem subclass. diff --git a/continuedev/src/continuedev/models/filesystem_edit.py b/continuedev/src/continuedev/models/filesystem_edit.py index b06ca2b3..9316ff46 100644 --- a/continuedev/src/continuedev/models/filesystem_edit.py +++ b/continuedev/src/continuedev/models/filesystem_edit.py @@ -1,9 +1,11 @@ +import os  from abc import abstractmethod  from typing import Generator, List +  from pydantic import BaseModel -from .main import Position, Range +  from ..libs.util.map_path import map_path -import os +from .main import Position, Range  class FileSystemEdit(BaseModel): @@ -27,7 +29,9 @@ class FileEdit(AtomicFileSystemEdit):      replacement: str      def with_mapped_paths(self, orig_root: str, copy_root: str) -> "FileSystemEdit": -        return FileEdit(map_path(self.filepath, orig_root, copy_root), self.range, self.replacement) +        return FileEdit( +            map_path(self.filepath, orig_root, copy_root), self.range, self.replacement +        )      @staticmethod      def from_deletion(filepath: str, range: Range) -> "FileEdit": @@ -35,11 +39,23 @@ class FileEdit(AtomicFileSystemEdit):      @staticmethod      def from_insertion(filepath: str, position: Position, content: str) -> "FileEdit": -        return FileEdit(filepath=filepath, range=Range.from_shorthand(position.line, position.character, position.line, position.character), replacement=content) +        return FileEdit( +            filepath=filepath, +            range=Range.from_shorthand( +                position.line, position.character, position.line, position.character +            ), +            replacement=content, +        )      @staticmethod -    def from_append(filepath: str, previous_content: str, appended_content: str) -> "FileEdit": -        return FileEdit(filepath=filepath, range=Range.from_position(Position.from_end_of_file(previous_content)), replacement=appended_content) +    def from_append( +        filepath: str, previous_content: str, appended_content: str +    ) -> "FileEdit": +        return FileEdit( +            filepath=filepath, +            range=Range.from_position(Position.from_end_of_file(previous_content)), +            replacement=appended_content, +        )  class FileEditWithFullContents(BaseModel): @@ -52,7 +68,9 @@ class AddFile(AtomicFileSystemEdit):      content: str      def with_mapped_paths(self, orig_root: str, copy_root: str) -> "FileSystemEdit": -        return AddFile(self, map_path(self.filepath, orig_root, copy_root), self.content) +        return AddFile( +            self, map_path(self.filepath, orig_root, copy_root), self.content +        )  class DeleteFile(AtomicFileSystemEdit): @@ -67,7 +85,10 @@ class RenameFile(AtomicFileSystemEdit):      new_filepath: str      def with_mapped_paths(self, orig_root: str, copy_root: str) -> "FileSystemEdit": -        return RenameFile(map_path(self.filepath, orig_root, copy_root), map_path(self.new_filepath, orig_root, copy_root)) +        return RenameFile( +            map_path(self.filepath, orig_root, copy_root), +            map_path(self.new_filepath, orig_root, copy_root), +        )  class AddDirectory(AtomicFileSystemEdit): @@ -89,7 +110,10 @@ class RenameDirectory(AtomicFileSystemEdit):      new_path: str      def with_mapped_paths(self, orig_root: str, copy_root: str) -> "FileSystemEdit": -        return RenameDirectory(map_path(self.filepath, orig_root, copy_root), map_path(self.new_path, orig_root, copy_root)) +        return RenameDirectory( +            map_path(self.filepath, orig_root, copy_root), +            map_path(self.new_path, orig_root, copy_root), +        )  class DeleteDirectoryRecursive(FileSystemEdit): @@ -112,10 +136,9 @@ class SequentialFileSystemEdit(FileSystemEdit):      edits: List[FileSystemEdit]      def with_mapped_paths(self, orig_root: str, copy_root: str) -> "FileSystemEdit": -        return SequentialFileSystemEdit([ -            edit.with_mapped_paths(orig_root, copy_root) -            for edit in self.edits -        ]) +        return SequentialFileSystemEdit( +            [edit.with_mapped_paths(orig_root, copy_root) for edit in self.edits] +        )      def next_edit(self) -> Generator["FileSystemEdit", None, None]:          for edit in self.edits: @@ -124,6 +147,7 @@ class SequentialFileSystemEdit(FileSystemEdit):  class EditDiff(BaseModel):      """A reversible edit that can be applied to a file.""" +      forward: FileSystemEdit      backward: FileSystemEdit @@ -136,5 +160,5 @@ class EditDiff(BaseModel):              backwards.insert(0, diff.backward)          return cls(              forward=SequentialFileSystemEdit(edits=forwards), -            backward=SequentialFileSystemEdit(edits=backwards) +            backward=SequentialFileSystemEdit(edits=backwards),          ) diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 4262ac55..1c43f0a0 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -1,26 +1,22 @@ -from .main import * -from .filesystem import RangeInFile, FileEdit -from .filesystem_edit import FileEditWithFullContents -from ..core.main import History, HistoryNode, FullState, SessionInfo -from ..core.context import ContextItem -from pydantic import schema_json_of  import os -MODELS_TO_GENERATE = [ -    Position, Range, Traceback, TracebackFrame -] + [ -    RangeInFile, FileEdit -] + [ -    FileEditWithFullContents -] + [ -    History, HistoryNode, FullState, SessionInfo -] + [ -    ContextItem -] - -RENAMES = { -    "ExampleClass": "RenamedName" -} +from pydantic import schema_json_of + +from ..core.context import ContextItem +from ..core.main import FullState, History, HistoryNode, SessionInfo +from .filesystem import FileEdit, RangeInFile +from .filesystem_edit import FileEditWithFullContents +from .main import * + +MODELS_TO_GENERATE = ( +    [Position, Range, Traceback, TracebackFrame] +    + [RangeInFile, FileEdit] +    + [FileEditWithFullContents] +    + [History, HistoryNode, FullState, SessionInfo] +    + [ContextItem] +) + +RENAMES = {"ExampleClass": "RenamedName"}  SCHEMA_DIR = "../schema/json" diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index fa736772..d442a415 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -1,7 +1,8 @@  from abc import ABC -from typing import List, Union, Tuple -from pydantic import BaseModel, root_validator  from functools import total_ordering +from typing import List, Tuple, Union + +from pydantic import BaseModel, root_validator  class ContinueBaseModel(BaseModel): @@ -46,16 +47,19 @@ class Position(BaseModel):      def to_index(self, string: str) -> int:          """Convert line and character to index in string"""          lines = string.splitlines() -        return sum(map(len, lines[:self.line])) + self.character +        return sum(map(len, lines[: self.line])) + self.character  class Range(BaseModel):      """A range in a file. 0-indexed.""" +      start: Position      end: Position      def __lt__(self, other: "Range") -> bool: -        return self.start < other.start or (self.start == other.start and self.end < other.end) +        return self.start < other.start or ( +            self.start == other.start and self.end < other.end +        )      def __eq__(self, other: "Range") -> bool:          return self.start == other.start and self.end == other.end @@ -78,10 +82,13 @@ class Range(BaseModel):          if len(lines) == 0:              return (0, 0) -        start_index = sum( -            [len(line) + 1 for line in lines[:self.start.line]]) + self.start.character -        end_index = sum( -            [len(line) + 1 for line in lines[:self.end.line]]) + self.end.character +        start_index = ( +            sum([len(line) + 1 for line in lines[: self.start.line]]) +            + self.start.character +        ) +        end_index = ( +            sum([len(line) + 1 for line in lines[: self.end.line]]) + self.end.character +        )          return (start_index, end_index)      def overlaps_with(self, other: "Range") -> bool: @@ -90,27 +97,23 @@ class Range(BaseModel):      def to_full_lines(self) -> "Range":          return Range(              start=Position(line=self.start.line, character=0), -            end=Position(line=self.end.line + 1, character=0) +            end=Position(line=self.end.line + 1, character=0),          )      @staticmethod      def from_indices(string: str, start_index: int, end_index: int) -> "Range":          return Range(              start=Position.from_index(string, start_index), -            end=Position.from_index(string, end_index) +            end=Position.from_index(string, end_index),          )      @staticmethod -    def from_shorthand(start_line: int, start_char: int, end_line: int, end_char: int) -> "Range": +    def from_shorthand( +        start_line: int, start_char: int, end_line: int, end_char: int +    ) -> "Range":          return Range( -            start=Position( -                line=start_line, -                character=start_char -            ), -            end=Position( -                line=end_line, -                character=end_char -            ) +            start=Position(line=start_line, character=start_char), +            end=Position(line=end_line, character=end_char),          )      @staticmethod @@ -148,7 +151,9 @@ class Range(BaseModel):          if start_line == -1 or end_line == -1:              raise ValueError("Snippet not found in content") -        return Range.from_shorthand(start_line, 0, end_line, len(content_lines[end_line]) - 1) +        return Range.from_shorthand( +            start_line, 0, end_line, len(content_lines[end_line]) - 1 +        )      @staticmethod      def from_position(position: Position) -> "Range": @@ -160,7 +165,8 @@ class AbstractModel(ABC, BaseModel):      def check_is_subclass(cls, values):          if not issubclass(cls, AbstractModel):              raise TypeError( -                "AbstractModel subclasses must be subclasses of AbstractModel") +                "AbstractModel subclasses must be subclasses of AbstractModel" +            )  class TracebackFrame(BaseModel): @@ -170,7 +176,11 @@ class TracebackFrame(BaseModel):      code: Union[str, None]      def __eq__(self, other): -        return self.filepath == other.filepath and self.lineno == other.lineno and self.function == other.function +        return ( +            self.filepath == other.filepath +            and self.lineno == other.lineno +            and self.function == other.function +        )  class Traceback(BaseModel): diff --git a/continuedev/src/continuedev/plugins/context_providers/diff.py b/continuedev/src/continuedev/plugins/context_providers/diff.py index 7a53e87a..c8345d02 100644 --- a/continuedev/src/continuedev/plugins/context_providers/diff.py +++ b/continuedev/src/continuedev/plugins/context_providers/diff.py @@ -1,9 +1,8 @@  import subprocess  from typing import List -from .util import remove_meilisearch_disallowed_chars -from ...core.main import ContextItem, ContextItemDescription, ContextItemId  from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId  class DiffContextProvider(ContextProvider): @@ -21,10 +20,9 @@ class DiffContextProvider(ContextProvider):                  name="Diff",                  description="Reference the output of 'git diff' for the current workspace",                  id=ContextItemId( -                    provider_title=self.title, -                    item_id=self.DIFF_CONTEXT_ITEM_ID -                ) -            ) +                    provider_title=self.title, item_id=self.DIFF_CONTEXT_ITEM_ID +                ), +            ),          )      async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: @@ -35,8 +33,9 @@ class DiffContextProvider(ContextProvider):          if not id.item_id == self.DIFF_CONTEXT_ITEM_ID:              raise Exception("Invalid item id") -        diff = subprocess.check_output( -            ["git", "diff"], cwd=self.workspace_dir).decode("utf-8") +        diff = subprocess.check_output(["git", "diff"], cwd=self.workspace_dir).decode( +            "utf-8" +        )          ctx_item = self.BASE_CONTEXT_ITEM.copy()          ctx_item.content = diff diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py index 42d1f754..3f37232e 100644 --- a/continuedev/src/continuedev/plugins/context_providers/embeddings.py +++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py @@ -1,13 +1,12 @@  import os -from typing import List, Optional  import uuid +from typing import List, Optional +  from pydantic import BaseModel -from ...core.main import ContextItemId  from ...core.context import ContextProvider  from ...core.main import ContextItem, ContextItemDescription, ContextItemId  from ...libs.chroma.query import ChromaIndexManager -from .util import remove_meilisearch_disallowed_chars  class EmbeddingResult(BaseModel): @@ -41,10 +40,9 @@ class EmbeddingsProvider(ContextProvider):                  name="Embedding Search",                  description="Enter a query to embedding search codebase",                  id=ContextItemId( -                    provider_title=self.title, -                    item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID -                ) -            ) +                    provider_title=self.title, item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID +                ), +            ),          )      async def _get_query_results(self, query: str) -> str: @@ -53,9 +51,8 @@ class EmbeddingsProvider(ContextProvider):          ret = []          for node in results.source_nodes:              resource_name = list(node.node.relationships.values())[0] -            filepath = resource_name[:resource_name.index("::")] -            ret.append(EmbeddingResult( -                filename=filepath, content=node.node.text)) +            filepath = resource_name[: resource_name.index("::")] +            ret.append(EmbeddingResult(filename=filepath, content=node.node.text))          return ret diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index b40092af..33e20662 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -1,11 +1,10 @@  import os -import re +from fnmatch import fnmatch  from typing import List -from ...core.main import ContextItem, ContextItemDescription, ContextItemId +  from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId  from .util import remove_meilisearch_disallowed_chars -from fnmatch import fnmatch -  MAX_SIZE_IN_BYTES = 1024 * 1024 * 1 @@ -18,7 +17,7 @@ def get_file_contents(filepath: str) -> str:          with open(filepath, "r") as f:              return f.read() -    except Exception as e: +    except Exception:          # Some files cannot be read, e.g. binary files          return None @@ -40,7 +39,7 @@ DEFAULT_IGNORE_DIRS = [      ".pytest_cache",      ".vscode-test",      ".continue", -    "__pycache__" +    "__pycache__",  ] @@ -50,14 +49,18 @@ class FileContextProvider(ContextProvider):      """      title = "file" -    ignore_patterns: List[str] = DEFAULT_IGNORE_DIRS + \ -        list(filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS)) +    ignore_patterns: List[str] = DEFAULT_IGNORE_DIRS + list( +        filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS) +    )      async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:          absolute_filepaths: List[str] = []          for root, dir_names, file_names in os.walk(workspace_dir): -            dir_names[:] = [d for d in dir_names if not any( -                fnmatch(d, pattern) for pattern in self.ignore_patterns)] +            dir_names[:] = [ +                d +                for d in dir_names +                if not any(fnmatch(d, pattern) for pattern in self.ignore_patterns) +            ]              for file_name in file_names:                  absolute_filepaths.append(os.path.join(root, file_name)) @@ -72,20 +75,24 @@ class FileContextProvider(ContextProvider):              content = get_file_contents(absolute_filepath)              if content is None:                  continue  # no pun intended -             +              relative_to_workspace = os.path.relpath(absolute_filepath, workspace_dir) -            items.append(ContextItem( -                content=content[:min(2000, len(content))], -                description=ContextItemDescription( -                    name=os.path.basename(absolute_filepath), -                    # We should add the full path to the ContextItem -                    # It warrants a data modeling discussion and has no immediate use case -                    description=relative_to_workspace, -                    id=ContextItemId( -                        provider_title=self.title, -                        item_id=remove_meilisearch_disallowed_chars(absolute_filepath) -                    ) +            items.append( +                ContextItem( +                    content=content[: min(2000, len(content))], +                    description=ContextItemDescription( +                        name=os.path.basename(absolute_filepath), +                        # We should add the full path to the ContextItem +                        # It warrants a data modeling discussion and has no immediate use case +                        description=relative_to_workspace, +                        id=ContextItemId( +                            provider_title=self.title, +                            item_id=remove_meilisearch_disallowed_chars( +                                absolute_filepath +                            ), +                        ), +                    ),                  ) -            )) +            )          return items diff --git a/continuedev/src/continuedev/plugins/context_providers/filetree.py b/continuedev/src/continuedev/plugins/context_providers/filetree.py index c7b4806b..ea86f214 100644 --- a/continuedev/src/continuedev/plugins/context_providers/filetree.py +++ b/continuedev/src/continuedev/plugins/context_providers/filetree.py @@ -1,21 +1,19 @@ -import json -from typing import List  import os -import aiohttp +from typing import List -from ...core.main import ContextItem, ContextItemDescription, ContextItemId  from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId  def format_file_tree(startpath) -> str:      result = ""      for root, dirs, files in os.walk(startpath): -        level = root.replace(startpath, '').count(os.sep) -        indent = ' ' * 4 * (level) -        result += '{}{}/'.format(indent, os.path.basename(root)) + "\n" -        subindent = ' ' * 4 * (level + 1) +        level = root.replace(startpath, "").count(os.sep) +        indent = " " * 4 * (level) +        result += "{}{}/".format(indent, os.path.basename(root)) + "\n" +        subindent = " " * 4 * (level + 1)          for f in files: -            result += '{}{}'.format(subindent, f) + "\n" +            result += "{}{}".format(subindent, f) + "\n"      return result @@ -31,11 +29,8 @@ class FileTreeContextProvider(ContextProvider):              description=ContextItemDescription(                  name="File Tree",                  description="Add a formatted file tree of this directory to the context", -                id=ContextItemId( -                    provider_title=self.title, -                    item_id=self.title -                ) -            ) +                id=ContextItemId(provider_title=self.title, item_id=self.title), +            ),          )      async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: diff --git a/continuedev/src/continuedev/plugins/context_providers/github.py b/continuedev/src/continuedev/plugins/context_providers/github.py index 2e7047f2..d394add1 100644 --- a/continuedev/src/continuedev/plugins/context_providers/github.py +++ b/continuedev/src/continuedev/plugins/context_providers/github.py @@ -1,8 +1,13 @@  from typing import List -from github import Github -from github import Auth -from ...core.context import ContextProvider, ContextItemDescription, ContextItem, ContextItemId +from github import Auth, Github + +from ...core.context import ( +    ContextItem, +    ContextItemDescription, +    ContextItemId, +    ContextProvider, +)  class GitHubIssuesContextProvider(ContextProvider): @@ -22,14 +27,14 @@ class GitHubIssuesContextProvider(ContextProvider):          repo = gh.get_repo(self.repo_name)          issues = repo.get_issues().get_page(0) -        return [ContextItem( -            content=issue.body, -            description=ContextItemDescription( -                name=f"Issue #{issue.number}", -                description=issue.title, -                id=ContextItemId( -                    provider_title=self.title, -                    item_id=issue.id -                ) +        return [ +            ContextItem( +                content=issue.body, +                description=ContextItemDescription( +                    name=f"Issue #{issue.number}", +                    description=issue.title, +                    id=ContextItemId(provider_title=self.title, item_id=issue.id), +                ),              ) -        ) for issue in issues] +            for issue in issues +        ] diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py index 4b0a59ec..d716c9d3 100644 --- a/continuedev/src/continuedev/plugins/context_providers/google.py +++ b/continuedev/src/continuedev/plugins/context_providers/google.py @@ -2,9 +2,10 @@ import json  from typing import List  import aiohttp -from .util import remove_meilisearch_disallowed_chars -from ...core.main import ContextItem, ContextItemDescription, ContextItemId +  from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from .util import remove_meilisearch_disallowed_chars  class GoogleContextProvider(ContextProvider): @@ -22,22 +23,16 @@ class GoogleContextProvider(ContextProvider):                  name="Google Search",                  description="Enter a query to search google",                  id=ContextItemId( -                    provider_title=self.title, -                    item_id=self.GOOGLE_CONTEXT_ITEM_ID -                ) -            ) +                    provider_title=self.title, item_id=self.GOOGLE_CONTEXT_ITEM_ID +                ), +            ),          )      async def _google_search(self, query: str) -> str:          url = "https://google.serper.dev/search" -        payload = json.dumps({ -            "q": query -        }) -        headers = { -            'X-API-KEY': self.serper_api_key, -            'Content-Type': 'application/json' -        } +        payload = json.dumps({"q": query}) +        headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"}          async with aiohttp.ClientSession() as session:              async with session.post(url, headers=headers, data=payload) as response: @@ -61,6 +56,5 @@ class GoogleContextProvider(ContextProvider):          ctx_item = self.BASE_CONTEXT_ITEM.copy()          ctx_item.content = content -        ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars( -            query) +        ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(query)          return ctx_item diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index 750775ac..ed293124 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -1,11 +1,17 @@  import os  from typing import Any, Dict, List -from ...core.main import ChatMessage -from ...models.filesystem import RangeInFile, RangeInFileWithContents -from ...core.context import ContextItem, ContextItemDescription, ContextItemId, ContextProvider  from pydantic import BaseModel +from ...core.context import ( +    ContextItem, +    ContextItemDescription, +    ContextItemId, +    ContextProvider, +) +from ...core.main import ChatMessage +from ...models.filesystem import RangeInFileWithContents +  class HighlightedRangeContextItem(BaseModel):      rif: RangeInFileWithContents @@ -40,12 +46,10 @@ class HighlightedCodeContextProvider(ContextProvider):          visible_files = await self.ide.getVisibleFiles()          if len(visible_files) > 0:              content = await self.ide.readFile(visible_files[0]) -            rif = RangeInFileWithContents.from_entire_file( -                visible_files[0], content) +            rif = RangeInFileWithContents.from_entire_file(visible_files[0], content)              item = self._rif_to_context_item(rif, 0, True) -            item.description.name = self._rif_to_name( -                rif, show_line_nums=False) +            item.description.name = self._rif_to_name(rif, show_line_nums=False)              self.last_added_fallback = True              return HighlightedRangeContextItem(rif=rif, item=item) @@ -55,21 +59,28 @@ class HighlightedCodeContextProvider(ContextProvider):      async def get_selected_items(self) -> List[ContextItem]:          items = [hr.item for hr in self.highlighted_ranges] -        if len(items) == 0 and (fallback_item := await self._get_fallback_context_item()): +        if len(items) == 0 and ( +            fallback_item := await self._get_fallback_context_item() +        ):              items = [fallback_item.item]          return items      async def get_chat_messages(self) -> List[ContextItem]:          ranges = self.highlighted_ranges -        if len(ranges) == 0 and (fallback_item := await self._get_fallback_context_item()): +        if len(ranges) == 0 and ( +            fallback_item := await self._get_fallback_context_item() +        ):              ranges = [fallback_item] -        return [ChatMessage( -            role="user", -            content=f"Code in this file is highlighted ({r.rif.filepath}):\n```\n{r.rif.contents}\n```", -            summary=f"Code in this file is highlighted: {r.rif.filepath}" -        ) for r in ranges] +        return [ +            ChatMessage( +                role="user", +                content=f"Code in this file is highlighted ({r.rif.filepath}):\n```\n{r.rif.contents}\n```", +                summary=f"Code in this file is highlighted: {r.rif.filepath}", +            ) +            for r in ranges +        ]      def _make_sure_is_editing_range(self):          """If none of the highlighted ranges are currently being edited, the first should be selected""" @@ -80,8 +91,9 @@ class HighlightedCodeContextProvider(ContextProvider):      def _disambiguate_highlighted_ranges(self):          """If any files have the same name, also display their folder name""" -        name_status: Dict[str, set] = { -        }  # basename -> set of full paths with that basename +        name_status: Dict[ +            str, set +        ] = {}  # basename -> set of full paths with that basename          for hr in self.highlighted_ranges:              basename = os.path.basename(hr.rif.filepath)              if basename in name_status: @@ -92,11 +104,16 @@ class HighlightedCodeContextProvider(ContextProvider):          for hr in self.highlighted_ranges:              basename = os.path.basename(hr.rif.filepath)              if len(name_status[basename]) > 1: -                hr.item.description.name = self._rif_to_name(hr.rif, display_filename=os.path.join( -                    os.path.basename(os.path.dirname(hr.rif.filepath)), basename)) +                hr.item.description.name = self._rif_to_name( +                    hr.rif, +                    display_filename=os.path.join( +                        os.path.basename(os.path.dirname(hr.rif.filepath)), basename +                    ), +                )              else:                  hr.item.description.name = self._rif_to_name( -                    hr.rif, display_filename=basename) +                    hr.rif, display_filename=basename +                )      async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:          return [] @@ -110,7 +127,9 @@ class HighlightedCodeContextProvider(ContextProvider):          self.should_get_fallback_context_item = True          self.last_added_fallback = False -    async def delete_context_with_ids(self, ids: List[ContextItemId]) -> List[ContextItem]: +    async def delete_context_with_ids( +        self, ids: List[ContextItemId] +    ) -> List[ContextItem]:          ids_to_delete = [id.item_id for id in ids]          kept_ranges = [] @@ -126,36 +145,57 @@ class HighlightedCodeContextProvider(ContextProvider):          return [hr.item for hr in self.highlighted_ranges] -    def _rif_to_name(self, rif: RangeInFileWithContents, display_filename: str = None, show_line_nums: bool = True) -> str: -        line_nums = f" ({rif.range.start.line + 1}-{rif.range.end.line + 1})" if show_line_nums else "" +    def _rif_to_name( +        self, +        rif: RangeInFileWithContents, +        display_filename: str = None, +        show_line_nums: bool = True, +    ) -> str: +        line_nums = ( +            f" ({rif.range.start.line + 1}-{rif.range.end.line + 1})" +            if show_line_nums +            else "" +        )          return f"{display_filename or os.path.basename(rif.filepath)}{line_nums}" -    def _rif_to_context_item(self, rif: RangeInFileWithContents, idx: int, editing: bool) -> ContextItem: +    def _rif_to_context_item( +        self, rif: RangeInFileWithContents, idx: int, editing: bool +    ) -> ContextItem:          return ContextItem(              description=ContextItemDescription(                  name=self._rif_to_name(rif),                  description=rif.filepath, -                id=ContextItemId( -                    provider_title=self.title, -                    item_id=str(idx) -                ) +                id=ContextItemId(provider_title=self.title, item_id=str(idx)),              ),              content=rif.contents,              editing=editing, -            editable=True +            editable=True,          ) -    async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): +    async def handle_highlighted_code( +        self, range_in_files: List[RangeInFileWithContents] +    ):          self.should_get_fallback_context_item = True          self.last_added_fallback = False          # Filter out rifs from ~/.continue/diffs folder          range_in_files = [ -            rif for rif in range_in_files if not os.path.dirname(rif.filepath) == os.path.expanduser("~/.continue/diffs")] +            rif +            for rif in range_in_files +            if not os.path.dirname(rif.filepath) +            == os.path.expanduser("~/.continue/diffs") +        ]          # If not adding highlighted code          if not self.adding_highlighted_code: -            if len(self.highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end): +            if ( +                len(self.highlighted_ranges) == 1 +                and len(range_in_files) <= 1 +                and ( +                    len(range_in_files) == 0 +                    or range_in_files[0].range.start == range_in_files[0].range.end +                ) +            ):                  # If un-highlighting the range to edit, then remove the range                  self.highlighted_ranges = []              elif len(range_in_files) > 0: @@ -164,7 +204,9 @@ class HighlightedCodeContextProvider(ContextProvider):                  self.highlighted_ranges = [                      HighlightedRangeContextItem(                          rif=range_in_files[0], -                        item=self._rif_to_context_item(range_in_files[0], 0, True))] +                        item=self._rif_to_context_item(range_in_files[0], 0, True), +                    ) +                ]              return @@ -173,22 +215,36 @@ class HighlightedCodeContextProvider(ContextProvider):          for i, hr in enumerate(self.highlighted_ranges):              found_overlap = False              for new_rif in range_in_files: -                if hr.rif.filepath == new_rif.filepath and hr.rif.range.overlaps_with(new_rif.range): +                if hr.rif.filepath == new_rif.filepath and hr.rif.range.overlaps_with( +                    new_rif.range +                ):                      found_overlap = True                      break                  # Also don't allow multiple ranges in same file with same content. This is useless to the model, and avoids                  # the bug where cmd+f causes repeated highlights -                if hr.rif.filepath == new_rif.filepath and hr.rif.contents == new_rif.contents: +                if ( +                    hr.rif.filepath == new_rif.filepath +                    and hr.rif.contents == new_rif.contents +                ):                      found_overlap = True                      break              if not found_overlap: -                new_ranges.append(HighlightedRangeContextItem(rif=hr.rif, item=self._rif_to_context_item( -                    hr.rif, len(new_ranges), False))) +                new_ranges.append( +                    HighlightedRangeContextItem( +                        rif=hr.rif, +                        item=self._rif_to_context_item(hr.rif, len(new_ranges), False), +                    ) +                ) -        self.highlighted_ranges = new_ranges + [HighlightedRangeContextItem(rif=rif, item=self._rif_to_context_item( -            rif, len(new_ranges) + idx, False)) for idx, rif in enumerate(range_in_files)] +        self.highlighted_ranges = new_ranges + [ +            HighlightedRangeContextItem( +                rif=rif, +                item=self._rif_to_context_item(rif, len(new_ranges) + idx, False), +            ) +            for idx, rif in enumerate(range_in_files) +        ]          self._make_sure_is_editing_range()          self._disambiguate_highlighted_ranges() @@ -197,5 +253,7 @@ class HighlightedCodeContextProvider(ContextProvider):          for hr in self.highlighted_ranges:              hr.item.editing = hr.item.description.id.to_string() in ids -    async def add_context_item(self, id: ContextItemId, query: str, prev: List[ContextItem] = None) -> List[ContextItem]: +    async def add_context_item( +        self, id: ContextItemId, query: str, prev: List[ContextItem] = None +    ) -> List[ContextItem]:          raise NotImplementedError() diff --git a/continuedev/src/continuedev/plugins/context_providers/search.py b/continuedev/src/continuedev/plugins/context_providers/search.py index da991a78..6aecb5bf 100644 --- a/continuedev/src/continuedev/plugins/context_providers/search.py +++ b/continuedev/src/continuedev/plugins/context_providers/search.py @@ -1,10 +1,11 @@  import os  from typing import List +  from ripgrepy import Ripgrepy -from .util import remove_meilisearch_disallowed_chars -from ...core.main import ContextItem, ContextItemDescription, ContextItemId  from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from .util import remove_meilisearch_disallowed_chars  class SearchContextProvider(ContextProvider): @@ -22,17 +23,16 @@ class SearchContextProvider(ContextProvider):                  name="Search",                  description="Search the workspace for all matches of an exact string (e.g. '@search console.log')",                  id=ContextItemId( -                    provider_title=self.title, -                    item_id=self.SEARCH_CONTEXT_ITEM_ID -                ) -            ) +                    provider_title=self.title, item_id=self.SEARCH_CONTEXT_ITEM_ID +                ), +            ),          )      def _get_rg_path(self): -        if os.name == 'nt': +        if os.name == "nt":              rg_path = f"C:\\Users\\{os.getlogin()}\\AppData\\Local\\Programs\\Microsoft VS Code\\resources\\app\\node_modules.asar.unpacked\\vscode-ripgrep\\bin\\rg.exe" -        elif os.name == 'posix': -            if 'darwin' in os.sys.platform: +        elif os.name == "posix": +            if "darwin" in os.sys.platform:                  rg_path = "/Applications/Visual Studio Code.app/Contents/Resources/app/node_modules.asar.unpacked/@vscode/ripgrep/bin/rg"              else:                  rg_path = "/usr/share/code/resources/app/node_modules.asar.unpacked/vscode-ripgrep/bin/rg" @@ -87,6 +87,5 @@ class SearchContextProvider(ContextProvider):          ctx_item = self.BASE_CONTEXT_ITEM.copy()          ctx_item.content = results          ctx_item.description.name = f"Search: '{query}'" -        ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars( -            query) +        ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(query)          return ctx_item diff --git a/continuedev/src/continuedev/plugins/context_providers/url.py b/continuedev/src/continuedev/plugins/context_providers/url.py index 5b67608d..a5ec8990 100644 --- a/continuedev/src/continuedev/plugins/context_providers/url.py +++ b/continuedev/src/continuedev/plugins/context_providers/url.py @@ -1,11 +1,11 @@ -  from typing import List -from bs4 import BeautifulSoup +  import requests +from bs4 import BeautifulSoup -from .util import remove_meilisearch_disallowed_chars -from ...core.main import ContextItem, ContextItemDescription, ContextItemId  from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from .util import remove_meilisearch_disallowed_chars  class URLContextProvider(ContextProvider): @@ -30,10 +30,9 @@ class URLContextProvider(ContextProvider):                  name="Dynamic URL",                  description="Reference the contents of a webpage (e.g. '@url https://www.w3schools.com/python/python_ref_functions.asp')",                  id=ContextItemId( -                    provider_title=self.title, -                    item_id=self.DYNAMIC_URL_CONTEXT_ITEM_ID -                ) -            ) +                    provider_title=self.title, item_id=self.DYNAMIC_URL_CONTEXT_ITEM_ID +                ), +            ),          )      def static_url_context_item_from_url(self, url: str) -> ContextItem: @@ -45,30 +44,36 @@ class URLContextProvider(ContextProvider):                  description=f"Contents of {url}",                  id=ContextItemId(                      provider_title=self.title, -                    item_id=remove_meilisearch_disallowed_chars(url) -                ) -            ) +                    item_id=remove_meilisearch_disallowed_chars(url), +                ), +            ),          )      def _get_url_text_contents_and_title(self, url: str) -> (str, str):          response = requests.get(url) -        soup = BeautifulSoup(response.text, 'html.parser') -        title = url.replace( -            "https://", "").replace("http://", "").replace("www.", "") +        soup = BeautifulSoup(response.text, "html.parser") +        title = url.replace("https://", "").replace("http://", "").replace("www.", "")          if soup.title is not None:              title = soup.title.string          return soup.get_text(), title      async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:          self.static_url_context_items = [ -            self.static_url_context_item_from_url(url) for url in self.preset_urls] +            self.static_url_context_item_from_url(url) for url in self.preset_urls +        ]          return [self.DYNAMIC_CONTEXT_ITEM] + self.static_url_context_items      async def get_item(self, id: ContextItemId, query: str) -> ContextItem:          # Check if the item is a static item          matching_static_item = next( -            (item for item in self.static_url_context_items if item.description.id.item_id == id.item_id), None) +            ( +                item +                for item in self.static_url_context_items +                if item.description.id.item_id == id.item_id +            ), +            None, +        )          if matching_static_item:              return matching_static_item @@ -85,6 +90,5 @@ class URLContextProvider(ContextProvider):          ctx_item = self.DYNAMIC_CONTEXT_ITEM.copy()          ctx_item.content = content          ctx_item.description.name = title -        ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars( -            url) +        ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(url)          return ctx_item diff --git a/continuedev/src/continuedev/plugins/context_providers/util.py b/continuedev/src/continuedev/plugins/context_providers/util.py index da2e6b17..61bea8aa 100644 --- a/continuedev/src/continuedev/plugins/context_providers/util.py +++ b/continuedev/src/continuedev/plugins/context_providers/util.py @@ -2,4 +2,4 @@ import re  def remove_meilisearch_disallowed_chars(id: str) -> str: -    return re.sub(r'[^0-9a-zA-Z_-]', '', id) +    return re.sub(r"[^0-9a-zA-Z_-]", "", id) diff --git a/continuedev/src/continuedev/plugins/policies/default.py b/continuedev/src/continuedev/plugins/policies/default.py index 2382f33a..ef88c8d6 100644 --- a/continuedev/src/continuedev/plugins/policies/default.py +++ b/continuedev/src/continuedev/plugins/policies/default.py @@ -1,15 +1,15 @@  from textwrap import dedent  from typing import Type, Union -from ..steps.chat import SimpleChatStep -from ..steps.welcome import WelcomeStep  from ...core.config import ContinueConfig -from ..steps.steps_on_startup import StepsOnStartupStep -from ...core.main import Step, History, Policy +from ...core.main import History, Policy, Step  from ...core.observation import UserInputObservation +from ..steps.chat import SimpleChatStep  from ..steps.core.core import MessageStep  from ..steps.custom_command import CustomCommandStep  from ..steps.main import EditHighlightedCodeStep +from ..steps.steps_on_startup import StepsOnStartupStep +from ..steps.welcome import WelcomeStep  def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: @@ -28,7 +28,8 @@ def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]:                      return slash_command.step(**params)                  except TypeError as e:                      raise Exception( -                        f"Incorrect params used for slash command '{command_name}': {e}") +                        f"Incorrect params used for slash command '{command_name}': {e}" +                    )      return None @@ -40,12 +41,17 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]:              slash_command = parse_slash_command(custom_cmd.prompt, config)              if slash_command is not None:                  return slash_command -            return CustomCommandStep(name=custom_cmd.name, description=custom_cmd.description, prompt=custom_cmd.prompt, user_input=after_command, slash_command=command_name) +            return CustomCommandStep( +                name=custom_cmd.name, +                description=custom_cmd.description, +                prompt=custom_cmd.prompt, +                user_input=after_command, +                slash_command=command_name, +            )      return None  class DefaultPolicy(Policy): -      default_step: Type[Step] = SimpleChatStep      default_params: dict = {} @@ -53,13 +59,19 @@ class DefaultPolicy(Policy):          # At the very start, run initial Steps spcecified in the config          if history.get_current() is None:              return ( -                MessageStep(name="Welcome to Continue", message=dedent("""\ +                MessageStep( +                    name="Welcome to Continue", +                    message=dedent( +                        """\                      - Highlight code section and ask a question or give instructions                      - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue                      - Use `/help` to ask questions about how to use Continue -                    - [Customize Continue](https://continue.dev/docs/customization) (e.g. use your own API key) by typing '/config'.""")) >> -                WelcomeStep() >> -                StepsOnStartupStep()) +                    - [Customize Continue](https://continue.dev/docs/customization) (e.g. use your own API key) by typing '/config'.""" +                    ), +                ) +                >> WelcomeStep() +                >> StepsOnStartupStep() +            )          observation = history.get_current().observation          if observation is not None and isinstance(observation, UserInputObservation): diff --git a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/README.md b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/README.md index d735e0cd..78d603a2 100644 --- a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/README.md +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/README.md @@ -3,6 +3,7 @@  Uses the Chess.com API example to show how to add map and filter Python transforms to a dlt pipeline.
  Background
 +
  - https://dlthub.com/docs/general-usage/resource#filter-transform-and-pivot-data
  - https://dlthub.com/docs/customizations/customizing-pipelines/renaming_columns
 -- https://dlthub.com/docs/customizations/customizing-pipelines/pseudonymizing_columns
\ No newline at end of file +- https://dlthub.com/docs/customizations/customizing-pipelines/pseudonymizing_columns
 diff --git a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/dlt_transform_docs.md b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/dlt_transform_docs.md index 658b285f..864aea87 100644 --- a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/dlt_transform_docs.md +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/dlt_transform_docs.md @@ -1,15 +1,19 @@  # Customize resources +  ## Filter, transform and pivot data  You can attach any number of transformations that are evaluated on item per item basis to your resource. The available transformation types: +  - map - transform the data item (resource.add_map)  - filter - filter the data item (resource.add_filter)  - yield map - a map that returns iterator (so single row may generate many rows - resource.add_yield_map)  Example: We have a resource that loads a list of users from an api endpoint. We want to customize it so: +  - we remove users with user_id == 'me'  - we anonymize user data -Here's our resource: +  Here's our resource: +  ```python  import dlt @@ -22,6 +26,7 @@ def users():  ```  Here's our script that defines transformations and loads the data. +  ```python  from pipedrive import users @@ -34,9 +39,9 @@ def anonymize_user(user_data):  for user in users().add_filter(lambda user: user['user_id'] != 'me').add_map(anonymize_user):  print(user)  ``` -                 +  Here is a more complex example of a filter transformation: -                 +      # Renaming columns      ## Renaming columns by replacing the special characters @@ -80,11 +85,13 @@ Here is a more complex example of a filter transformation:      # {'Objekt_0': {'Groesse': 0, 'Aequivalenzpruefung': True}}      # ...      ``` -                 +  Here is a more complex example of a map transformation: -                 +  # Pseudonymizing columns +  ## Pseudonymizing (or anonymizing) columns by replacing the special characters +  Pseudonymization is a deterministic way to hide personally identifiable info (PII), enabling us to consistently achieve the same mapping. If instead you wish to anonymize, you can delete the data, or replace it with a constant. In the example below, we create a dummy source with a PII column called 'name', which we replace with deterministic hashes (i.e. replacing the German umlaut).  ```python @@ -132,4 +139,4 @@ def pseudonymize_name(doc):      pipeline = dlt.pipeline(pipeline_name='example', destination='bigquery', dataset_name='normalized_data')      load_info = pipeline.run(data_source) -```
\ No newline at end of file +``` diff --git a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/main.py index 5d242f7c..54207399 100644 --- a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/main.py @@ -2,9 +2,8 @@ from textwrap import dedent  from ....core.main import Step  from ....core.sdk import ContinueSDK -from ....plugins.steps.core.core import WaitForUserInputStep -from ....plugins.steps.core.core import MessageStep -from .steps import SetUpChessPipelineStep, AddTransformStep +from ....plugins.steps.core.core import MessageStep, WaitForUserInputStep +from .steps import AddTransformStep, SetUpChessPipelineStep  class AddTransformRecipe(Step): @@ -12,16 +11,21 @@ class AddTransformRecipe(Step):      async def run(self, sdk: ContinueSDK):          text_observation = await sdk.run_step( -            MessageStep(message=dedent("""\ +            MessageStep( +                message=dedent( +                    """\                  This recipe will walk you through the process of adding a transform to a dlt pipeline that uses the chess.com API source. With the help of Continue, you will:                  - Set up a dlt pipeline for the chess.com API                  - Add a filter or map transform to the pipeline -                - Run the pipeline and view the transformed data in a Streamlit app"""), name="Add transformation to a dlt pipeline") >> -            SetUpChessPipelineStep() >> -            WaitForUserInputStep( -                prompt="How do you want to transform the Chess.com API data before loading it? For example, you could filter out games that ended in a draw.") +                - Run the pipeline and view the transformed data in a Streamlit app""" +                ), +                name="Add transformation to a dlt pipeline", +            ) +            >> SetUpChessPipelineStep() +            >> WaitForUserInputStep( +                prompt="How do you want to transform the Chess.com API data before loading it? For example, you could filter out games that ended in a draw." +            )          )          await sdk.run_step( -            AddTransformStep( -                transform_description=text_observation.text) +            AddTransformStep(transform_description=text_observation.text)          ) diff --git a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/steps.py index e589fc36..0091af97 100644 --- a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/steps.py @@ -1,12 +1,10 @@  import os  from textwrap import dedent - -from ....plugins.steps.core.core import MessageStep -from ....libs.util.paths import find_data_file -from ....core.sdk import Models  from ....core.main import Step -from ....core.sdk import ContinueSDK +from ....core.sdk import ContinueSDK, Models +from ....libs.util.paths import find_data_file +from ....plugins.steps.core.core import MessageStep  AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" @@ -19,21 +17,26 @@ class SetUpChessPipelineStep(Step):          return "This step will create a new dlt pipeline that loads data from the chess.com API."      async def run(self, sdk: ContinueSDK): -          # running commands to get started when creating a new dlt pipeline -        await sdk.run([ -            'python3 -m venv .env', -            'source .env/bin/activate', -            'pip install dlt', -            'dlt --non-interactive init chess duckdb', -            'pip install -r requirements.txt', -            'pip install pandas streamlit'  # Needed for the pipeline show step later -        ], name="Set up Python environment", description=dedent(f"""\ +        await sdk.run( +            [ +                "python3 -m venv .env", +                "source .env/bin/activate", +                "pip install dlt", +                "dlt --non-interactive init chess duckdb", +                "pip install -r requirements.txt", +                "pip install pandas streamlit",  # Needed for the pipeline show step later +            ], +            name="Set up Python environment", +            description=dedent( +                """\              - Create a Python virtual environment: `python3 -m venv .env`              - Activate the virtual environment: `source .env/bin/activate`              - Install dlt: `pip install dlt`              - Create a new dlt pipeline called "chess" that loads data into a local DuckDB instance: `dlt init chess duckdb` -            - Install the Python dependencies for the pipeline: `pip install -r requirements.txt`""")) +            - Install the Python dependencies for the pipeline: `pip install -r requirements.txt`""" +            ), +        )  class AddTransformStep(Step): @@ -43,42 +46,61 @@ class AddTransformStep(Step):      transform_description: str      async def run(self, sdk: ContinueSDK): -        source_name = 'chess' -        filename = f'{source_name}_pipeline.py' +        source_name = "chess" +        filename = f"{source_name}_pipeline.py"          abs_filepath = os.path.join(sdk.ide.workspace_directory, filename)          # Open the file and highlight the function to be edited          await sdk.ide.setFileOpen(abs_filepath) -        await sdk.run_step(MessageStep(message=dedent("""\ +        await sdk.run_step( +            MessageStep( +                message=dedent( +                    """\                  This step will customize your resource function with a transform of your choice:                  - Add a filter or map transformation depending on your request                  - Load the data into a local DuckDB instance -                - Open up a Streamlit app for you to view the data"""), name="Write transformation function")) +                - Open up a Streamlit app for you to view the data""" +                ), +                name="Write transformation function", +            ) +        ) -        with open(find_data_file('dlt_transform_docs.md')) as f: +        with open(find_data_file("dlt_transform_docs.md")) as f:              dlt_transform_docs = f.read() -        prompt = dedent(f"""\ +        prompt = dedent( +            f"""\              Task: Write a transform function using the description below and then use `add_map` or `add_filter` from the `dlt` library to attach it a resource.              Description: {self.transform_description}              Here are some docs pages that will help you better understand how to use `dlt`. -            {dlt_transform_docs}""") +            {dlt_transform_docs}""" +        )          # edit the pipeline to add a tranform function and attach it to a resource          await sdk.edit_file(              filename=filename,              prompt=prompt, -            name=f"Writing transform function {AI_ASSISTED_STRING}" +            name=f"Writing transform function {AI_ASSISTED_STRING}",          ) -        await sdk.wait_for_user_confirmation("Press Continue to confirm that the changes are okay before we run the pipeline.") +        await sdk.wait_for_user_confirmation( +            "Press Continue to confirm that the changes are okay before we run the pipeline." +        )          # run the pipeline and load the data -        await sdk.run(f'python3 {filename}', name="Run the pipeline", description=f"Running `python3 {filename}` to load the data into a local DuckDB instance") +        await sdk.run( +            f"python3 {filename}", +            name="Run the pipeline", +            description=f"Running `python3 {filename}` to load the data into a local DuckDB instance", +        )          # run a streamlit app to show the data -        await sdk.run(f'dlt pipeline {source_name}_pipeline show', name="Show data in a Streamlit app", description=f"Running `dlt pipeline {source_name} show` to show the data in a Streamlit app, where you can view and play with the data.") +        await sdk.run( +            f"dlt pipeline {source_name}_pipeline show", +            name="Show data in a Streamlit app", +            description=f"Running `dlt pipeline {source_name} show` to show the data in a Streamlit app, where you can view and play with the data.", +        ) diff --git a/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/main.py index c0f9e7e3..e67ea557 100644 --- a/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/main.py @@ -1,7 +1,8 @@  from textwrap import dedent -from ....plugins.steps.main import EditHighlightedCodeStep +  from ....core.main import Step  from ....core.sdk import ContinueSDK +from ....plugins.steps.main import EditHighlightedCodeStep  class ContinueStepStep(Step): @@ -9,7 +10,10 @@ class ContinueStepStep(Step):      prompt: str      async def run(self, sdk: ContinueSDK): -        await sdk.run_step(EditHighlightedCodeStep(user_input=dedent(f"""\ +        await sdk.run_step( +            EditHighlightedCodeStep( +                user_input=dedent( +                    f"""\          Here is an example of a Step that runs a command and then edits a file.          ```python @@ -33,4 +37,7 @@ class ContinueStepStep(Step):          It should be a subclass of Step as above, implementing the `run` method, and using pydantic attributes to define the parameters. -        """))) +        """ +                ) +            ) +        ) diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/main.py index 84363e02..4b259769 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/main.py @@ -1,10 +1,9 @@  from textwrap import dedent -from ....core.sdk import ContinueSDK  from ....core.main import Step -from ....plugins.steps.core.core import WaitForUserInputStep -from ....plugins.steps.core.core import MessageStep -from .steps import SetupPipelineStep, ValidatePipelineStep, RunQueryStep +from ....core.sdk import ContinueSDK +from ....plugins.steps.core.core import MessageStep, WaitForUserInputStep +from .steps import RunQueryStep, SetupPipelineStep, ValidatePipelineStep  class CreatePipelineRecipe(Step): @@ -12,7 +11,10 @@ class CreatePipelineRecipe(Step):      async def run(self, sdk: ContinueSDK):          text_observation = await sdk.run_step( -            MessageStep(name="Building your first dlt pipeline", message=dedent("""\ +            MessageStep( +                name="Building your first dlt pipeline", +                message=dedent( +                    """\                  This recipe will walk you through the process of creating a dlt pipeline for your chosen data source. With the help of Continue, you will:                  - Create a Python virtual environment with dlt installed                  - Run `dlt init` to generate a pipeline template @@ -20,14 +22,19 @@ class CreatePipelineRecipe(Step):                  - Add any required API keys to the `secrets.toml` file                  - Test that the API call works                  - Load the data into a local DuckDB instance -                - Write a query to view the data""")) >> -            WaitForUserInputStep( -                prompt="What API do you want to load data from? (e.g. weatherapi.com, chess.com)") +                - Write a query to view the data""" +                ), +            ) +            >> WaitForUserInputStep( +                prompt="What API do you want to load data from? (e.g. weatherapi.com, chess.com)" +            )          )          await sdk.run_step( -            SetupPipelineStep(api_description=text_observation.text) >> -            ValidatePipelineStep() >> -            RunQueryStep() >> -            MessageStep( -                name="Congrats!", message="You've successfully created your first dlt pipeline! 🎉") +            SetupPipelineStep(api_description=text_observation.text) +            >> ValidatePipelineStep() +            >> RunQueryStep() +            >> MessageStep( +                name="Congrats!", +                message="You've successfully created your first dlt pipeline! 🎉", +            )          ) diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index 872f8d62..43a2b800 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -1,13 +1,13 @@  import os -from textwrap import dedent  import time +from textwrap import dedent -from ....models.main import Range -from ....models.filesystem import RangeInFile -from ....plugins.steps.core.core import MessageStep -from ....models.filesystem_edit import AddFile, FileEdit  from ....core.main import Step  from ....core.sdk import ContinueSDK, Models +from ....models.filesystem import RangeInFile +from ....models.filesystem_edit import AddFile, FileEdit +from ....models.main import Range +from ....plugins.steps.core.core import MessageStep  AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" @@ -19,50 +19,71 @@ class SetupPipelineStep(Step):      api_description: str  # e.g. "I want to load data from the weatherapi.com API"      async def describe(self, models: Models): -        return dedent(f"""\ +        return dedent( +            f"""\          This step will create a new dlt pipeline that loads data from an API, as per your request:          {self.api_description} -        """) +        """ +        )      async def run(self, sdk: ContinueSDK):          sdk.context.set("api_description", self.api_description) -        source_name = (await sdk.models.medium.complete( -            f"Write a snake_case name for the data source described by {self.api_description}: ")).strip() -        filename = f'{source_name}.py' +        source_name = ( +            await sdk.models.medium.complete( +                f"Write a snake_case name for the data source described by {self.api_description}: " +            ) +        ).strip() +        filename = f"{source_name}.py"          # running commands to get started when creating a new dlt pipeline -        await sdk.run([ -            'python3 -m venv .env', -            'source .env/bin/activate', -            'pip install dlt', -            f'dlt --non-interactive init {source_name} duckdb', -            'pip install -r requirements.txt' -        ], description=dedent(f"""\ +        await sdk.run( +            [ +                "python3 -m venv .env", +                "source .env/bin/activate", +                "pip install dlt", +                f"dlt --non-interactive init {source_name} duckdb", +                "pip install -r requirements.txt", +            ], +            description=dedent( +                f"""\              Running the following commands:              - `python3 -m venv .env`: Create a Python virtual environment              - `source .env/bin/activate`: Activate the virtual environment              - `pip install dlt`: Install dlt              - `dlt init {source_name} duckdb`: Create a new dlt pipeline called {source_name} that loads data into a local DuckDB instance -            - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline"""), name="Setup Python environment") +            - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline""" +            ), +            name="Setup Python environment", +        )          # editing the resource function to call the requested API          resource_function_range = Range.from_shorthand(15, 0, 30, 0) -        await sdk.ide.highlightCode(RangeInFile(filepath=os.path.join(await sdk.ide.getWorkspaceDirectory(), filename), range=resource_function_range), "#ffa50033") +        await sdk.ide.highlightCode( +            RangeInFile( +                filepath=os.path.join(await sdk.ide.getWorkspaceDirectory(), filename), +                range=resource_function_range, +            ), +            "#ffa50033", +        )          # sdk.set_loading_message("Writing code to call the API...")          await sdk.edit_file(              range=resource_function_range,              filename=filename, -            prompt=f'Edit the resource function to call the API described by this: {self.api_description}. Do not move or remove the exit() call in __main__.', -            name=f"Edit the resource function to call the API {AI_ASSISTED_STRING}" +            prompt=f"Edit the resource function to call the API described by this: {self.api_description}. Do not move or remove the exit() call in __main__.", +            name=f"Edit the resource function to call the API {AI_ASSISTED_STRING}",          )          time.sleep(1)          # wait for user to put API key in secrets.toml -        await sdk.ide.setFileOpen(await sdk.ide.getWorkspaceDirectory() + "/.dlt/secrets.toml") -        await sdk.wait_for_user_confirmation("If this service requires an API key, please add it to the `secrets.toml` file and then press `Continue`.") +        await sdk.ide.setFileOpen( +            await sdk.ide.getWorkspaceDirectory() + "/.dlt/secrets.toml" +        ) +        await sdk.wait_for_user_confirmation( +            "If this service requires an API key, please add it to the `secrets.toml` file and then press `Continue`." +        )          sdk.context.set("source_name", source_name) @@ -73,7 +94,7 @@ class ValidatePipelineStep(Step):      async def run(self, sdk: ContinueSDK):          workspace_dir = await sdk.ide.getWorkspaceDirectory()          source_name = sdk.context.get("source_name") -        filename = f'{source_name}.py' +        filename = f"{source_name}.py"          # await sdk.run_step(MessageStep(name="Validate the pipeline", message=dedent("""\          #         Next, we will validate that your dlt pipeline is working as expected: @@ -83,13 +104,20 @@ class ValidatePipelineStep(Step):          #         """)))          # test that the API call works -        output = await sdk.run(f'python3 {filename}', name="Test the pipeline", description=f"Running `python3 {filename}` to test loading data from the API", handle_error=False) +        output = await sdk.run( +            f"python3 {filename}", +            name="Test the pipeline", +            description=f"Running `python3 {filename}` to test loading data from the API", +            handle_error=False, +        )          # If it fails, return the error          if "Traceback" in output or "SyntaxError" in output:              output = "Traceback" + output.split("Traceback")[-1]              file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) -            suggestion = await sdk.models.medium.complete(dedent(f"""\ +            suggestion = await sdk.models.medium.complete( +                dedent( +                    f"""\                  ```python                  {file_content}                  ``` @@ -99,59 +127,98 @@ class ValidatePipelineStep(Step):                  {output}                  ``` -                This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:""")) +                This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:""" +                ) +            ) -            api_documentation_url = await sdk.models.medium.complete(dedent(f"""\ +            api_documentation_url = await sdk.models.medium.complete( +                dedent( +                    f"""\                  The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this:                  ```python                         {file_content}                  ``` -                What is the URL for the API documentation that will help me learn how to make this call? Please format in markdown so I can click the link.""")) +                What is the URL for the API documentation that will help me learn how to make this call? Please format in markdown so I can click the link.""" +                ) +            )              sdk.raise_exception( -                title=f"Error while running pipeline.\nFix the resource function in {filename} and rerun this step", message=output, with_step=MessageStep(name=f"Suggestion to solve error {AI_ASSISTED_STRING}", message=dedent(f"""\ +                title=f"Error while running pipeline.\nFix the resource function in {filename} and rerun this step", +                message=output, +                with_step=MessageStep( +                    name=f"Suggestion to solve error {AI_ASSISTED_STRING}", +                    message=dedent( +                        f"""\                  {suggestion}                  {api_documentation_url} -                After you've fixed the code, click the retry button at the top of the Validate Pipeline step above."""))) +                After you've fixed the code, click the retry button at the top of the Validate Pipeline step above.""" +                    ), +                ), +            )          # remove exit() from the main main function -        await sdk.run_step(MessageStep(name="Remove early exit() from main function", message="Remove the early exit() from the main function now that we are done testing and want the pipeline to load the data into DuckDB.")) +        await sdk.run_step( +            MessageStep( +                name="Remove early exit() from main function", +                message="Remove the early exit() from the main function now that we are done testing and want the pipeline to load the data into DuckDB.", +            ) +        )          contents = await sdk.ide.readFile(os.path.join(workspace_dir, filename))          replacement = "\n".join( -            list(filter(lambda line: line.strip() != "exit()", contents.split("\n")))) -        await sdk.ide.applyFileSystemEdit(FileEdit( -            filepath=os.path.join(workspace_dir, filename), -            replacement=replacement, -            range=Range.from_entire_file(contents) -        )) +            list(filter(lambda line: line.strip() != "exit()", contents.split("\n"))) +        ) +        await sdk.ide.applyFileSystemEdit( +            FileEdit( +                filepath=os.path.join(workspace_dir, filename), +                replacement=replacement, +                range=Range.from_entire_file(contents), +            ) +        )          # load the data into the DuckDB instance -        await sdk.run(f'python3 {filename}', name="Load data into DuckDB", description=f"Running python3 {filename} to load data into DuckDB") +        await sdk.run( +            f"python3 {filename}", +            name="Load data into DuckDB", +            description=f"Running python3 {filename} to load data into DuckDB", +        ) -        tables_query_code = dedent(f'''\ +        tables_query_code = dedent( +            f"""\              import duckdb              # connect to DuckDB instance              conn = duckdb.connect(database="{source_name}.duckdb")              # list all tables -            print(conn.sql("DESCRIBE"))''') +            print(conn.sql("DESCRIBE"))""" +        )          query_filename = os.path.join(workspace_dir, "query.py") -        await sdk.apply_filesystem_edit(AddFile(filepath=query_filename, content=tables_query_code), name="Add query.py file", description="Adding a file called `query.py` to the workspace that will run a test query on the DuckDB instance") +        await sdk.apply_filesystem_edit( +            AddFile(filepath=query_filename, content=tables_query_code), +            name="Add query.py file", +            description="Adding a file called `query.py` to the workspace that will run a test query on the DuckDB instance", +        )  class RunQueryStep(Step):      hide: bool = True      async def run(self, sdk: ContinueSDK): -        output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False) +        output = await sdk.run( +            ".env/bin/python3 query.py", +            name="Run test query", +            description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", +            handle_error=False, +        )          if "Traceback" in output or "SyntaxError" in output: -            suggestion = await sdk.models.medium.complete(dedent(f"""\ +            suggestion = await sdk.models.medium.complete( +                dedent( +                    f"""\                  ```python                  {await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))}                  ``` @@ -161,8 +228,16 @@ class RunQueryStep(Step):                  {output}                  ``` -                This is a brief summary of the error followed by a suggestion on how it can be fixed:""")) +                This is a brief summary of the error followed by a suggestion on how it can be fixed:""" +                ) +            )              sdk.raise_exception( -                title="Error while running query", message=output, with_step=MessageStep(name=f"Suggestion to solve error {AI_ASSISTED_STRING}", message=suggestion + "\n\nIt is also very likely that no duckdb table was created, which can happen if the resource function did not yield any data. Please make sure that it is yielding data and then rerun this step.") +                title="Error while running query", +                message=output, +                with_step=MessageStep( +                    name=f"Suggestion to solve error {AI_ASSISTED_STRING}", +                    message=suggestion +                    + "\n\nIt is also very likely that no duckdb table was created, which can happen if the resource function did not yield any data. Please make sure that it is yielding data and then rerun this step.", +                ),              ) diff --git a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/README.md b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/README.md index c4981e56..d50324f7 100644 --- a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/README.md +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/README.md @@ -1,3 +1,3 @@  # DDtoBQRecipe -Move from using DuckDB to Google BigQuery as the destination for your `dlt` pipeline
\ No newline at end of file +Move from using DuckDB to Google BigQuery as the destination for your `dlt` pipeline diff --git a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/main.py index 5b6aa8f0..5348321d 100644 --- a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/main.py @@ -3,7 +3,7 @@ from textwrap import dedent  from ....core.main import Step  from ....core.sdk import ContinueSDK  from ....plugins.steps.core.core import MessageStep -from .steps import SetUpChessPipelineStep, SwitchDestinationStep, LoadDataStep +from .steps import LoadDataStep, SetUpChessPipelineStep, SwitchDestinationStep  # Based on the following guide:  # https://github.com/dlt-hub/dlt/pull/392 @@ -14,13 +14,18 @@ class DDtoBQRecipe(Step):      async def run(self, sdk: ContinueSDK):          await sdk.run_step( -            MessageStep(name="Move from using DuckDB to Google BigQuery as the destination", message=dedent("""\ +            MessageStep( +                name="Move from using DuckDB to Google BigQuery as the destination", +                message=dedent( +                    """\                  This recipe will walk you through the process of moving from using DuckDB to Google BigQuery as the destination for your dlt pipeline. With the help of Continue, you will:                  - Set up a dlt pipeline for the chess.com API                  - Switch destination from DuckDB to Google BigQuery                  - Add BigQuery credentials to your secrets.toml file -                - Run the pipeline again to load data to BigQuery""")) >> -            SetUpChessPipelineStep() >> -            SwitchDestinationStep() >> -            LoadDataStep() +                - Run the pipeline again to load data to BigQuery""" +                ), +            ) +            >> SetUpChessPipelineStep() +            >> SwitchDestinationStep() +            >> LoadDataStep()          ) diff --git a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py index 14972142..d6769148 100644 --- a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py @@ -1,12 +1,11 @@  import os  from textwrap import dedent -from ....plugins.steps.find_and_replace import FindAndReplaceStep -from ....plugins.steps.core.core import MessageStep -from ....core.sdk import Models  from ....core.main import Step -from ....core.sdk import ContinueSDK +from ....core.sdk import ContinueSDK, Models  from ....libs.util.paths import find_data_file +from ....plugins.steps.core.core import MessageStep +from ....plugins.steps.find_and_replace import FindAndReplaceStep  AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" @@ -19,49 +18,61 @@ class SetUpChessPipelineStep(Step):          return "This step will create a new dlt pipeline that loads data from the chess.com API."      async def run(self, sdk: ContinueSDK): -          # running commands to get started when creating a new dlt pipeline -        await sdk.run([ -            'python3 -m venv .env', -            'source .env/bin/activate', -            'pip install dlt', -            'dlt --non-interactive init chess duckdb', -            'pip install -r requirements.txt', -        ], name="Set up Python environment", description=dedent(f"""\ +        await sdk.run( +            [ +                "python3 -m venv .env", +                "source .env/bin/activate", +                "pip install dlt", +                "dlt --non-interactive init chess duckdb", +                "pip install -r requirements.txt", +            ], +            name="Set up Python environment", +            description=dedent( +                """\              Running the following commands:              - `python3 -m venv .env`: Create a Python virtual environment              - `source .env/bin/activate`: Activate the virtual environment              - `pip install dlt`: Install dlt              - `dlt init chess duckdb`: Create a new dlt pipeline called "chess" that loads data into a local DuckDB instance -            - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline""")) +            - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline""" +            ), +        )  class SwitchDestinationStep(Step):      hide: bool = True      async def run(self, sdk: ContinueSDK): -          # Switch destination from DuckDB to Google BigQuery -        filepath = os.path.join( -            sdk.ide.workspace_directory, 'chess_pipeline.py') -        await sdk.run_step(FindAndReplaceStep(filepath=filepath, pattern="destination='duckdb'", replacement="destination='bigquery'")) +        filepath = os.path.join(sdk.ide.workspace_directory, "chess_pipeline.py") +        await sdk.run_step( +            FindAndReplaceStep( +                filepath=filepath, +                pattern="destination='duckdb'", +                replacement="destination='bigquery'", +            ) +        )          # Add BigQuery credentials to your secrets.toml file -        template = dedent(f"""\ +        template = dedent( +            """\              [destination.bigquery.credentials]              location = "US"  # change the location of the data              project_id = "project_id" # please set me up!              private_key = "private_key" # please set me up! -            client_email = "client_email" # please set me up!""") +            client_email = "client_email" # please set me up!""" +        )          # wait for user to put API key in secrets.toml -        secrets_path = os.path.join( -            sdk.ide.workspace_directory, ".dlt/secrets.toml") +        secrets_path = os.path.join(sdk.ide.workspace_directory, ".dlt/secrets.toml")          await sdk.ide.setFileOpen(secrets_path)          await sdk.append_to_file(secrets_path, template)          # append template to bottom of secrets.toml -        await sdk.wait_for_user_confirmation("Please add your GCP credentials to `secrets.toml` file and then press `Continue`") +        await sdk.wait_for_user_confirmation( +            "Please add your GCP credentials to `secrets.toml` file and then press `Continue`" +        )  class LoadDataStep(Step): @@ -70,14 +81,20 @@ class LoadDataStep(Step):      async def run(self, sdk: ContinueSDK):          # Run the pipeline again to load data to BigQuery -        output = await sdk.run('.env/bin/python3 chess_pipeline.py', name="Load data to BigQuery", description="Running `.env/bin/python3 chess_pipeline.py` to load data to Google BigQuery") +        output = await sdk.run( +            ".env/bin/python3 chess_pipeline.py", +            name="Load data to BigQuery", +            description="Running `.env/bin/python3 chess_pipeline.py` to load data to Google BigQuery", +        )          if "Traceback" in output or "SyntaxError" in output:              with open(find_data_file("dlt_duckdb_to_bigquery_docs.md"), "r") as f:                  docs = f.read()              output = "Traceback" + output.split("Traceback")[-1] -            suggestion = await sdk.models.default.complete(dedent(f"""\ +            suggestion = await sdk.models.default.complete( +                dedent( +                    f"""\                  When trying to load data into BigQuery, the following error occurred:                  ```ascii @@ -88,8 +105,15 @@ class LoadDataStep(Step):                  {docs} -                This is a brief summary of the error followed by a suggestion on how it can be fixed:""")) +                This is a brief summary of the error followed by a suggestion on how it can be fixed:""" +                ) +            )              sdk.raise_exception( -                title="Error while running query", message=output, with_step=MessageStep(name=f"Suggestion to solve error {AI_ASSISTED_STRING}", message=suggestion) +                title="Error while running query", +                message=output, +                with_step=MessageStep( +                    name=f"Suggestion to solve error {AI_ASSISTED_STRING}", +                    message=suggestion, +                ),              ) diff --git a/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/main.py index 54cba45f..8f16cb34 100644 --- a/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/main.py @@ -1,11 +1,10 @@  from textwrap import dedent -from ....plugins.steps.input.nl_multiselect import NLMultiselectStep  from ....core.main import Step  from ....core.sdk import ContinueSDK  from ....plugins.steps.core.core import MessageStep -from .steps import SetupPipelineStep, DeployAirflowStep, RunPipelineStep - +from ....plugins.steps.input.nl_multiselect import NLMultiselectStep +from .steps import DeployAirflowStep, RunPipelineStep, SetupPipelineStep  # https://github.com/dlt-hub/dlt-deploy-template/blob/master/airflow-composer/dag_template.py  # https://www.notion.so/dlthub/Deploy-a-pipeline-with-Airflow-245fd1058652479494307ead0b5565f3 @@ -20,14 +19,20 @@ class DeployPipelineAirflowRecipe(Step):      async def run(self, sdk: ContinueSDK):          source_name = await sdk.run_step( -            MessageStep(name="Deploying a pipeline to Airflow", message=dedent("""\ +            MessageStep( +                name="Deploying a pipeline to Airflow", +                message=dedent( +                    """\                  This recipe will show you how to deploy a pipeline to Airflow. With the help of Continue, you will:                  - Select a dlt-verified pipeline                  - Setup the pipeline                  - Deploy it to Airflow -                - Optionally, setup Airflow locally""")) >> -            NLMultiselectStep( -                prompt=dedent("""\ +                - Optionally, setup Airflow locally""" +                ), +            ) +            >> NLMultiselectStep( +                prompt=dedent( +                    """\                      Which verified pipeline do you want to deploy with Airflow? The options are:                      - Asana                      - Chess.com @@ -48,14 +53,34 @@ class DeployPipelineAirflowRecipe(Step):                      - Stripe                      - SQL Database                      - Workable -                    - Zendesk"""), +                    - Zendesk""" +                ),                  options=[ -                    "asana_dlt", "chess", "github", "google_analytics", "google_sheets", "hubspot", "matomo", "pipedrive", "shopify_dlt", "strapi", "zendesk", -                    "facebook_ads", "jira", "mux", "notion", "pokemon", "salesforce", "stripe_analytics", "sql_database", "workable" -                ]) +                    "asana_dlt", +                    "chess", +                    "github", +                    "google_analytics", +                    "google_sheets", +                    "hubspot", +                    "matomo", +                    "pipedrive", +                    "shopify_dlt", +                    "strapi", +                    "zendesk", +                    "facebook_ads", +                    "jira", +                    "mux", +                    "notion", +                    "pokemon", +                    "salesforce", +                    "stripe_analytics", +                    "sql_database", +                    "workable", +                ], +            )          )          await sdk.run_step( -            SetupPipelineStep(source_name=source_name) >> -            RunPipelineStep(source_name=source_name) >> -            DeployAirflowStep(source_name=source_name) +            SetupPipelineStep(source_name=source_name) +            >> RunPipelineStep(source_name=source_name) +            >> DeployAirflowStep(source_name=source_name)          ) diff --git a/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/steps.py index 83067d52..d09cf8bb 100644 --- a/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/steps.py @@ -1,10 +1,9 @@  import os  from textwrap import dedent -from ....plugins.steps.core.core import MessageStep -from ....core.sdk import Models  from ....core.main import Step -from ....core.sdk import ContinueSDK +from ....core.sdk import ContinueSDK, Models +from ....plugins.steps.core.core import MessageStep  from ....plugins.steps.find_and_replace import FindAndReplaceStep  AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" @@ -20,19 +19,25 @@ class SetupPipelineStep(Step):          pass      async def run(self, sdk: ContinueSDK): -        await sdk.run([ -            'python3 -m venv .env', -            'source .env/bin/activate', -            'pip install dlt', -            f'dlt --non-interactive init {self.source_name} duckdb', -            'pip install -r requirements.txt' -        ], description=dedent(f"""\ +        await sdk.run( +            [ +                "python3 -m venv .env", +                "source .env/bin/activate", +                "pip install dlt", +                f"dlt --non-interactive init {self.source_name} duckdb", +                "pip install -r requirements.txt", +            ], +            description=dedent( +                f"""\              Running the following commands:              - `python3 -m venv .env`: Create a Python virtual environment              - `source .env/bin/activate`: Activate the virtual environment              - `pip install dlt`: Install dlt              - `dlt init {self.source_name} duckdb`: Create a new dlt pipeline called {self.source_name} that loads data into a local DuckDB instance -            - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline"""), name="Setup Python environment") +            - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline""" +            ), +            name="Setup Python environment", +        )  class RunPipelineStep(Step): @@ -45,10 +50,16 @@ class RunPipelineStep(Step):          pass      async def run(self, sdk: ContinueSDK): -        await sdk.run([ -            f'python3 {self.source_name}_pipeline.py', -        ], description=dedent(f"""\ -            Running the command `python3 {self.source_name}_pipeline.py to run the pipeline: """), name="Run dlt pipeline") +        await sdk.run( +            [ +                f"python3 {self.source_name}_pipeline.py", +            ], +            description=dedent( +                f"""\ +            Running the command `python3 {self.source_name}_pipeline.py to run the pipeline: """ +            ), +            name="Run dlt pipeline", +        )  class DeployAirflowStep(Step): @@ -56,26 +67,47 @@ class DeployAirflowStep(Step):      source_name: str      async def run(self, sdk: ContinueSDK): -          # Run dlt command to deploy pipeline to Airflow          await sdk.run( -            ['git init', -                f'dlt --non-interactive deploy {self.source_name}_pipeline.py airflow-composer'], -            description="Running `dlt deploy airflow` to deploy the dlt pipeline to Airflow", name="Deploy dlt pipeline to Airflow") +            [ +                "git init", +                f"dlt --non-interactive deploy {self.source_name}_pipeline.py airflow-composer", +            ], +            description="Running `dlt deploy airflow` to deploy the dlt pipeline to Airflow", +            name="Deploy dlt pipeline to Airflow", +        )          # Get filepaths, open the DAG file          directory = await sdk.ide.getWorkspaceDirectory() -        pipeline_filepath = os.path.join( -            directory, f"{self.source_name}_pipeline.py") +        pipeline_filepath = os.path.join(directory, f"{self.source_name}_pipeline.py")          dag_filepath = os.path.join( -            directory, f"dags/dag_{self.source_name}_pipeline.py") +            directory, f"dags/dag_{self.source_name}_pipeline.py" +        )          await sdk.ide.setFileOpen(dag_filepath)          # Replace the pipeline name and dataset name -        await sdk.run_step(FindAndReplaceStep(filepath=pipeline_filepath, pattern="'pipeline_name'", replacement=f"'{self.source_name}_pipeline'")) -        await sdk.run_step(FindAndReplaceStep(filepath=pipeline_filepath, pattern="'dataset_name'", replacement=f"'{self.source_name}_data'")) -        await sdk.run_step(FindAndReplaceStep(filepath=pipeline_filepath, pattern="pipeline_or_source_script", replacement=f"{self.source_name}_pipeline")) +        await sdk.run_step( +            FindAndReplaceStep( +                filepath=pipeline_filepath, +                pattern="'pipeline_name'", +                replacement=f"'{self.source_name}_pipeline'", +            ) +        ) +        await sdk.run_step( +            FindAndReplaceStep( +                filepath=pipeline_filepath, +                pattern="'dataset_name'", +                replacement=f"'{self.source_name}_data'", +            ) +        ) +        await sdk.run_step( +            FindAndReplaceStep( +                filepath=pipeline_filepath, +                pattern="pipeline_or_source_script", +                replacement=f"{self.source_name}_pipeline", +            ) +        )          # Prompt the user for the DAG schedule          # edit_dag_range = Range.from_shorthand(18, 0, 23, 0) @@ -85,4 +117,9 @@ class DeployAirflowStep(Step):          #                     range=edit_dag_range)          # Tell the user to check the schedule and fill in owner, email, other default_args -        await sdk.run_step(MessageStep(message="Fill in the owner, email, and other default_args in the DAG file with your own personal information. Then the DAG will be ready to run!", name="Fill in default_args")) +        await sdk.run_step( +            MessageStep( +                message="Fill in the owner, email, and other default_args in the DAG file with your own personal information. Then the DAG will be ready to run!", +                name="Fill in default_args", +            ) +        ) diff --git a/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/main.py index 197abe85..2ca65b8e 100644 --- a/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/main.py @@ -1,7 +1,7 @@  from typing import Coroutine -from ....core.main import Step, Observation -from ....core.sdk import ContinueSDK -from ....core.sdk import Models + +from ....core.main import Observation, Step +from ....core.sdk import ContinueSDK, Models  class TemplateRecipe(Step): @@ -25,5 +25,5 @@ class TemplateRecipe(Step):          visible_files = await sdk.ide.getVisibleFiles()          await sdk.edit_file(              filename=visible_files[0], -            prompt=f"Append a statement to print `Hello, {self.name}!` at the end of the file." +            prompt=f"Append a statement to print `Hello, {self.name}!` at the end of the file.",          ) diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index c66cd629..e2712746 100644 --- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -1,9 +1,10 @@ +import os  from textwrap import dedent  from typing import Union -from ....models.filesystem_edit import AddDirectory, AddFile +  from ....core.main import Step  from ....core.sdk import ContinueSDK -import os +from ....models.filesystem_edit import AddDirectory, AddFile  class WritePytestsRecipe(Step): @@ -30,7 +31,8 @@ class WritePytestsRecipe(Step):          for_file_contents = await sdk.ide.readFile(self.for_filepath) -        prompt = dedent(f"""\ +        prompt = dedent( +            f"""\              This is the file you will write unit tests for:              ```python @@ -41,7 +43,8 @@ class WritePytestsRecipe(Step):              "{self.user_input}" -            Here is a complete set of pytest unit tests:""") +            Here is a complete set of pytest unit tests:""" +        )          tests = await sdk.models.medium.complete(prompt)          await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index 658cc7f3..25633942 100644 --- a/continuedev/src/continuedev/plugins/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -1,9 +1,9 @@  from textwrap import dedent  from typing import Coroutine, Union -from ...core.observation import Observation, TextObservation +  from ...core.main import Step +from ...core.observation import Observation  from ...core.sdk import ContinueSDK -from .core.core import EditFileStep  from ...libs.chroma.query import ChromaIndexManager  from .core.core import EditFileStep @@ -42,11 +42,12 @@ class AnswerQuestionChroma(Step):          files = []          for node in results.source_nodes:              resource_name = list(node.node.relationships.values())[0] -            filepath = resource_name[:resource_name.index("::")] +            filepath = resource_name[: resource_name.index("::")]              files.append(filepath)              code_snippets += f"""{filepath}```\n{node.node.text}\n```\n\n""" -        prompt = dedent(f"""Here are a few snippets of code that might be useful in answering the question: +        prompt = dedent( +            f"""Here are a few snippets of code that might be useful in answering the question:              {code_snippets} @@ -54,7 +55,8 @@ class AnswerQuestionChroma(Step):              {self.question} -            Here is the answer:""") +            Here is the answer:""" +        )          answer = await sdk.models.medium.complete(prompt)          # Make paths relative to the workspace directory @@ -73,8 +75,12 @@ class EditFileChroma(Step):          index = ChromaIndexManager(await sdk.ide.getWorkspaceDirectory())          results = index.query_codebase_index(self.request) -        resource_name = list( -            results.source_nodes[0].node.relationships.values())[0] -        filepath = resource_name[:resource_name.index("::")] +        resource_name = list(results.source_nodes[0].node.relationships.values())[0] +        filepath = resource_name[: resource_name.index("::")] -        await sdk.run_step(EditFileStep(filepath=filepath, prompt=f"Here is the code:\n\n{{code}}\n\nHere is the user request:\n\n{self.request}\n\nHere is the code after making the requested changes:\n")) +        await sdk.run_step( +            EditFileStep( +                filepath=filepath, +                prompt=f"Here is the code:\n\n{{code}}\n\nHere is the user request:\n\n{self.request}\n\nHere is the code after making the requested changes:\n", +            ) +        ) diff --git a/continuedev/src/continuedev/plugins/steps/comment_code.py b/continuedev/src/continuedev/plugins/steps/comment_code.py index 3e34ab52..1eee791d 100644 --- a/continuedev/src/continuedev/plugins/steps/comment_code.py +++ b/continuedev/src/continuedev/plugins/steps/comment_code.py @@ -9,4 +9,8 @@ class CommentCodeStep(Step):          return "Writing comments"      async def run(self, sdk: ContinueSDK): -        await sdk.run_step(EditHighlightedCodeStep(user_input="Write comprehensive comments in the canonical format for every class and function")) +        await sdk.run_step( +            EditHighlightedCodeStep( +                user_input="Write comprehensive comments in the canonical format for every class and function" +            ) +        ) diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 1ef201da..86569dcb 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -1,25 +1,36 @@  # These steps are depended upon by ContinueSDK -import os -import json  import difflib -from textwrap import dedent +import os  import traceback +from textwrap import dedent  from typing import Any, Coroutine, List, Union -import difflib  from pydantic import validator +from ....core.main import ChatMessage, ContinueCustomException, Step +from ....core.observation import ( +    Observation, +    TextObservation, +    UserInputObservation, +)  from ....libs.llm.ggml import GGML -# from ....libs.llm.replicate import ReplicateLLM -from ....models.main import Range  from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI -from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit -from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents -from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation -from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep  from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS -from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes +from ....libs.util.strings import ( +    dedent_and_get_common_whitespace, +    remove_quotes_and_escapes, +)  from ....libs.util.telemetry import posthog_logger +from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents +from ....models.filesystem_edit import ( +    EditDiff, +    FileEdit, +    FileEditWithFullContents, +    FileSystemEdit, +) + +# from ....libs.llm.replicate import ReplicateLLM +from ....models.main import Range  class ContinueSDK: @@ -56,7 +67,7 @@ class DisplayErrorStep(Step):      @validator("e", pre=True, always=True)      def validate_e(cls, v):          if isinstance(v, Exception): -            return '\n'.join(traceback.format_exception(v)) +            return "\n".join(traceback.format_exception(v))      async def describe(self, models: Models) -> Coroutine[str, None, None]:          return self.e @@ -100,25 +111,41 @@ class ShellCommandsStep(Step):              return f"Error when running shell commands:\n```\n{self._err_text}\n```"          cmds_str = "\n".join(self.cmds) -        return await models.medium.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:") +        return await models.medium.complete( +            f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:" +        )      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: -        cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd +        await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd          for cmd in self.cmds:              output = await sdk.ide.runCommand(cmd) -            if self.handle_error and output is not None and output_contains_error(output): -                suggestion = await sdk.models.medium.complete(dedent(f"""\ +            if ( +                self.handle_error +                and output is not None +                and output_contains_error(output) +            ): +                suggestion = await sdk.models.medium.complete( +                    dedent( +                        f"""\                      While running the command `{cmd}`, the following error occurred:                      ```ascii                      {output}                      ``` -                    This is a brief summary of the error followed by a suggestion on how it can be fixed:"""), with_history=await sdk.get_chat_context()) +                    This is a brief summary of the error followed by a suggestion on how it can be fixed:""" +                    ), +                    with_history=await sdk.get_chat_context(), +                )                  sdk.raise_exception( -                    title="Error while running query", message=output, with_step=MessageStep(name=f"Suggestion to solve error {AI_ASSISTED_STRING}", message=f"{suggestion}\n\nYou can click the retry button on the failed step to try again.") +                    title="Error while running query", +                    message=output, +                    with_step=MessageStep( +                        name=f"Suggestion to solve error {AI_ASSISTED_STRING}", +                        message=f"{suggestion}\n\nYou can click the retry button on the failed step to try again.", +                    ),                  )          return TextObservation(text=output) @@ -143,7 +170,8 @@ class DefaultModelEditCodeStep(Step):      name: str = "Editing Code"      hide = False      description: str = "" -    _prompt: str = dedent("""\ +    _prompt: str = dedent( +        """\          Take the file prefix and suffix into account, but only rewrite the code_to_edit as specified in the user_request. The code you write in modified_code_to_edit will replace the code between the code_to_edit tags. Do NOT preface your answer or write anything other than code. The </modified_code_to_edit> tag should be written to indicate the end of the modified code section. Do not ever use nested tags.          Example: @@ -177,7 +205,8 @@ class DefaultModelEditCodeStep(Step):          </modified_code_to_edit>          Main task: -        """) +        """ +    )      _previous_contents: str = ""      _new_contents: str = ""      _prompt_and_completion: str = "" @@ -186,22 +215,34 @@ class DefaultModelEditCodeStep(Step):          if self._previous_contents.strip() == self._new_contents.strip():              description = "No edits were made"          else: -            changes = '\n'.join(difflib.ndiff( -                self._previous_contents.splitlines(), self._new_contents.splitlines())) -            description = await models.medium.complete(dedent(f"""\ +            changes = "\n".join( +                difflib.ndiff( +                    self._previous_contents.splitlines(), +                    self._new_contents.splitlines(), +                ) +            ) +            description = await models.medium.complete( +                dedent( +                    f"""\                  Diff summary: "{self.user_input}"                  ```diff                  {changes}                  ``` -                Please give brief a description of the changes made above using markdown bullet points. Be concise:""")) -        name = await models.medium.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:") +                Please give brief a description of the changes made above using markdown bullet points. Be concise:""" +                ) +            ) +        name = await models.medium.complete( +            f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:" +        )          self.name = remove_quotes_and_escapes(name)          return f"{remove_quotes_and_escapes(description)}" -    async def get_prompt_parts(self, rif: RangeInFileWithContents, sdk: ContinueSDK, full_file_contents: str): +    async def get_prompt_parts( +        self, rif: RangeInFileWithContents, sdk: ContinueSDK, full_file_contents: str +    ):          # We don't know here all of the functions being passed in.          # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.          # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. @@ -209,18 +250,27 @@ class DefaultModelEditCodeStep(Step):          max_tokens = int(model_to_use.context_length / 2)          TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 -        if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: +        if ( +            model_to_use.count_tokens(rif.contents) +            > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE +        ):              self.description += "\n\n**It looks like you've selected a large range to edit, which may take a while to complete. If you'd like to cancel, click the 'X' button above. If you highlight a more specific range, Continue will only edit within it.**"              # At this point, we also increase the max_tokens parameter so it doesn't stop in the middle of generation              # Increase max_tokens to be double the size of the range              # But don't exceed twice default max tokens -            max_tokens = int(min(model_to_use.count_tokens( -                rif.contents), DEFAULT_MAX_TOKENS) * 2.5) +            max_tokens = int( +                min(model_to_use.count_tokens(rif.contents), DEFAULT_MAX_TOKENS) * 2.5 +            )          BUFFER_FOR_FUNCTIONS = 400 -        total_tokens = model_to_use.count_tokens( -            full_file_contents + self._prompt + self.user_input) + BUFFER_FOR_FUNCTIONS + max_tokens +        total_tokens = ( +            model_to_use.count_tokens( +                full_file_contents + self._prompt + self.user_input +            ) +            + BUFFER_FOR_FUNCTIONS +            + max_tokens +        )          # If using 3.5 and overflows, upgrade to 3.5.16k          if model_to_use.name == "gpt-3.5-turbo": @@ -239,7 +289,8 @@ class DefaultModelEditCodeStep(Step):          if total_tokens > model_to_use.context_length:              while cur_end_line > min_end_line:                  total_tokens -= model_to_use.count_tokens( -                    full_file_contents_lst[cur_end_line]) +                    full_file_contents_lst[cur_end_line] +                )                  cur_end_line -= 1                  if total_tokens < model_to_use.context_length:                      break @@ -248,33 +299,30 @@ class DefaultModelEditCodeStep(Step):              while cur_start_line < max_start_line:                  cur_start_line += 1                  total_tokens -= model_to_use.count_tokens( -                    full_file_contents_lst[cur_start_line]) +                    full_file_contents_lst[cur_start_line] +                )                  if total_tokens < model_to_use.context_length:                      break          # Now use the found start/end lines to get the prefix and suffix strings -        file_prefix = "\n".join( -            full_file_contents_lst[cur_start_line:max_start_line]) -        file_suffix = "\n".join( -            full_file_contents_lst[min_end_line:cur_end_line - 1]) +        file_prefix = "\n".join(full_file_contents_lst[cur_start_line:max_start_line]) +        file_suffix = "\n".join(full_file_contents_lst[min_end_line : cur_end_line - 1])          # Move any surrounding blank line in rif.contents to the prefix/suffix          # TODO: Keep track of start line of the range, because it's needed below for offset stuff -        rif_start_line = rif.range.start.line          if len(rif.contents) > 0:              lines = rif.contents.splitlines(keepends=True)              first_line = lines[0] if lines else None              while first_line and first_line.strip() == "":                  file_prefix += first_line -                rif.contents = rif.contents[len(first_line):] +                rif.contents = rif.contents[len(first_line) :]                  lines = rif.contents.splitlines(keepends=True)                  first_line = lines[0] if lines else None              last_line = lines[-1] if lines else None              while last_line and last_line.strip() == "":                  file_suffix = last_line + file_suffix -                rif.contents = rif.contents[:len( -                    rif.contents) - len(last_line)] +                rif.contents = rif.contents[: len(rif.contents) - len(last_line)]                  lines = rif.contents.splitlines(keepends=True)                  last_line = lines[-1] if lines else None @@ -287,10 +335,13 @@ class DefaultModelEditCodeStep(Step):          return file_prefix, rif.contents, file_suffix, model_to_use, max_tokens -    def compile_prompt(self, file_prefix: str, contents: str, file_suffix: str, sdk: ContinueSDK) -> str: +    def compile_prompt( +        self, file_prefix: str, contents: str, file_suffix: str, sdk: ContinueSDK +    ) -> str:          if contents.strip() == "":              # Seperate prompt for insertion at the cursor, the other tends to cause it to repeat whole file -            prompt = dedent(f"""\ +            prompt = dedent( +                f"""\  <file_prefix>  {file_prefix}  </file_prefix> @@ -302,30 +353,39 @@ class DefaultModelEditCodeStep(Step):  {self.user_input}  </user_request> -Please output the code to be inserted at the cursor in order to fulfill the user_request. Do NOT preface your answer or write anything other than code. You should not write any tags, just the code. Make sure to correctly indent the code:""") +Please output the code to be inserted at the cursor in order to fulfill the user_request. Do NOT preface your answer or write anything other than code. You should not write any tags, just the code. Make sure to correctly indent the code:""" +            )              return prompt          prompt = self._prompt          if file_prefix.strip() != "": -            prompt += dedent(f""" +            prompt += dedent( +                f"""  <file_prefix>  {file_prefix} -</file_prefix>""") -        prompt += dedent(f""" +</file_prefix>""" +            ) +        prompt += dedent( +            f"""  <code_to_edit>  {contents} -</code_to_edit>""") +</code_to_edit>""" +        )          if file_suffix.strip() != "": -            prompt += dedent(f""" +            prompt += dedent( +                f"""  <file_suffix>  {file_suffix} -</file_suffix>""") -        prompt += dedent(f""" +</file_suffix>""" +            ) +        prompt += dedent( +            f"""  <user_request>  {self.user_input}  </user_request>  <modified_code_to_edit> -""") +""" +        )          return prompt @@ -333,28 +393,44 @@ Please output the code to be inserted at the cursor in order to fulfill the user          return "</modified_code_to_edit>" in line or "</code_to_edit>" in line      def line_to_be_ignored(self, line: str, is_first_line: bool = False) -> bool: -        return "```" in line or "<modified_code_to_edit>" in line or "<file_prefix>" in line or "</file_prefix>" in line or "<file_suffix>" in line or "</file_suffix>" in line or "<user_request>" in line or "</user_request>" in line or "<code_to_edit>" in line +        return ( +            "```" in line +            or "<modified_code_to_edit>" in line +            or "<file_prefix>" in line +            or "</file_prefix>" in line +            or "<file_suffix>" in line +            or "</file_suffix>" in line +            or "<user_request>" in line +            or "</user_request>" in line +            or "<code_to_edit>" in line +        )      async def stream_rif(self, rif: RangeInFileWithContents, sdk: ContinueSDK):          await sdk.ide.saveFile(rif.filepath)          full_file_contents = await sdk.ide.readFile(rif.filepath) -        file_prefix, contents, file_suffix, model_to_use, max_tokens = await self.get_prompt_parts( -            rif, sdk, full_file_contents) -        contents, common_whitespace = dedent_and_get_common_whitespace( -            contents) +        ( +            file_prefix, +            contents, +            file_suffix, +            model_to_use, +            max_tokens, +        ) = await self.get_prompt_parts(rif, sdk, full_file_contents) +        contents, common_whitespace = dedent_and_get_common_whitespace(contents)          prompt = self.compile_prompt(file_prefix, contents, file_suffix, sdk)          full_file_contents_lines = full_file_contents.split("\n")          lines_to_display = [] -        async def sendDiffUpdate(lines: List[str], sdk: ContinueSDK, final: bool = False): +        async def sendDiffUpdate( +            lines: List[str], sdk: ContinueSDK, final: bool = False +        ):              nonlocal full_file_contents_lines, rif, lines_to_display              completion = "\n".join(lines) -            full_prefix_lines = full_file_contents_lines[:rif.range.start.line] -            full_suffix_lines = full_file_contents_lines[rif.range.end.line:] +            full_prefix_lines = full_file_contents_lines[: rif.range.start.line] +            full_suffix_lines = full_file_contents_lines[rif.range.end.line :]              # Don't do this at the very end, just show the inserted code              if final: @@ -365,13 +441,29 @@ Please output the code to be inserted at the cursor in order to fulfill the user                  rewritten_lines = 0                  for line in lines:                      for i in range(rewritten_lines, len(contents_lines)): -                        if difflib.SequenceMatcher(None, line, contents_lines[i]).ratio() > 0.7 and contents_lines[i].strip() != "": +                        if ( +                            difflib.SequenceMatcher( +                                None, line, contents_lines[i] +                            ).ratio() +                            > 0.7 +                            and contents_lines[i].strip() != "" +                        ):                              rewritten_lines = i + 1                              break                  lines_to_display = contents_lines[rewritten_lines:] -            new_file_contents = "\n".join( -                full_prefix_lines) + "\n" + completion + "\n" + ("\n".join(lines_to_display) + "\n" if len(lines_to_display) > 0 else "") + "\n".join(full_suffix_lines) +            new_file_contents = ( +                "\n".join(full_prefix_lines) +                + "\n" +                + completion +                + "\n" +                + ( +                    "\n".join(lines_to_display) + "\n" +                    if len(lines_to_display) > 0 +                    else "" +                ) +                + "\n".join(full_suffix_lines) +            )              step_index = sdk.history.current_index @@ -403,30 +495,61 @@ Please output the code to be inserted at the cursor in order to fulfill the user              # Highlight the line to show progress              line_to_highlight = current_line_in_file - len(current_block_lines)              if False: -                await sdk.ide.highlightCode(RangeInFile(filepath=rif.filepath, range=Range.from_shorthand( -                    line_to_highlight, 0, line_to_highlight, 0)), "#FFFFFF22" if len(current_block_lines) == 0 else "#00FF0022") +                await sdk.ide.highlightCode( +                    RangeInFile( +                        filepath=rif.filepath, +                        range=Range.from_shorthand( +                            line_to_highlight, 0, line_to_highlight, 0 +                        ), +                    ), +                    "#FFFFFF22" if len(current_block_lines) == 0 else "#00FF0022", +                )              if len(current_block_lines) == 0:                  # Set this as the start of the next block -                current_block_start = rif.range.start.line + len(original_lines) - len( -                    original_lines_below_previous_blocks) + offset_from_blocks -                if len(original_lines_below_previous_blocks) > 0 and line == original_lines_below_previous_blocks[0]: +                current_block_start = ( +                    rif.range.start.line +                    + len(original_lines) +                    - len(original_lines_below_previous_blocks) +                    + offset_from_blocks +                ) +                if ( +                    len(original_lines_below_previous_blocks) > 0 +                    and line == original_lines_below_previous_blocks[0] +                ):                      # Line is equal to the next line in file, move past this line -                    original_lines_below_previous_blocks = original_lines_below_previous_blocks[ -                        1:] +                    original_lines_below_previous_blocks = ( +                        original_lines_below_previous_blocks[1:] +                    )                      return              # In a block, and have already matched at least one line              # Check if the next line matches, for each of the candidates              matches_found = []              first_valid_match = None -            for index_of_last_matched_line, num_lines_matched in indices_of_last_matched_lines: -                if index_of_last_matched_line + 1 < len(original_lines_below_previous_blocks) and line == original_lines_below_previous_blocks[index_of_last_matched_line + 1]: +            for ( +                index_of_last_matched_line, +                num_lines_matched, +            ) in indices_of_last_matched_lines: +                if ( +                    index_of_last_matched_line + 1 +                    < len(original_lines_below_previous_blocks) +                    and line +                    == original_lines_below_previous_blocks[ +                        index_of_last_matched_line + 1 +                    ] +                ):                      matches_found.append( -                        (index_of_last_matched_line + 1, num_lines_matched + 1)) -                    if first_valid_match is None and num_lines_matched + 1 >= LINES_TO_MATCH_BEFORE_ENDING_BLOCK: +                        (index_of_last_matched_line + 1, num_lines_matched + 1) +                    ) +                    if ( +                        first_valid_match is None +                        and num_lines_matched + 1 >= LINES_TO_MATCH_BEFORE_ENDING_BLOCK +                    ):                          first_valid_match = ( -                            index_of_last_matched_line + 1, num_lines_matched + 1) +                            index_of_last_matched_line + 1, +                            num_lines_matched + 1, +                        )              indices_of_last_matched_lines = matches_found              if first_valid_match is not None: @@ -436,7 +559,13 @@ Please output the code to be inserted at the cursor in order to fulfill the user                  # So here we will strip all matching lines from the end of current_block_lines                  lines_stripped = []                  index_of_last_line_in_block = first_valid_match[0] -                while len(current_block_lines) > 0 and current_block_lines[-1] == original_lines_below_previous_blocks[index_of_last_line_in_block - 1]: +                while ( +                    len(current_block_lines) > 0 +                    and current_block_lines[-1] +                    == original_lines_below_previous_blocks[ +                        index_of_last_line_in_block - 1 +                    ] +                ):                      lines_stripped.append(current_block_lines.pop())                      index_of_last_line_in_block -= 1 @@ -455,18 +584,22 @@ Please output the code to be inserted at the cursor in order to fulfill the user                  end_line = current_block_start + index_of_last_line_in_block                  if False: -                    await sdk.ide.showSuggestion(FileEdit( -                        filepath=rif.filepath, -                        range=Range.from_shorthand( -                            start_line, 0, end_line, 0), -                        replacement=replacement -                    )) +                    await sdk.ide.showSuggestion( +                        FileEdit( +                            filepath=rif.filepath, +                            range=Range.from_shorthand(start_line, 0, end_line, 0), +                            replacement=replacement, +                        ) +                    )                  # Reset current block / update variables                  current_line_in_file += 1                  offset_from_blocks += len(current_block_lines) -                original_lines_below_previous_blocks = original_lines_below_previous_blocks[ -                    index_of_last_line_in_block + 1:] +                original_lines_below_previous_blocks = ( +                    original_lines_below_previous_blocks[ +                        index_of_last_line_in_block + 1 : +                    ] +                )                  current_block_lines = []                  current_block_start = -1                  indices_of_last_matched_lines = [] @@ -485,7 +618,8 @@ Please output the code to be inserted at the cursor in order to fulfill the user              # Make sure they are sorted by index              indices_of_last_matched_lines = sorted( -                indices_of_last_matched_lines, key=lambda x: x[0]) +                indices_of_last_matched_lines, key=lambda x: x[0] +            )              current_block_lines.append(line) @@ -498,11 +632,9 @@ Please output the code to be inserted at the cursor in order to fulfill the user                  messages.pop(i)                  deleted += 1              i -= 1 -        messages.append(ChatMessage( -            role="user", -            content=prompt, -            summary=self.user_input -        )) +        messages.append( +            ChatMessage(role="user", content=prompt, summary=self.user_input) +        )          lines_of_prefix_copied = 0          lines = [] @@ -512,18 +644,22 @@ Please output the code to be inserted at the cursor in order to fulfill the user          line_below_highlighted_range = file_suffix.lstrip().split("\n")[0]          if isinstance(model_to_use, GGML): -            messages = [ChatMessage( -                role="user", content=f"```\n{rif.contents}\n```\n\nUser request: \"{self.user_input}\"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n", summary=self.user_input)] +            messages = [ +                ChatMessage( +                    role="user", +                    content=f'```\n{rif.contents}\n```\n\nUser request: "{self.user_input}"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n', +                    summary=self.user_input, +                ) +            ]          # elif isinstance(model_to_use, ReplicateLLM):          #     messages = [ChatMessage(          #         role="user", content=f"// Previous implementation\n\n{rif.contents}\n\n// Updated implementation (after following directions: {self.user_input})\n\n", summary=self.user_input)]          generator = model_to_use.stream_chat( -            messages, temperature=sdk.config.temperature, max_tokens=max_tokens) +            messages, temperature=sdk.config.temperature, max_tokens=max_tokens +        ) -        posthog_logger.capture_event("model_use", { -            "model": model_to_use.name -        }) +        posthog_logger.capture_event("model_use", {"model": model_to_use.name})          try:              async for chunk in generator: @@ -555,16 +691,31 @@ Please output the code to be inserted at the cursor in order to fulfill the user                      if self.is_end_line(chunk_lines[i]):                          break                      # Lines that should be ignored, like the <> tags -                    elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0): +                    elif self.line_to_be_ignored( +                        chunk_lines[i], completion_lines_covered == 0 +                    ):                          continue  # noice                      # Check if we are currently just copying the prefix -                    elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]: +                    elif ( +                        (lines_of_prefix_copied > 0 or completion_lines_covered == 0) +                        and lines_of_prefix_copied < len(file_prefix.splitlines()) +                        and chunk_lines[i] +                        == full_file_contents_lines[lines_of_prefix_copied] +                    ):                          # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line                          lines_of_prefix_copied += 1                          continue  # also nice                      # Because really short lines might be expected to be repeated, this is only a !heuristic!                      # Stop when it starts copying the file_suffix -                    elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()): +                    elif ( +                        chunk_lines[i].strip() == line_below_highlighted_range.strip() +                        and len(chunk_lines[i].strip()) > 4 +                        and not ( +                            len(original_lines_below_previous_blocks) > 0 +                            and chunk_lines[i].strip() +                            == original_lines_below_previous_blocks[0].strip() +                        ) +                    ):                          repeating_file_suffix = True                          break @@ -576,11 +727,25 @@ Please output the code to be inserted at the cursor in order to fulfill the user                      completion_lines_covered += 1                      current_line_in_file += 1 -                await sendDiffUpdate(lines + [common_whitespace if unfinished_line.startswith("<") else (common_whitespace + unfinished_line)], sdk) +                await sendDiffUpdate( +                    lines +                    + [ +                        common_whitespace +                        if unfinished_line.startswith("<") +                        else (common_whitespace + unfinished_line) +                    ], +                    sdk, +                )          finally:              await generator.aclose()          # Add the unfinished line -        if unfinished_line != "" and not self.line_to_be_ignored(unfinished_line, completion_lines_covered == 0) and not self.is_end_line(unfinished_line): +        if ( +            unfinished_line != "" +            and not self.line_to_be_ignored( +                unfinished_line, completion_lines_covered == 0 +            ) +            and not self.is_end_line(unfinished_line) +        ):              unfinished_line = common_whitespace + unfinished_line              lines.append(unfinished_line)              await handle_generated_line(unfinished_line) @@ -598,13 +763,19 @@ Please output the code to be inserted at the cursor in order to fulfill the user                  for i in range(-1, -len(current_block_lines) - 1, -1):                      if len(original_lines_below_previous_blocks) == 0:                          break -                    if current_block_lines[i] == original_lines_below_previous_blocks[-1]: +                    if ( +                        current_block_lines[i] +                        == original_lines_below_previous_blocks[-1] +                    ):                          num_to_remove += 1                          original_lines_below_previous_blocks.pop()                      else:                          break -                current_block_lines = current_block_lines[:- -                                                          num_to_remove] if num_to_remove > 0 else current_block_lines +                current_block_lines = ( +                    current_block_lines[:-num_to_remove] +                    if num_to_remove > 0 +                    else current_block_lines +                )                  # It's also possible that some lines match at the beginning of the block                  # while len(current_block_lines) > 0 and len(original_lines_below_previous_blocks) > 0 and current_block_lines[0] == original_lines_below_previous_blocks[0]: @@ -612,12 +783,19 @@ Please output the code to be inserted at the cursor in order to fulfill the user                  #     original_lines_below_previous_blocks.pop(0)                  #     current_block_start += 1 -                await sdk.ide.showSuggestion(FileEdit( -                    filepath=rif.filepath, -                    range=Range.from_shorthand( -                        current_block_start, 0, current_block_start + len(original_lines_below_previous_blocks), 0), -                    replacement="\n".join(current_block_lines) -                )) +                await sdk.ide.showSuggestion( +                    FileEdit( +                        filepath=rif.filepath, +                        range=Range.from_shorthand( +                            current_block_start, +                            0, +                            current_block_start +                            + len(original_lines_below_previous_blocks), +                            0, +                        ), +                        replacement="\n".join(current_block_lines), +                    ) +                )          # Record the completion          completion = "\n".join(lines) @@ -629,14 +807,18 @@ Please output the code to be inserted at the cursor in order to fulfill the user          await sdk.update_ui()          rif_with_contents = [] -        for range_in_file in map(lambda x: RangeInFile( -            filepath=x.filepath, -            # Only consider the range line-by-line. Maybe later don't if it's only a single line. -            range=x.range.to_full_lines() -        ), self.range_in_files): +        for range_in_file in map( +            lambda x: RangeInFile( +                filepath=x.filepath, +                # Only consider the range line-by-line. Maybe later don't if it's only a single line. +                range=x.range.to_full_lines(), +            ), +            self.range_in_files, +        ):              file_contents = await sdk.ide.readRangeInFile(range_in_file)              rif_with_contents.append( -                RangeInFileWithContents.from_range_in_file(range_in_file, file_contents)) +                RangeInFileWithContents.from_range_in_file(range_in_file, file_contents) +            )          rif_dict = {}          for rif in rif_with_contents: @@ -645,10 +827,10 @@ Please output the code to be inserted at the cursor in order to fulfill the user          for rif in rif_with_contents:              # If the file doesn't exist, ask them to save it first              if not os.path.exists(rif.filepath): -                message = f"The file {rif.filepath} does not exist. Please save it first." -                raise ContinueCustomException( -                    title=message, message=message +                message = ( +                    f"The file {rif.filepath} does not exist. Please save it first."                  ) +                raise ContinueCustomException(title=message, message=message)              await sdk.ide.setFileOpen(rif.filepath)              await sdk.ide.setSuggestionsLocked(rif.filepath, True) @@ -666,11 +848,14 @@ class EditFileStep(Step):      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:          file_contents = await sdk.ide.readFile(self.filepath) -        await sdk.run_step(DefaultModelEditCodeStep( -            range_in_files=[RangeInFile.from_entire_file( -                self.filepath, file_contents)], -            user_input=self.prompt -        )) +        await sdk.run_step( +            DefaultModelEditCodeStep( +                range_in_files=[ +                    RangeInFile.from_entire_file(self.filepath, file_contents) +                ], +                user_input=self.prompt, +            ) +        )  class ManualEditStep(ReversibleStep): @@ -698,8 +883,7 @@ class ManualEditStep(ReversibleStep):      def from_sequence(cls, edits: List[FileEditWithFullContents]) -> "ManualEditStep":          diffs = []          for edit in edits: -            _, diff = FileSystem.apply_edit_to_str( -                edit.fileContents, edit.fileEdit) +            _, diff = FileSystem.apply_edit_to_str(edit.fileContents, edit.fileEdit)              diffs.append(diff)          return cls(edit_diff=EditDiff.from_sequence(diffs)) @@ -720,12 +904,12 @@ class UserInputStep(Step):      async def describe(self, models: Models) -> Coroutine[str, None, None]:          return self.user_input -    async def run(self, sdk: ContinueSDK) -> Coroutine[UserInputObservation, None, None]: -        self.chat_context.append(ChatMessage( -            role="user", -            content=self.user_input, -            summary=self.user_input -        )) +    async def run( +        self, sdk: ContinueSDK +    ) -> Coroutine[UserInputObservation, None, None]: +        self.chat_context.append( +            ChatMessage(role="user", content=self.user_input, summary=self.user_input) +        )          return UserInputObservation(user_input=self.user_input) diff --git a/continuedev/src/continuedev/plugins/steps/custom_command.py b/continuedev/src/continuedev/plugins/steps/custom_command.py index 419b3c3d..4128415b 100644 --- a/continuedev/src/continuedev/plugins/steps/custom_command.py +++ b/continuedev/src/continuedev/plugins/steps/custom_command.py @@ -1,6 +1,6 @@ -from ...libs.util.templating import render_templated_string  from ...core.main import Step  from ...core.sdk import ContinueSDK, Models +from ...libs.util.templating import render_templated_string  from ..steps.chat import SimpleChatStep @@ -21,8 +21,9 @@ class CustomCommandStep(Step):          messages = await sdk.get_chat_context()          # Find the last chat message with this slash command and replace it with the user input          for i in range(len(messages) - 1, -1, -1): -            if messages[i].role == "user" and messages[i].content.startswith(self.slash_command): -                messages[i] = messages[i].copy( -                    update={"content": prompt_user_input}) +            if messages[i].role == "user" and messages[i].content.startswith( +                self.slash_command +            ): +                messages[i] = messages[i].copy(update={"content": prompt_user_input})                  break          await sdk.run_step(SimpleChatStep(messages=messages)) diff --git a/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py index f3131c4b..1d135b3e 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py +++ b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py @@ -1,5 +1,5 @@ -from ....core.sdk import ContinueSDK  from ....core.main import Step +from ....core.sdk import ContinueSDK  class ImplementAbstractMethodStep(Step): @@ -8,11 +8,9 @@ class ImplementAbstractMethodStep(Step):      class_name: str      async def run(self, sdk: ContinueSDK): -          implementations = await sdk.lsp.go_to_implementations(self.class_name)          for implementation in implementations: -              await sdk.edit_file(                  range_in_files=[implementation.range_in_file],                  prompt=f"Implement method `{self.method_name}` for this subclass of `{self.class_name}`", diff --git a/continuedev/src/continuedev/plugins/steps/draft/redux.py b/continuedev/src/continuedev/plugins/steps/draft/redux.py index 30c8fdbb..5a351e6f 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/redux.py +++ b/continuedev/src/continuedev/plugins/steps/draft/redux.py @@ -4,7 +4,6 @@ from ..core.core import EditFileStep  class EditReduxStateStep(Step): -      description: str  # e.g. "I want to load data from the weatherapi.com API"      async def run(self, sdk: ContinueSDK): @@ -15,24 +14,28 @@ class EditReduxStateStep(Step):          sdk.run_step(              EditFileStep(                  filename=store_filename, -                prompt=f"Edit the root store to add a new slice for {self.description}" +                prompt=f"Edit the root store to add a new slice for {self.description}",              )          )          store_file_contents = await sdk.ide.readFile(store_filename)          # Selector          selector_filename = "" -        sdk.run_step(EditFileStep( -            filepath=selector_filename, -            prompt=f"Edit the selector to add a new property for {self.description}. The store looks like this: {store_file_contents}" -        )) +        sdk.run_step( +            EditFileStep( +                filepath=selector_filename, +                prompt=f"Edit the selector to add a new property for {self.description}. The store looks like this: {store_file_contents}", +            ) +        )          # Reducer          reducer_filename = "" -        sdk.run_step(EditFileStep( -            filepath=reducer_filename, -            prompt=f"Edit the reducer to add a new property for {self.description}. The store looks like this: {store_file_contents}" -        )) +        sdk.run_step( +            EditFileStep( +                filepath=reducer_filename, +                prompt=f"Edit the reducer to add a new property for {self.description}. The store looks like this: {store_file_contents}", +            ) +        )          """          Starts with implementing selector          1. RootStore diff --git a/continuedev/src/continuedev/plugins/steps/draft/typeorm.py b/continuedev/src/continuedev/plugins/steps/draft/typeorm.py index d06a6fb4..c79fa041 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/typeorm.py +++ b/continuedev/src/continuedev/plugins/steps/draft/typeorm.py @@ -1,4 +1,5 @@  from textwrap import dedent +  from ....core.main import Step  from ....core.sdk import ContinueSDK @@ -12,17 +13,23 @@ class CreateTableStep(Step):          entity_name = self.sql_str.split(" ")[2].capitalize()          await sdk.edit_file(              f"src/entity/{entity_name}.ts", -            dedent(f"""\ +            dedent( +                f"""\              {self.sql_str} -            Write a TypeORM entity called {entity_name} for this table, importing as necessary:""") +            Write a TypeORM entity called {entity_name} for this table, importing as necessary:""" +            ),          )          # Add entity to data-source.ts -        await sdk.edit_file(filepath="src/data-source.ts", prompt=f"Add the {entity_name} entity:") +        await sdk.edit_file( +            filepath="src/data-source.ts", prompt=f"Add the {entity_name} entity:" +        )          # Generate blank migration for the entity -        out = await sdk.run(f"npx typeorm migration:create ./src/migration/Create{entity_name}Table") +        out = await sdk.run( +            f"npx typeorm migration:create ./src/migration/Create{entity_name}Table" +        )          migration_filepath = out.text.split(" ")[1]          # Wait for user input @@ -31,13 +38,17 @@ class CreateTableStep(Step):          # Fill in the migration          await sdk.edit_file(              migration_filepath, -            dedent(f"""\ +            dedent( +                f"""\                  This is the table that was created:                  {self.sql_str} -                Fill in the migration for the table:"""), +                Fill in the migration for the table:""" +            ),          )          # Run the migration -        await sdk.run("npx typeorm-ts-node-commonjs migration:run -d ./src/data-source.ts") +        await sdk.run( +            "npx typeorm-ts-node-commonjs migration:run -d ./src/data-source.ts" +        ) diff --git a/continuedev/src/continuedev/plugins/steps/feedback.py b/continuedev/src/continuedev/plugins/steps/feedback.py index fa56a4d9..df1142a1 100644 --- a/continuedev/src/continuedev/plugins/steps/feedback.py +++ b/continuedev/src/continuedev/plugins/steps/feedback.py @@ -1,6 +1,4 @@ -from typing import Coroutine -from ...core.main import Models -from ...core.main import Step +from ...core.main import Models, Step  from ...core.sdk import ContinueSDK  from ...libs.util.telemetry import posthog_logger diff --git a/continuedev/src/continuedev/plugins/steps/find_and_replace.py b/continuedev/src/continuedev/plugins/steps/find_and_replace.py index a2c9c44e..287e286d 100644 --- a/continuedev/src/continuedev/plugins/steps/find_and_replace.py +++ b/continuedev/src/continuedev/plugins/steps/find_and_replace.py @@ -1,6 +1,6 @@ -from ...models.filesystem_edit import FileEdit, Range  from ...core.main import Models, Step  from ...core.sdk import ContinueSDK +from ...models.filesystem_edit import FileEdit, Range  class FindAndReplaceStep(Step): @@ -17,12 +17,14 @@ class FindAndReplaceStep(Step):          while self.pattern in file_content:              start_index = file_content.index(self.pattern)              end_index = start_index + len(self.pattern) -            await sdk.ide.applyFileSystemEdit(FileEdit( -                filepath=self.filepath, -                range=Range.from_indices( -                    file_content, start_index, end_index - 1), -                replacement=self.replacement -            )) -            file_content = file_content[:start_index] + \ -                self.replacement + file_content[end_index:] +            await sdk.ide.applyFileSystemEdit( +                FileEdit( +                    filepath=self.filepath, +                    range=Range.from_indices(file_content, start_index, end_index - 1), +                    replacement=self.replacement, +                ) +            ) +            file_content = ( +                file_content[:start_index] + self.replacement + file_content[end_index:] +            )              await sdk.ide.saveFile(self.filepath) diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index 82f885d6..148dddb8 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -1,9 +1,11 @@  from textwrap import dedent +  from ...core.main import ChatMessage, Step  from ...core.sdk import ContinueSDK  from ...libs.util.telemetry import posthog_logger -help = dedent("""\ +help = dedent( +    """\          Continue is an open-source coding autopilot. It is a VS Code extension that brings the power of ChatGPT to your IDE.          It gathers context for you and stores your interactions automatically, so that you can avoid copy/paste now and benefit from a customized Large Language Model (LLM) later. @@ -23,11 +25,11 @@ help = dedent("""\          If Continue is stuck loading, try using `cmd+shift+p` to open the command palette, search "Reload Window", and then select it. This will reload VS Code and Continue and often fixes issues. -        If you have feedback, please use /feedback to let us know how you would like to use Continue. We are excited to hear from you!""") +        If you have feedback, please use /feedback to let us know how you would like to use Continue. We are excited to hear from you!""" +)  class HelpStep(Step): -      name: str = "Help"      user_input: str      manage_own_chat_context: bool = True @@ -40,7 +42,8 @@ class HelpStep(Step):              self.description = help          else:              self.description = "The following output is generated by a language model, which may hallucinate. Type just '/help'to see a fixed answer. You can also learn more by reading [the docs](https://continue.dev/docs).\n\n" -            prompt = dedent(f"""                     +            prompt = dedent( +                f"""                                              Information:                          {help} @@ -49,13 +52,12 @@ class HelpStep(Step):                          Please us the information below to provide a succinct answer to the following question: {question} -                        Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback. Never refer or link to any URL.""") +                        Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback. Never refer or link to any URL.""" +            ) -            self.chat_context.append(ChatMessage( -                role="user", -                content=prompt, -                summary="Help" -            )) +            self.chat_context.append( +                ChatMessage(role="user", content=prompt, summary="Help") +            )              messages = await sdk.get_chat_context()              generator = sdk.models.default.stream_chat(messages)              async for chunk in generator: @@ -64,4 +66,5 @@ class HelpStep(Step):                      await sdk.update_ui()          posthog_logger.capture_event( -            "help", {"question": question, "answer": self.description}) +            "help", {"question": question, "answer": self.description} +        ) diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py index 3d8d96fb..721f1306 100644 --- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py @@ -1,7 +1,8 @@  from typing import List, Union -from ..core.core import WaitForUserInputStep +  from ....core.main import Step  from ....core.sdk import ContinueSDK +from ..core.core import WaitForUserInputStep  class NLMultiselectStep(Step): @@ -11,7 +12,9 @@ class NLMultiselectStep(Step):      options: List[str]      async def run(self, sdk: ContinueSDK): -        user_response = (await sdk.run_step(WaitForUserInputStep(prompt=self.prompt))).text +        user_response = ( +            await sdk.run_step(WaitForUserInputStep(prompt=self.prompt)) +        ).text          def extract_option(text: str) -> Union[str, None]:              for option in self.options: @@ -24,5 +27,6 @@ class NLMultiselectStep(Step):              return first_try          gpt_parsed = await sdk.models.default.complete( -            f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:") +            f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:" +        )          return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index d2d6f4dd..da9cf5b2 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -1,19 +1,19 @@  import os -from typing import Coroutine, List, Union  from textwrap import dedent +from typing import Coroutine, List, Union +  from pydantic import BaseModel, Field -from ...models.main import Traceback, Range -from ...models.filesystem_edit import EditDiff, FileEdit -from ...models.filesystem import RangeInFile, RangeInFileWithContents -from ...core.observation import Observation -from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder  from ...core.main import ContinueCustomException, Step -from ...core.sdk import ContinueSDK, Models  from ...core.observation import Observation -from .core.core import DefaultModelEditCodeStep +from ...core.sdk import ContinueSDK, Models +from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder  from ...libs.util.calculate_diff import calculate_diff2  from ...libs.util.logging import logger +from ...models.filesystem import RangeInFile, RangeInFileWithContents +from ...models.filesystem_edit import EditDiff, FileEdit +from ...models.main import Range, Traceback +from .core.core import DefaultModelEditCodeStep  class Policy(BaseModel): @@ -36,7 +36,8 @@ class FasterEditHighlightedCodeStep(Step):      hide = True      _completion: str = "Edit Code"      _edit_diffs: Union[List[EditDiff], None] = None -    _prompt: str = dedent("""\ +    _prompt: str = dedent( +        """\          You will be given code to edit in order to perfectly satisfy the user request. All the changes you make must be described as replacements, which you should format in the following way:          FILEPATH          <FILE_TO_EDIT> @@ -75,7 +76,8 @@ class FasterEditHighlightedCodeStep(Step):          This is the user request: "{user_input}"          Here is the description of changes to make: -""") +""" +    )      async def describe(self, models: Models) -> Coroutine[str, None, None]:          return "Editing highlighted code" @@ -89,13 +91,14 @@ class FasterEditHighlightedCodeStep(Step):              for file in files:                  contents[file] = await sdk.ide.readFile(file) -            range_in_files = [RangeInFileWithContents.from_entire_file( -                filepath, content) for filepath, content in contents.items()] +            range_in_files = [ +                RangeInFileWithContents.from_entire_file(filepath, content) +                for filepath, content in contents.items() +            ]          enc_dec = MarkdownStyleEncoderDecoder(range_in_files)          code_string = enc_dec.encode() -        prompt = self._prompt.format( -            code=code_string, user_input=self.user_input) +        prompt = self._prompt.format(code=code_string, user_input=self.user_input)          rif_dict = {}          for rif in range_in_files: @@ -145,7 +148,14 @@ class FasterEditHighlightedCodeStep(Step):              replace_me = edit["replace_me"]              replace_with = edit["replace_with"]              file_edits.append( -                FileEdit(filepath=filepath, range=Range.from_lines_snippet_in_file(content=rif_dict[filepath], snippet=replace_me), replacement=replace_with)) +                FileEdit( +                    filepath=filepath, +                    range=Range.from_lines_snippet_in_file( +                        content=rif_dict[filepath], snippet=replace_me +                    ), +                    replacement=replace_with, +                ) +            )          # ------------------------------          self._edit_diffs = [] @@ -169,7 +179,9 @@ class StarCoderEditHighlightedCodeStep(Step):      _prompt_and_completion: str = ""      async def describe(self, models: Models) -> Coroutine[str, None, None]: -        return await models.medium.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:") +        return await models.medium.complete( +            f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:" +        )      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:          range_in_files = await sdk.get_code_context(only_editing=True) @@ -181,8 +193,10 @@ class StarCoderEditHighlightedCodeStep(Step):              for file in files:                  contents[file] = await sdk.ide.readFile(file) -            range_in_files = [RangeInFileWithContents.from_entire_file( -                filepath, content) for filepath, content in contents.items()] +            range_in_files = [ +                RangeInFileWithContents.from_entire_file(filepath, content) +                for filepath, content in contents.items() +            ]          rif_dict = {}          for rif in range_in_files: @@ -190,7 +204,8 @@ class StarCoderEditHighlightedCodeStep(Step):          for rif in range_in_files:              prompt = self._prompt.format( -                code=rif.contents, user_request=self.user_input) +                code=rif.contents, user_request=self.user_input +            )              if found_highlighted_code:                  full_file_contents = await sdk.ide.readFile(rif.filepath) @@ -208,7 +223,8 @@ class StarCoderEditHighlightedCodeStep(Step):              self._prompt_and_completion += prompt + completion              edits = calculate_diff2( -                rif.filepath, rif.contents, completion.removesuffix("\n")) +                rif.filepath, rif.contents, completion.removesuffix("\n") +            )              for edit in edits:                  await sdk.ide.applyFileSystemEdit(edit) @@ -220,7 +236,10 @@ class StarCoderEditHighlightedCodeStep(Step):  class EditHighlightedCodeStep(Step):      user_input: str = Field( -        ..., title="User Input", description="The natural language request describing how to edit the code") +        ..., +        title="User Input", +        description="The natural language request describing how to edit the code", +    )      hide = True      description: str = "Change the contents of the currently highlighted code or open file. You should call this function if the user asks seems to be asking for a code change." @@ -235,28 +254,44 @@ class EditHighlightedCodeStep(Step):              highlighted_code = await sdk.ide.getHighlightedCode()              if highlighted_code is not None:                  for rif in highlighted_code: -                    if os.path.dirname(rif.filepath) == os.path.expanduser(os.path.join("~", ".continue", "diffs")): +                    if os.path.dirname(rif.filepath) == os.path.expanduser( +                        os.path.join("~", ".continue", "diffs") +                    ):                          raise ContinueCustomException( -                            message="Please accept or reject the change before making another edit in this file.", title="Accept/Reject First") +                            message="Please accept or reject the change before making another edit in this file.", +                            title="Accept/Reject First", +                        )                      if rif.range.start == rif.range.end:                          range_in_files.append( -                            RangeInFileWithContents.from_range_in_file(rif, "")) +                            RangeInFileWithContents.from_range_in_file(rif, "") +                        )          # If still no highlighted code, raise error          if len(range_in_files) == 0:              raise ContinueCustomException( -                message="Please highlight some code and try again.", title="No Code Selected") +                message="Please highlight some code and try again.", +                title="No Code Selected", +            ) -        range_in_files = list(map(lambda x: RangeInFile( -            filepath=x.filepath, range=x.range -        ), range_in_files)) +        range_in_files = list( +            map( +                lambda x: RangeInFile(filepath=x.filepath, range=x.range), +                range_in_files, +            ) +        )          for range_in_file in range_in_files: -            if os.path.dirname(range_in_file.filepath) == os.path.expanduser(os.path.join("~", ".continue", "diffs")): +            if os.path.dirname(range_in_file.filepath) == os.path.expanduser( +                os.path.join("~", ".continue", "diffs") +            ):                  self.description = "Please accept or reject the change before making another edit in this file."                  return -        await sdk.run_step(DefaultModelEditCodeStep(user_input=self.user_input, range_in_files=range_in_files)) +        await sdk.run_step( +            DefaultModelEditCodeStep( +                user_input=self.user_input, range_in_files=range_in_files +            ) +        )  class UserInputStep(Step): @@ -270,7 +305,8 @@ class SolveTracebackStep(Step):          return f"```\n{self.traceback.full_traceback}\n```"      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: -        prompt = dedent("""I ran into this problem with my Python code: +        prompt = dedent( +            """I ran into this problem with my Python code:                  {traceback} @@ -279,15 +315,17 @@ class SolveTracebackStep(Step):                  {code}                  This is what the code should be in order to avoid the problem: -            """).format(traceback=self.traceback.full_traceback, code="{code}") +            """ +        ).format(traceback=self.traceback.full_traceback, code="{code}")          range_in_files = []          for frame in self.traceback.frames:              content = await sdk.ide.readFile(frame.filepath) -            range_in_files.append( -                RangeInFile.from_entire_file(frame.filepath, content)) +            range_in_files.append(RangeInFile.from_entire_file(frame.filepath, content)) -        await sdk.run_step(DefaultModelEditCodeStep(range_in_files=range_in_files, user_input=prompt)) +        await sdk.run_step( +            DefaultModelEditCodeStep(range_in_files=range_in_files, user_input=prompt) +        )          return None diff --git a/continuedev/src/continuedev/plugins/steps/open_config.py b/continuedev/src/continuedev/plugins/steps/open_config.py index 64ead547..d6283af2 100644 --- a/continuedev/src/continuedev/plugins/steps/open_config.py +++ b/continuedev/src/continuedev/plugins/steps/open_config.py @@ -1,15 +1,16 @@  from textwrap import dedent +  from ...core.main import Step  from ...core.sdk import ContinueSDK  from ...libs.util.paths import getConfigFilePath -import os  class OpenConfigStep(Step):      name: str = "Open config"      async def describe(self, models): -        return dedent("""\ +        return dedent( +            """\              `\"config.py\"` is now open. You can add a custom slash command in the `\"custom_commands\"` section, like in this example:              ```python              config = ContinueConfig( @@ -23,7 +24,8 @@ class OpenConfigStep(Step):              ```              `name` is the command you will type.              `description` is the description displayed in the slash command menu. -            `prompt` is the instruction given to the model. The overall prompt becomes "Task: {prompt}, Additional info: {user_input}". For example, if you entered "/test exactly 5 assertions", the overall prompt would become "Task: Write a comprehensive...and sophisticated, Additional info: exactly 5 assertions".""") +            `prompt` is the instruction given to the model. The overall prompt becomes "Task: {prompt}, Additional info: {user_input}". For example, if you entered "/test exactly 5 assertions", the overall prompt would become "Task: Write a comprehensive...and sophisticated, Additional info: exactly 5 assertions".""" +        )      async def run(self, sdk: ContinueSDK):          await sdk.ide.setFileOpen(getConfigFilePath()) diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index da6acdbf..a2612731 100644 --- a/continuedev/src/continuedev/plugins/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -1,5 +1,6 @@  from textwrap import dedent -from typing import List, Union, Tuple +from typing import List, Tuple, Union +  from ...core.main import Step  from ...core.sdk import ContinueSDK @@ -13,11 +14,11 @@ class NLDecisionStep(Step):      name: str = "Deciding what to do next"      async def run(self, sdk: ContinueSDK): -        step_descriptions = "\n".join([ -            f"- {step[0].name}: {step[1]}" -            for step in self.steps -        ]) -        prompt = dedent(f"""\ +        step_descriptions = "\n".join( +            [f"- {step[0].name}: {step[1]}" for step in self.steps] +        ) +        prompt = dedent( +            f"""\              The following steps are available, in the format "- [step name]: [step description]":              {step_descriptions} @@ -25,7 +26,8 @@ class NLDecisionStep(Step):              {self.user_input} -            Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""") +            Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""" +        )          resp = (await sdk.models.medium.complete(prompt)).lower() diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 456dba84..04fb98b7 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -1,14 +1,14 @@  import asyncio +import os +import re  from textwrap import dedent  from typing import List, Union -from ...models.filesystem import RangeInFile -from ...models.main import Range  from ...core.main import Step  from ...core.sdk import ContinueSDK  from ...libs.util.create_async_task import create_async_task -import os -import re +from ...models.filesystem import RangeInFile +from ...models.main import Range  # Already have some code for this somewhere  IGNORE_DIRS = ["env", "venv", ".venv"] @@ -29,8 +29,12 @@ def find_all_matches_in_dir(pattern: str, dirpath: str) -> List[RangeInFile]:                  file_content = f.read()                  results = re.finditer(pattern, file_content)                  range_in_files += [ -                    RangeInFile(filepath=os.path.join(root, file), range=Range.from_indices( -                        file_content, result.start(), result.end())) +                    RangeInFile( +                        filepath=os.path.join(root, file), +                        range=Range.from_indices( +                            file_content, result.start(), result.end() +                        ), +                    )                      for result in results                  ] @@ -42,12 +46,16 @@ class WriteRegexPatternStep(Step):      async def run(self, sdk: ContinueSDK):          # Ask the user for a regex pattern -        pattern = await sdk.models.medium.complete(dedent(f"""\ +        pattern = await sdk.models.medium.complete( +            dedent( +                f"""\              This is the user request:              {self.user_request} -            Please write either a regex pattern or just a string that be used with python's re module to find all matches requested by the user. It will be used as `re.findall(<PATTERN_YOU_WILL_WRITE>, file_content)`. Your output should be only the regex or string, nothing else:""")) +            Please write either a regex pattern or just a string that be used with python's re module to find all matches requested by the user. It will be used as `re.findall(<PATTERN_YOU_WILL_WRITE>, file_content)`. Your output should be only the regex or string, nothing else:""" +            ) +        )          return pattern @@ -59,11 +67,18 @@ class EditAllMatchesStep(Step):      async def run(self, sdk: ContinueSDK):          # Search all files for a given string -        range_in_files = find_all_matches_in_dir(self.pattern, self.directory or await sdk.ide.getWorkspaceDirectory()) +        range_in_files = find_all_matches_in_dir( +            self.pattern, self.directory or await sdk.ide.getWorkspaceDirectory() +        ) -        tasks = [create_async_task(sdk.edit_file( -            range=range_in_file.range, -            filename=range_in_file.filepath, -            prompt=self.user_request -        )) for range_in_file in range_in_files] +        tasks = [ +            create_async_task( +                sdk.edit_file( +                    range=range_in_file.range, +                    filename=range_in_file.filepath, +                    prompt=self.user_request, +                ) +            ) +            for range_in_file in range_in_files +        ]          await asyncio.gather(*tasks) diff --git a/continuedev/src/continuedev/plugins/steps/share_session.py b/continuedev/src/continuedev/plugins/steps/share_session.py index de8659bd..1d68dc90 100644 --- a/continuedev/src/continuedev/plugins/steps/share_session.py +++ b/continuedev/src/continuedev/plugins/steps/share_session.py @@ -3,15 +3,13 @@ import os  import time  from typing import Optional - +from ...core.main import FullState, Step  from ...core.sdk import ContinueSDK -from ...core.main import Step, FullState -from ...libs.util.paths import getSessionFilePath, getGlobalFolderPath +from ...libs.util.paths import getGlobalFolderPath, getSessionFilePath  from ...server.session_manager import session_manager  class ShareSessionStep(Step): -      session_id: Optional[str] = None      async def run(self, sdk: ContinueSDK): @@ -23,12 +21,14 @@ class ShareSessionStep(Step):          # Load the session data and format as a markdown file          session_filepath = getSessionFilePath(self.session_id) -        with open(session_filepath, 'r') as f: +        with open(session_filepath, "r") as f:              session_state = FullState(**json.load(f))          import datetime +          date_created = datetime.datetime.fromtimestamp( -            float(session_state.session_info.date_created)).strftime('%Y-%m-%d %H:%M:%S') +            float(session_state.session_info.date_created) +        ).strftime("%Y-%m-%d %H:%M:%S")          content = f"This is a session transcript from [Continue](https://continue.dev) on {date_created}.\n\n"          for node in session_state.history.timeline[:-2]: @@ -40,9 +40,10 @@ class ShareSessionStep(Step):          # Save to a markdown file          save_filepath = os.path.join( -            getGlobalFolderPath(), f"{session_state.session_info.title}.md") +            getGlobalFolderPath(), f"{session_state.session_info.title}.md" +        ) -        with open(save_filepath, 'w') as f: +        with open(save_filepath, "w") as f:              f.write(content)          # Open the file diff --git a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py index 489cada3..d0058ffc 100644 --- a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py +++ b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py @@ -1,5 +1,5 @@  from ...core.main import Step -from ...core.sdk import Models, ContinueSDK +from ...core.sdk import ContinueSDK, Models  class StepsOnStartupStep(Step): diff --git a/continuedev/src/continuedev/plugins/steps/welcome.py b/continuedev/src/continuedev/plugins/steps/welcome.py index df3e9a8a..ef1acfc1 100644 --- a/continuedev/src/continuedev/plugins/steps/welcome.py +++ b/continuedev/src/continuedev/plugins/steps/welcome.py @@ -1,9 +1,9 @@ -from textwrap import dedent  import os +from textwrap import dedent -from ...models.filesystem_edit import AddFile  from ...core.main import Step  from ...core.sdk import ContinueSDK, Models +from ...models.filesystem_edit import AddFile  class WelcomeStep(Step): @@ -21,13 +21,20 @@ class WelcomeStep(Step):          if not os.path.exists(continue_dir):              os.mkdir(continue_dir) -        await sdk.ide.applyFileSystemEdit(AddFile(filepath=filepath, content=dedent("""\ +        await sdk.ide.applyFileSystemEdit( +            AddFile( +                filepath=filepath, +                content=dedent( +                    """\              \"\"\"              Welcome to Continue! To learn how to use it, delete this comment and try to use Continue for the following:              - "Write me a calculator class"              - Ask for a new method (e.g. "exp", "mod", "sqrt")              - Type /comment to write comments for the entire class              - Ask about how the class works, how to write it in another language, etc. -            \"\"\""""))) +            \"\"\"""" +                ), +            ) +        )          # await sdk.ide.setFileOpen(filepath=filepath) diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py index 990833be..d079475c 100644 --- a/continuedev/src/continuedev/server/gui_protocol.py +++ b/continuedev/src/continuedev/server/gui_protocol.py @@ -1,7 +1,5 @@ -from typing import Any, Dict, List  from abc import ABC, abstractmethod - -from ..core.context import ContextItem +from typing import Any  class AbstractGUIProtocolServer(ABC): diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 72b410d4..f63fecf8 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -1,10 +1,10 @@ +from abc import ABC, abstractmethod  from typing import Any, List, Union -from abc import ABC, abstractmethod, abstractproperty +  from fastapi import WebSocket -from ..models.main import Traceback -from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff  from ..models.filesystem import RangeInFile, RangeInFileWithContents +from ..models.filesystem_edit import EditDiff, FileEdit, FileSystemEdit  class AbstractIdeProtocolServer(ABC): diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index f0a3f094..00ded6f1 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,18 +1,16 @@ +import argparse  import asyncio -import time -import psutil -import os -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware  import atexit -import uvicorn -import argparse +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware -from .ide import router as ide_router -from .gui import router as gui_router -from .session_manager import session_manager, router as sessions_router  from ..libs.util.logging import logger +from .gui import router as gui_router +from .ide import router as ide_router +from .session_manager import router as sessions_router +from .session_manager import session_manager  app = FastAPI() @@ -39,8 +37,7 @@ def health():  try:      # add cli arg for server port      parser = argparse.ArgumentParser() -    parser.add_argument("-p", "--port", help="server port", -                        type=int, default=65432) +    parser.add_argument("-p", "--port", help="server port", type=int, default=65432)      args = parser.parse_args()  except Exception as e:      logger.debug(f"Error parsing command line arguments: {e}") diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index f47c08ca..98b48685 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -5,8 +5,8 @@ import subprocess  from meilisearch_python_async import Client -from ..libs.util.paths import getServerFolderPath  from ..libs.util.logging import logger +from ..libs.util.paths import getServerFolderPath  def ensure_meilisearch_installed() -> bool: @@ -43,7 +43,11 @@ def ensure_meilisearch_installed() -> bool:          # Download MeiliSearch          logger.debug("Downloading MeiliSearch...")          subprocess.run( -            f"curl -L https://install.meilisearch.com | sh", shell=True, check=True, cwd=serverPath) +            "curl -L https://install.meilisearch.com | sh", +            shell=True, +            check=True, +            cwd=serverPath, +        )          return False @@ -56,7 +60,7 @@ async def check_meilisearch_running() -> bool:      """      try: -        async with Client('http://localhost:7700') as client: +        async with Client("http://localhost:7700") as client:              try:                  resp = await client.health()                  if resp.status != "available": @@ -96,5 +100,11 @@ async def start_meilisearch():      # Check if MeiliSearch is running      if not await check_meilisearch_running() or not was_already_installed:          logger.debug("Starting MeiliSearch...") -        subprocess.Popen(["./meilisearch", "--no-analytics"], cwd=serverPath, stdout=subprocess.DEVNULL, -                         stderr=subprocess.STDOUT, close_fds=True, start_new_session=True) +        subprocess.Popen( +            ["./meilisearch", "--no-analytics"], +            cwd=serverPath, +            stdout=subprocess.DEVNULL, +            stderr=subprocess.STDOUT, +            close_fds=True, +            start_new_session=True, +        ) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 88ac13f6..6f4e4a87 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,20 +1,23 @@ +import json  import os  import traceback -from fastapi import WebSocket, APIRouter  from typing import Any, Coroutine, Dict, Optional, Union  from uuid import uuid4 -import json +from fastapi import APIRouter, WebSocket  from fastapi.websockets import WebSocketState -from ..plugins.steps.core.core import MessageStep -from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath, getSessionsListFilePath -from ..core.main import FullState, HistoryNode, SessionInfo  from ..core.autopilot import Autopilot -from .ide_protocol import AbstractIdeProtocolServer +from ..core.main import FullState  from ..libs.util.create_async_task import create_async_task  from ..libs.util.errors import SessionNotFound  from ..libs.util.logging import logger +from ..libs.util.paths import ( +    getSessionFilePath, +    getSessionsFolderPath, +    getSessionsListFilePath, +) +from .ide_protocol import AbstractIdeProtocolServer  router = APIRouter(prefix="/sessions", tags=["sessions"]) @@ -42,14 +45,21 @@ class SessionManager:              # And only if the IDE is still alive              sessions_folder = getSessionsFolderPath()              session_files = os.listdir(sessions_folder) -            if f"{session_id}.json" in session_files and session_id in self.registered_ides: +            if ( +                f"{session_id}.json" in session_files +                and session_id in self.registered_ides +            ):                  if self.registered_ides[session_id].session_id is not None: -                    return await self.new_session(self.registered_ides[session_id], session_id=session_id) +                    return await self.new_session( +                        self.registered_ides[session_id], session_id=session_id +                    )              raise KeyError("Session ID not recognized", session_id)          return self.sessions[session_id] -    async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None) -> Session: +    async def new_session( +        self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None +    ) -> Session:          logger.debug(f"New session: {session_id}")          # Load the persisted state (not being used right now) @@ -68,9 +78,9 @@ class SessionManager:          # Set up the autopilot to update the GUI          async def on_update(state: FullState): -            await session_manager.send_ws_data(session_id, "state_update", { -                "state": state.dict() -            }) +            await session_manager.send_ws_data( +                session_id, "state_update", {"state": state.dict()} +            )          autopilot.on_update(on_update) @@ -81,7 +91,7 @@ class SessionManager:              await ide.on_error(e)          def on_error(e: Exception) -> Coroutine: -            err_msg = '\n'.join(traceback.format_exception(e)) +            err_msg = "\n".join(traceback.format_exception(e))              return ide.showMessage(f"Error in Continue server: {err_msg}")          create_async_task(autopilot.run_policy(), on_error) @@ -90,9 +100,15 @@ class SessionManager:      async def remove_session(self, session_id: str):          logger.debug(f"Removing session: {session_id}")          if session_id in self.sessions: -            if session_id in self.registered_ides and self.registered_ides[session_id] is not None: +            if ( +                session_id in self.registered_ides +                and self.registered_ides[session_id] is not None +            ):                  ws_to_close = self.registered_ides[session_id].websocket -                if ws_to_close is not None and ws_to_close.client_state != WebSocketState.DISCONNECTED: +                if ( +                    ws_to_close is not None +                    and ws_to_close.client_state != WebSocketState.DISCONNECTED +                ):                      await self.sessions[session_id].autopilot.ide.websocket.close()              del self.sessions[session_id] @@ -117,7 +133,9 @@ class SessionManager:          with open(getSessionsListFilePath(), "w") as f:              json.dump(sessions_list, f) -    async def load_session(self, old_session_id: str, new_session_id: Optional[str] = None) -> str: +    async def load_session( +        self, old_session_id: str, new_session_id: Optional[str] = None +    ) -> str:          """Load the session's FullState from a json file"""          # First persist the current state @@ -142,10 +160,9 @@ class SessionManager:              # logger.debug(f"Session {session_id} has no websocket")              return -        await self.sessions[session_id].ws.send_json({ -            "messageType": message_type, -            "data": data -        }) +        await self.sessions[session_id].ws.send_json( +            {"messageType": message_type, "data": data} +        )  session_manager = SessionManager()  | 
