diff options
84 files changed, 2336 insertions, 1235 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index e048f877..98730d38 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -1,11 +1,10 @@ -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from typing import Coroutine, List, Union -from .config import ContinueConfig from ..models.filesystem_edit import FileSystemEdit +from .config import ContinueConfig +from .main import ChatMessage, History, Step from .observation import Observation -from .main import ChatMessage, History, Step, ChatMessageRole - """ [[Generate]] diff --git a/continuedev/src/continuedev/core/env.py b/continuedev/src/continuedev/core/env.py index 2692c348..60b86538 100644 --- a/continuedev/src/continuedev/core/env.py +++ b/continuedev/src/continuedev/core/env.py @@ -1,6 +1,7 @@ -from dotenv import load_dotenv import os +from dotenv import load_dotenv + def get_env_var(var_name: str): load_dotenv() @@ -8,21 +9,21 @@ def get_env_var(var_name: str): def make_sure_env_exists(): - if not os.path.exists('.env'): - with open('.env', 'w') as f: - f.write('') + if not os.path.exists(".env"): + with open(".env", "w") as f: + f.write("") def save_env_var(var_name: str, var_value: str): make_sure_env_exists() - with open('.env', 'r') as f: + with open(".env", "r") as f: lines = f.readlines() - with open('.env', 'w') as f: + with open(".env", "w") as f: values = {} for line in lines: - key, value = line.split('=') - value = value.replace('"', '') + key, value = line.split("=") + value = value.replace('"', "") values[key] = value values[var_name] = var_value diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index a33d777e..53440dae 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -1,10 +1,10 @@ import json from typing import Coroutine, Dict, List, Literal, Optional, Union -from pydantic.schema import schema +from pydantic import BaseModel, validator +from pydantic.schema import schema from ..models.main import ContinueBaseModel -from pydantic import BaseModel, validator from .observation import Observation ChatMessageRole = Literal["assistant", "user", "system", "function"] @@ -27,8 +27,7 @@ class ChatMessage(ContinueBaseModel): d = self.dict() del d["summary"] if d["function_call"] is not None: - d["function_call"]["name"] = d["function_call"]["name"].replace( - " ", "") + d["function_call"]["name"] = d["function_call"]["name"].replace(" ", "") if d["content"] is None: d["content"] = "" @@ -49,9 +48,9 @@ class ChatMessage(ContinueBaseModel): def resolve_refs(schema_data): def traverse(obj): if isinstance(obj, dict): - if '$ref' in obj: - ref = obj['$ref'] - parts = ref.split('/') + if "$ref" in obj: + ref = obj["$ref"] + parts = ref.split("/") ref_obj = schema_data for part in parts[1:]: ref_obj = ref_obj[part] @@ -67,8 +66,14 @@ def resolve_refs(schema_data): return traverse(schema_data) -unincluded_parameters = ["system_message", "chat_context", - "manage_own_chat_context", "hide", "name", "description"] +unincluded_parameters = [ + "system_message", + "chat_context", + "manage_own_chat_context", + "hide", + "name", + "description", +] def step_to_json_schema(step) -> str: @@ -82,7 +87,7 @@ def step_to_json_schema(step) -> str: return { "name": step.name.replace(" ", ""), "description": step.description or "", - "parameters": parameters + "parameters": parameters, } @@ -96,6 +101,7 @@ def step_to_fn_call_arguments(step: "Step") -> str: class HistoryNode(ContinueBaseModel): """A point in history, a list of which make up History""" + step: "Step" observation: Union[Observation, None] depth: int @@ -111,12 +117,14 @@ class HistoryNode(ContinueBaseModel): role="assistant", name=self.step.__class__.__name__, content=self.step.description or f"Ran function {self.step.name}", - summary=f"Called function {self.step.name}" - )] + summary=f"Called function {self.step.name}", + ) + ] class History(ContinueBaseModel): """A history of steps taken and their results""" + timeline: List[HistoryNode] current_index: int @@ -128,7 +136,7 @@ class History(ContinueBaseModel): return msgs def add_node(self, node: HistoryNode) -> int: - """ Add node and return the index where it was added """ + """Add node and return the index where it was added""" self.timeline.insert(self.current_index + 1, node) self.current_index += 1 return self.current_index @@ -138,10 +146,15 @@ class History(ContinueBaseModel): return None return self.timeline[self.current_index] - def get_last_at_depth(self, depth: int, include_current: bool = False) -> Union[HistoryNode, None]: + def get_last_at_depth( + self, depth: int, include_current: bool = False + ) -> Union[HistoryNode, None]: i = self.current_index if include_current else self.current_index - 1 while i >= 0: - if self.timeline[i].depth == depth and type(self.timeline[i].step).__name__ != "ManualEditStep": + if ( + self.timeline[i].depth == depth + and type(self.timeline[i].step).__name__ != "ManualEditStep" + ): return self.timeline[i] i -= 1 return None @@ -204,24 +217,27 @@ class ContextItemId(BaseModel): """ A ContextItemId is a unique identifier for a ContextItem. """ + provider_title: str item_id: str - @validator('provider_title', 'item_id') + @validator("provider_title", "item_id") def must_be_valid_id(cls, v): import re - if not re.match(r'^[0-9a-zA-Z_-]*$', v): + + if not re.match(r"^[0-9a-zA-Z_-]*$", v): raise ValueError( - "Both provider_title and item_id can only include characters 0-9, a-z, A-Z, -, and _") + "Both provider_title and item_id can only include characters 0-9, a-z, A-Z, -, and _" + ) return v def to_string(self) -> str: return f"{self.provider_title}-{self.item_id}" @staticmethod - def from_string(string: str) -> 'ContextItemId': - provider_title, *rest = string.split('-') - item_id = '-'.join(rest) + def from_string(string: str) -> "ContextItemId": + provider_title, *rest = string.split("-") + item_id = "-".join(rest) return ContextItemId(provider_title=provider_title, item_id=item_id) @@ -231,22 +247,24 @@ class ContextItemDescription(BaseModel): The id can be used to retrieve the ContextItem from the ContextManager. """ + name: str description: str - id: ContextItemId + id: ContextItemId class ContextItem(BaseModel): """ A ContextItem is a single item that is stored in the ContextManager. """ + description: ContextItemDescription content: str - @validator('content', pre=True) + @validator("content", pre=True) def content_must_be_string(cls, v): if v is None: - return '' + return "" return v editing: bool = False @@ -261,6 +279,7 @@ class SessionInfo(ContinueBaseModel): class FullState(ContinueBaseModel): """A full state of the program, including the history""" + history: History active: bool user_input_queue: List[str] @@ -286,7 +305,9 @@ class Policy(ContinueBaseModel): """A rule that determines which step to take next""" # Note that history is mutable, kinda sus - def next(self, config: ContinueConfig, history: History = History.from_empty()) -> "Step": + def next( + self, config: ContinueConfig, history: History = History.from_empty() + ) -> "Step": raise NotImplementedError @@ -373,7 +394,12 @@ class ContinueCustomException(Exception): message: str with_step: Union[Step, None] - def __init__(self, message: str, title: str = "Error while running step:", with_step: Union[Step, None] = None): + def __init__( + self, + message: str, + title: str = "Error while running step:", + with_step: Union[Step, None] = None, + ): self.message = message self.title = title self.with_step = with_step diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index 900762b6..52a52b1d 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,10 +1,13 @@ -from typing import Optional, Any -from pydantic import BaseModel, validator +from typing import Any, Optional + +from pydantic import BaseModel + from ..libs.llm import LLM class Models(BaseModel): """Main class that holds the current model configuration""" + default: LLM small: Optional[LLM] = None medium: Optional[LLM] = None diff --git a/continuedev/src/continuedev/core/observation.py b/continuedev/src/continuedev/core/observation.py index 126cf19e..8a5e454e 100644 --- a/continuedev/src/continuedev/core/observation.py +++ b/continuedev/src/continuedev/core/observation.py @@ -1,4 +1,5 @@ from pydantic import BaseModel, validator + from ..models.main import Traceback diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py index dba4874f..d77cce49 100644 --- a/continuedev/src/continuedev/libs/chroma/query.py +++ b/continuedev/src/continuedev/libs/chroma/query.py @@ -1,12 +1,19 @@ import json +import os import subprocess +from functools import cached_property from typing import List, Tuple -from llama_index import GPTVectorStoreIndex, StorageContext, load_index_from_storage, Document + +from llama_index import ( + Document, + GPTVectorStoreIndex, + StorageContext, + load_index_from_storage, +) from llama_index.langchain_helpers.text_splitter import TokenTextSplitter -import os -from .update import filter_ignored_files, load_gpt_index_documents + from ..util.logging import logger -from functools import cached_property +from .update import filter_ignored_files, load_gpt_index_documents class ChromaIndexManager: @@ -18,23 +25,42 @@ class ChromaIndexManager: @cached_property def current_commit(self) -> str: """Get the current commit.""" - return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=self.workspace_dir).decode("utf-8").strip() + return ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], cwd=self.workspace_dir + ) + .decode("utf-8") + .strip() + ) @cached_property def current_branch(self) -> str: """Get the current branch.""" - return subprocess.check_output( - ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=self.workspace_dir).decode("utf-8").strip() + return ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=self.workspace_dir + ) + .decode("utf-8") + .strip() + ) @cached_property def index_dir(self) -> str: - return os.path.join(self.workspace_dir, ".continue", "chroma", self.current_branch) + return os.path.join( + self.workspace_dir, ".continue", "chroma", self.current_branch + ) @cached_property def git_root_dir(self): """Get the root directory of a Git repository.""" try: - return subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], cwd=self.workspace_dir).strip().decode() + return ( + subprocess.check_output( + ["git", "rev-parse", "--show-toplevel"], cwd=self.workspace_dir + ) + .strip() + .decode() + ) except subprocess.CalledProcessError: return None @@ -57,8 +83,7 @@ class ChromaIndexManager: try: text_chunks = text_splitter.split_text(doc.text) except: - logger.warning( - f"ERROR (probably found special token): {doc.text}") + logger.warning(f"ERROR (probably found special token): {doc.text}") continue # lol filename = doc.extra_info["filename"] chunks[filename] = len(text_chunks) @@ -66,8 +91,7 @@ class ChromaIndexManager: doc_chunks.append(Document(text, doc_id=f"{filename}::{i}")) with open(f"{self.index_dir}/metadata.json", "w") as f: - json.dump({"commit": self.current_commit, - "chunks": chunks}, f, indent=4) + json.dump({"commit": self.current_commit, "chunks": chunks}, f, indent=4) index = GPTVectorStoreIndex([]) @@ -89,17 +113,30 @@ class ChromaIndexManager: with open(metadata, "r") as f: previous_commit = json.load(f)["commit"] - modified_deleted_files = subprocess.check_output( - ["git", "diff", "--name-only", previous_commit, self.current_commit]).decode("utf-8").strip() + modified_deleted_files = ( + subprocess.check_output( + ["git", "diff", "--name-only", previous_commit, self.current_commit] + ) + .decode("utf-8") + .strip() + ) modified_deleted_files = modified_deleted_files.split("\n") modified_deleted_files = [f for f in modified_deleted_files if f] deleted_files = [ - f for f in modified_deleted_files if not os.path.exists(os.path.join(self.workspace_dir, f))] + f + for f in modified_deleted_files + if not os.path.exists(os.path.join(self.workspace_dir, f)) + ] modified_files = [ - f for f in modified_deleted_files if os.path.exists(os.path.join(self.workspace_dir, f))] + f + for f in modified_deleted_files + if os.path.exists(os.path.join(self.workspace_dir, f)) + ] - return filter_ignored_files(modified_files, self.index_dir), filter_ignored_files(deleted_files, self.index_dir) + return filter_ignored_files( + modified_files, self.index_dir + ), filter_ignored_files(deleted_files, self.index_dir) def update_codebase_index(self): """Update the index with a list of files.""" @@ -108,15 +145,13 @@ class ChromaIndexManager: self.create_codebase_index() else: # index = GPTFaissIndex.load_from_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index") - index = GPTVectorStoreIndex.load_from_disk( - f"{self.index_dir}/index.json") + index = GPTVectorStoreIndex.load_from_disk(f"{self.index_dir}/index.json") modified_files, deleted_files = self.get_modified_deleted_files() with open(f"{self.index_dir}/metadata.json", "r") as f: metadata = json.load(f) for file in deleted_files: - num_chunks = metadata["chunks"][file] for i in range(num_chunks): index.delete(f"{file}::{i}") @@ -126,9 +161,7 @@ class ChromaIndexManager: logger.debug(f"Deleted {file}") for file in modified_files: - if file in metadata["chunks"]: - num_chunks = metadata["chunks"][file] for i in range(num_chunks): @@ -159,12 +192,10 @@ class ChromaIndexManager: def query_codebase_index(self, query: str) -> str: """Query the codebase index.""" if not self.check_index_exists(): - logger.debug( - f"No index found for the codebase at {self.index_dir}") + logger.debug(f"No index found for the codebase at {self.index_dir}") return "" - storage_context = StorageContext.from_defaults( - persist_dir=self.index_dir) + storage_context = StorageContext.from_defaults(persist_dir=self.index_dir) index = load_index_from_storage(storage_context) # index = GPTVectorStoreIndex.load_from_disk(path) engine = index.as_query_engine() @@ -173,14 +204,15 @@ class ChromaIndexManager: def query_additional_index(self, query: str) -> str: """Query the additional index.""" index = GPTVectorStoreIndex.load_from_disk( - os.path.join(self.index_dir, 'additional_index.json')) + os.path.join(self.index_dir, "additional_index.json") + ) return index.query(query) def replace_additional_index(self, info: str): """Replace the additional index with the given info.""" - with open(f'{self.index_dir}/additional_context.txt', 'w') as f: + with open(f"{self.index_dir}/additional_context.txt", "w") as f: f.write(info) documents = [Document(info)] index = GPTVectorStoreIndex(documents) - index.save_to_disk(f'{self.index_dir}/additional_index.json') + index.save_to_disk(f"{self.index_dir}/additional_index.json") logger.debug("Additional index replaced") diff --git a/continuedev/src/continuedev/libs/chroma/update.py b/continuedev/src/continuedev/libs/chroma/update.py index d5326a06..7a1217f9 100644 --- a/continuedev/src/continuedev/libs/chroma/update.py +++ b/continuedev/src/continuedev/libs/chroma/update.py @@ -1,28 +1,24 @@ # import faiss import os import subprocess - -from llama_index import SimpleDirectoryReader, Document from typing import List + from dotenv import load_dotenv +from llama_index import Document, SimpleDirectoryReader load_dotenv() -FILE_TYPES_TO_IGNORE = [ - '.pyc', - '.png', - '.jpg', - '.jpeg', - '.gif', - '.svg', - '.ico' -] +FILE_TYPES_TO_IGNORE = [".pyc", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico"] def filter_ignored_files(files: List[str], root_dir: str): """Further filter files before indexing.""" for file in files: - if file.endswith(tuple(FILE_TYPES_TO_IGNORE)) or file.startswith('.git') or file.startswith('archive'): + if ( + file.endswith(tuple(FILE_TYPES_TO_IGNORE)) + or file.startswith(".git") + or file.startswith("archive") + ): continue # nice yield root_dir + "/" + file @@ -30,9 +26,15 @@ def filter_ignored_files(files: List[str], root_dir: str): def get_git_ignored_files(root_dir: str): """Get the list of ignored files in a Git repository.""" try: - output = subprocess.check_output( - ['git', 'ls-files', '--ignored', '--others', '--exclude-standard'], cwd=root_dir).strip().decode() - return output.split('\n') + output = ( + subprocess.check_output( + ["git", "ls-files", "--ignored", "--others", "--exclude-standard"], + cwd=root_dir, + ) + .strip() + .decode() + ) + return output.split("\n") except subprocess.CalledProcessError: return [] @@ -57,4 +59,8 @@ def load_gpt_index_documents(root: str) -> List[Document]: # Get input files input_files = get_input_files(root) # Use SimpleDirectoryReader to load the files into Documents - return SimpleDirectoryReader(root, input_files=input_files, file_metadata=lambda filename: {"filename": filename}).load_data() + return SimpleDirectoryReader( + root, + input_files=input_files, + file_metadata=lambda filename: {"filename": filename}, + ).load_data() diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 70c67856..4af6b8e2 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,5 +1,5 @@ from abc import ABC, abstractproperty -from typing import Any, Coroutine, Dict, Generator, List, Union, Optional +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union from ...core.main import ChatMessage from ...models.main import ContinueBaseModel @@ -28,15 +28,21 @@ class LLM(ContinueBaseModel, ABC): """Stop the connection to the LLM.""" raise NotImplementedError - async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + async def complete( + self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + ) -> Coroutine[Any, Any, str]: """Return the completion of the text with the given temperature.""" raise NotImplementedError - def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + def stream_complete( + self, prompt, with_history: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: """Stream the completion through generator.""" raise NotImplementedError - async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat( + self, messages: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: """Stream the chat through generator.""" raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 9d7bc93f..9a7d0ac9 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,10 +1,15 @@ -from functools import cached_property -import time from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union + +from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic + from ...core.main import ChatMessage -from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages +from ..util.count_tokens import ( + DEFAULT_ARGS, + compile_chat_messages, + count_tokens, + format_chat_messages, +) class AnthropicLLM(LLM): @@ -19,7 +24,13 @@ class AnthropicLLM(LLM): write_log: Optional[Callable[[str], None]] = None - async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs): + async def start( + self, + *, + api_key: Optional[str] = None, + write_log: Callable[[str], None], + **kwargs, + ): self.write_log = write_log self._async_client = AsyncAnthropic(api_key=self.api_key) @@ -58,7 +69,11 @@ class AnthropicLLM(LLM): prompt = "" # Anthropic prompt must start with a Human turn - if len(messages) > 0 and messages[0]["role"] != "user" and messages[0]["role"] != "system": + if ( + len(messages) > 0 + and messages[0]["role"] != "user" + and messages[0]["role"] != "system" + ): prompt += f"{HUMAN_PROMPT} Hello." for msg in messages: prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " @@ -66,7 +81,9 @@ class AnthropicLLM(LLM): prompt += AI_PROMPT return prompt - async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_complete( + self, prompt, with_history: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True @@ -76,50 +93,62 @@ class AnthropicLLM(LLM): self.write_log(f"Prompt: \n\n{prompt}") completion = "" async for chunk in await self._async_client.completions.create( - prompt=prompt, - **args + prompt=prompt, **args ): yield chunk.completion completion += chunk.completion self.write_log(f"Completion: \n\n{completion}") - async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat( + self, messages: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True args = self._transform_args(args) messages = compile_chat_messages( - args["model"], messages, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) + args["model"], + messages, + self.context_length, + args["max_tokens_to_sample"], + functions=args.get("functions", None), + system_message=self.system_message, + ) completion = "" self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async for chunk in await self._async_client.completions.create( - prompt=self.__messages_to_prompt(messages), - **args + prompt=self.__messages_to_prompt(messages), **args ): - yield { - "role": "assistant", - "content": chunk.completion - } + yield {"role": "assistant", "content": chunk.completion} completion += chunk.completion self.write_log(f"Completion: \n\n{completion}") - async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + async def complete( + self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + ) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} args = self._transform_args(args) messages = compile_chat_messages( - args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) + args["model"], + with_history, + self.context_length, + args["max_tokens_to_sample"], + prompt, + functions=None, + system_message=self.system_message, + ) - completion = "" self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - resp = (await self._async_client.completions.create( - prompt=self.__messages_to_prompt(messages), - **args - )).completion + resp = ( + await self._async_client.completions.create( + prompt=self.__messages_to_prompt(messages), **args + ) + ).completion self.write_log(f"Completion: \n\n{resp}") return resp diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 25a61e63..2d60384b 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,12 +1,11 @@ -from functools import cached_property import json from typing import Any, Coroutine, Dict, Generator, List, Union -from pydantic import ConfigDict import aiohttp + from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens class GGML(LLM): @@ -22,7 +21,8 @@ class GGML(LLM): async def start(self, **kwargs): self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) + ) async def stop(self): await self._client_session.close() @@ -42,19 +42,27 @@ class GGML(LLM): def count_tokens(self, text: str): return count_tokens(self.name, text) - async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_complete( + self, prompt, with_history: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) - - async with self._client_session.post(f"{self.server_url}/v1/completions", json={ - "messages": messages, - **args - }) as resp: + self.name, + with_history, + self.context_length, + args["max_tokens"], + prompt, + functions=args.get("functions", None), + system_message=self.system_message, + ) + + async with self._client_session.post( + f"{self.server_url}/v1/completions", json={"messages": messages, **args} + ) as resp: async for line in resp.content.iter_any(): if line: try: @@ -62,40 +70,66 @@ class GGML(LLM): except: raise Exception(str(line)) - async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat( + self, messages: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + self.name, + messages, + self.context_length, + args["max_tokens"], + None, + functions=args.get("functions", None), + system_message=self.system_message, + ) args["stream"] = True - async with self._client_session.post(f"{self.server_url}/v1/chat/completions", json={ - "messages": messages, - **args - }) as resp: + async with self._client_session.post( + f"{self.server_url}/v1/chat/completions", + json={"messages": messages, **args}, + ) as resp: # This is streaming application/json instaed of text/event-stream async for line in resp.content.iter_chunks(): if line[1]: try: json_chunk = line[0].decode("utf-8") - if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): + if json_chunk.startswith(": ping - ") or json_chunk.startswith( + "data: [DONE]" + ): continue chunks = json_chunk.split("\n") for chunk in chunks: if chunk.strip() != "": yield { "role": "assistant", - "content": json.loads(chunk[6:])["choices"][0]["delta"] + "content": json.loads(chunk[6:])["choices"][0][ + "delta" + ], } except: raise Exception(str(line[0])) - async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + async def complete( + self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + ) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} - async with self._client_session.post(f"{self.server_url}/v1/completions", json={ - "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), - **args - }) as resp: + async with self._client_session.post( + f"{self.server_url}/v1/completions", + json={ + "messages": compile_chat_messages( + args["model"], + with_history, + self.context_length, + args["max_tokens"], + prompt, + functions=None, + system_message=self.system_message, + ), + **args, + }, + ) as resp: try: return await resp.text() except: diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 8945250c..76331a28 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,13 +1,13 @@ -from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List + import aiohttp import requests from ...core.main import ChatMessage -from ..util.count_tokens import DEFAULT_ARGS, count_tokens -from ...core.main import ChatMessage from ..llm import LLM +from ..util.count_tokens import DEFAULT_ARGS, count_tokens -DEFAULT_MAX_TIME = 120. +DEFAULT_MAX_TIME = 120.0 class HuggingFaceInferenceAPI(LLM): @@ -24,7 +24,8 @@ class HuggingFaceInferenceAPI(LLM): async def start(self, **kwargs): self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) + ) async def stop(self): await self._client_session.close() @@ -44,35 +45,43 @@ class HuggingFaceInferenceAPI(LLM): def count_tokens(self, text: str): return count_tokens(self.name, text) - async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): + async def complete( + self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + ): """Return the completion of the text with the given temperature.""" API_URL = f"https://api-inference.huggingface.co/models/{self.model}" - headers = { - "Authorization": f"Bearer {self.hf_token}"} - - response = requests.post(API_URL, headers=headers, json={ - "inputs": prompt, "parameters": { - "max_new_tokens": min(250, self.max_context_length - self.count_tokens(prompt)), - "max_time": DEFAULT_MAX_TIME, - "return_full_text": False, - } - }) + headers = {"Authorization": f"Bearer {self.hf_token}"} + + response = requests.post( + API_URL, + headers=headers, + json={ + "inputs": prompt, + "parameters": { + "max_new_tokens": min( + 250, self.max_context_length - self.count_tokens(prompt) + ), + "max_time": DEFAULT_MAX_TIME, + "return_full_text": False, + }, + }, + ) data = response.json() # Error if the response is not a list if not isinstance(data, list): - raise Exception( - "Hugging Face returned an error response: \n\n", data) + raise Exception("Hugging Face returned an error response: \n\n", data) return data[0]["generated_text"] - async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]: + async def stream_chat( + self, messages: List[ChatMessage] = None, **kwargs + ) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]: response = await self.complete(messages[-1].content, messages[:-1]) - yield { - "content": response, - "role": "assistant" - } + yield {"content": response, "role": "assistant"} - async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Any | List | Dict, None, None]: + async def stream_complete( + self, prompt, with_history: List[ChatMessage] = None, **kwargs + ) -> Generator[Any | List | Dict, None, None]: response = await self.complete(prompt, with_history) yield response diff --git a/continuedev/src/continuedev/libs/llm/hugging_face.py b/continuedev/src/continuedev/libs/llm/hugging_face.py index b0db585b..f246a43c 100644 --- a/continuedev/src/continuedev/libs/llm/hugging_face.py +++ b/continuedev/src/continuedev/libs/llm/hugging_face.py @@ -1,5 +1,6 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + from .llm import LLM -from transformers import AutoTokenizer, AutoModelForCausalLM class HuggingFace(LLM): @@ -12,6 +13,5 @@ class HuggingFace(LLM): args = {"max_tokens": 100} args.update(kwargs) input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids - generated_ids = self.model.generate( - input_ids, max_length=args["max_tokens"]) + generated_ids = self.model.generate(input_ids, max_length=args["max_tokens"]) return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index fbc2c43f..65e5db3a 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -1,9 +1,9 @@ -from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union from ...core.main import ChatMessage from . import LLM -from .proxy_server import ProxyServer from .openai import OpenAI +from .proxy_server import ProxyServer class MaybeProxyOpenAI(LLM): @@ -24,7 +24,13 @@ class MaybeProxyOpenAI(LLM): def context_length(self): return self.llm.context_length - async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]): + async def start( + self, + *, + api_key: Optional[str] = None, + unique_id: str, + write_log: Callable[[str], None] + ): if self.api_key is None or self.api_key.strip() == "": self.llm = ProxyServer(model=self.model) else: @@ -35,16 +41,21 @@ class MaybeProxyOpenAI(LLM): async def stop(self): await self.llm.stop() - async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + async def complete( + self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + ) -> Coroutine[Any, Any, str]: return await self.llm.complete(prompt, with_history=with_history, **kwargs) - async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: - resp = self.llm.stream_complete( - prompt, with_history=with_history, **kwargs) + async def stream_complete( + self, prompt, with_history: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: + resp = self.llm.stream_complete(prompt, with_history=with_history, **kwargs) async for item in resp: yield item - async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat( + self, messages: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: resp = self.llm.stream_chat(messages=messages, **kwargs) async for item in resp: yield item diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index a9f9f7aa..ef3cdc66 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -1,11 +1,11 @@ -from functools import cached_property import json from typing import Any, Coroutine, Dict, Generator, List, Union import aiohttp + from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens class Ollama(LLM): @@ -62,20 +62,32 @@ class Ollama(LLM): if msgs[i]["role"] == "user": prompt += f"[INST] {msgs[i]['content']} [/INST]" else: - prompt += msgs[i]['content'] + prompt += msgs[i]["content"] return prompt - async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_complete( + self, prompt, with_history: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + self.name, + with_history, + self.context_length, + args["max_tokens"], + prompt, + functions=None, + system_message=self.system_message, + ) prompt = self.convert_to_chat(messages) - async with self._client_session.post(f"{self.server_url}/api/generate", json={ - "prompt": prompt, - "model": self.model, - }) as resp: + async with self._client_session.post( + f"{self.server_url}/api/generate", + json={ + "prompt": prompt, + "model": self.model, + }, + ) as resp: async for line in resp.content.iter_any(): if line: try: @@ -89,16 +101,28 @@ class Ollama(LLM): except: raise Exception(str(line[0])) - async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat( + self, messages: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, messages, self.context_length, args["max_tokens"], None, functions=None, system_message=self.system_message) + self.name, + messages, + self.context_length, + args["max_tokens"], + None, + functions=None, + system_message=self.system_message, + ) prompt = self.convert_to_chat(messages) - async with self._client_session.post(f"{self.server_url}/api/generate", json={ - "prompt": prompt, - "model": self.model, - }) as resp: + async with self._client_session.post( + f"{self.server_url}/api/generate", + json={ + "prompt": prompt, + "model": self.model, + }, + ) as resp: # This is streaming application/json instaed of text/event-stream async for line in resp.content.iter_chunks(): if line[1]: @@ -111,18 +135,23 @@ class Ollama(LLM): if "response" in j: yield { "role": "assistant", - "content": j["response"] + "content": j["response"], } except: raise Exception(str(line[0])) - async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + async def complete( + self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + ) -> Coroutine[Any, Any, str]: completion = "" - async with self._client_session.post(f"{self.server_url}/api/generate", json={ - "prompt": prompt, - "model": self.model, - }) as resp: + async with self._client_session.post( + f"{self.server_url}/api/generate", + json={ + "prompt": prompt, + "model": self.model, + }, + ) as resp: async for line in resp.content.iter_any(): if line: try: diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 7eb516a3..c2d86841 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,12 +1,28 @@ -from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional +from typing import ( + Any, + Callable, + Coroutine, + Dict, + Generator, + List, + Literal, + Optional, + Union, +) -from pydantic import BaseModel +import certifi import openai +from pydantic import BaseModel from ...core.main import ChatMessage -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top from ..llm import LLM -import certifi +from ..util.count_tokens import ( + DEFAULT_ARGS, + compile_chat_messages, + count_tokens, + format_chat_messages, + prune_raw_prompt_from_top, +) class OpenAIServerInfo(BaseModel): @@ -16,9 +32,7 @@ class OpenAIServerInfo(BaseModel): api_type: Literal["azure", "openai"] = "openai" -CHAT_MODELS = { - "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" -} +CHAT_MODELS = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"} MAX_TOKENS_FOR_MODEL = { "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-0613": 4096, @@ -27,7 +41,7 @@ MAX_TOKENS_FOR_MODEL = { "gpt-35-turbo-16k": 16_384, "gpt-35-turbo-0613": 4096, "gpt-35-turbo": 4096, - "gpt-4-32k": 32_768 + "gpt-4-32k": 32_768, } @@ -43,7 +57,13 @@ class OpenAI(LLM): system_message: Optional[str] = None write_log: Optional[Callable[[str], None]] = None - async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs): + async def start( + self, + *, + api_key: Optional[str] = None, + write_log: Callable[[str], None], + **kwargs, + ): self.write_log = write_log openai.api_key = self.api_key @@ -54,7 +74,7 @@ class OpenAI(LLM): if self.openai_server_info.api_version is not None: openai.api_version = self.openai_server_info.api_version - if self.verify_ssl == False: + if self.verify_ssl is False: openai.verify_ssl_certs = False openai.ca_bundle_path = self.ca_bundle_path or certifi.where() @@ -80,14 +100,23 @@ class OpenAI(LLM): def count_tokens(self, text: str): return count_tokens(self.model, text) - async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_complete( + self, prompt, with_history: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True if args["model"] in CHAT_MODELS: messages = compile_chat_messages( - args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], + with_history, + self.context_length, + args["max_tokens"], + prompt, + functions=None, + system_message=self.system_message, + ) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") completion = "" async for chunk in await openai.ChatCompletion.acreate( @@ -110,17 +139,28 @@ class OpenAI(LLM): self.write_log(f"Completion:\n\n{completion}") - async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat( + self, messages: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it? - args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613" + args["model"] = ( + self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613" + ) if not args["model"].endswith("0613") and "functions" in args: del args["functions"] messages = compile_chat_messages( - args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["model"], + messages, + self.context_length, + args["max_tokens"], + None, + functions=args.get("functions", None), + system_message=self.system_message, + ) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") completion = "" async for chunk in await openai.ChatCompletion.acreate( @@ -132,26 +172,48 @@ class OpenAI(LLM): completion += chunk.choices[0].delta.content self.write_log(f"Completion: \n\n{completion}") - async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + async def complete( + self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + ) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} if args["model"] in CHAT_MODELS: messages = compile_chat_messages( - args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], + with_history, + self.context_length, + args["max_tokens"], + prompt, + functions=None, + system_message=self.system_message, + ) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - resp = (await openai.ChatCompletion.acreate( - messages=messages, - **args, - )).choices[0].message.content + resp = ( + ( + await openai.ChatCompletion.acreate( + messages=messages, + **args, + ) + ) + .choices[0] + .message.content + ) self.write_log(f"Completion: \n\n{resp}") else: prompt = prune_raw_prompt_from_top( - args["model"], self.context_length, prompt, args["max_tokens"]) + args["model"], self.context_length, prompt, args["max_tokens"] + ) self.write_log(f"Prompt:\n\n{prompt}") - resp = (await openai.Completion.acreate( - prompt=prompt, - **args, - )).choices[0].text + resp = ( + ( + await openai.Completion.acreate( + prompt=prompt, + **args, + ) + ) + .choices[0] + .text + ) self.write_log(f"Completion:\n\n{resp}") return resp diff --git a/continuedev/src/continuedev/libs/llm/prompt_utils.py b/continuedev/src/continuedev/libs/llm/prompt_utils.py index 51b64732..930b5220 100644 --- a/continuedev/src/continuedev/libs/llm/prompt_utils.py +++ b/continuedev/src/continuedev/libs/llm/prompt_utils.py @@ -1,4 +1,5 @@ from typing import Dict, List, Union + from ...models.filesystem import RangeInFileWithContents from ...models.filesystem_edit import FileEdit @@ -11,23 +12,28 @@ class MarkdownStyleEncoderDecoder: self.range_in_files = range_in_files def encode(self) -> str: - return "\n\n".join([ - f"File ({rif.filepath})\n```\n{rif.contents}\n```" - for rif in self.range_in_files - ]) + return "\n\n".join( + [ + f"File ({rif.filepath})\n```\n{rif.contents}\n```" + for rif in self.range_in_files + ] + ) def _suggestions_to_file_edits(self, suggestions: Dict[str, str]) -> List[FileEdit]: file_edits: List[FileEdit] = [] for suggestion_filepath, suggestion in suggestions.items(): matching_rifs = list( - filter(lambda r: r.filepath == suggestion_filepath, self.range_in_files)) - if (len(matching_rifs) > 0): + filter(lambda r: r.filepath == suggestion_filepath, self.range_in_files) + ) + if len(matching_rifs) > 0: range_in_file = matching_rifs[0] - file_edits.append(FileEdit( - range=range_in_file.range, - filepath=range_in_file.filepath, - replacement=suggestion - )) + file_edits.append( + FileEdit( + range=range_in_file.range, + filepath=range_in_file.filepath, + replacement=suggestion, + ) + ) return file_edits @@ -35,9 +41,9 @@ class MarkdownStyleEncoderDecoder: if len(self.range_in_files) == 0: return {} - if not '```' in completion: + if "```" not in completion: completion = "```\n" + completion + "\n```" - if completion.strip().splitlines()[0].strip() == '```': + if completion.strip().splitlines()[0].strip() == "```": first_filepath = self.range_in_files[0].filepath completion = f"File ({first_filepath})\n" + completion @@ -56,8 +62,7 @@ class MarkdownStyleEncoderDecoder: elif inside_file: if line.startswith("```"): inside_file = False - suggestions[current_filepath] = "\n".join( - current_file_lines) + suggestions[current_filepath] = "\n".join(current_file_lines) current_file_lines = [] current_filepath = None else: diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index fd8210e9..acc6653d 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,13 +1,29 @@ import json +import ssl import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional +from typing import ( + Any, + Callable, + Coroutine, + Dict, + Generator, + List, + Optional, + Union, +) + import aiohttp +import certifi + from ...core.main import ChatMessage from ..llm import LLM +from ..util.count_tokens import ( + DEFAULT_ARGS, + compile_chat_messages, + count_tokens, + format_chat_messages, +) from ..util.telemetry import posthog_logger -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens, format_chat_messages -import certifi -import ssl ca_bundle_path = certifi.where() ssl_context = ssl.create_default_context(cafile=ca_bundle_path) @@ -37,9 +53,17 @@ class ProxyServer(LLM): class Config: arbitrary_types_allowed = True - async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs): + async def start( + self, + *, + api_key: Optional[str] = None, + write_log: Callable[[str], None], + unique_id: str, + **kwargs, + ): self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(ssl_context=ssl_context)) + connector=aiohttp.TCPConnector(ssl_context=ssl_context) + ) self.write_log = write_log self.unique_id = unique_id @@ -65,16 +89,26 @@ class ProxyServer(LLM): # headers with unique id return {"unique_id": self.unique_id} - async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + async def complete( + self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + ) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], + with_history, + self.context_length, + args["max_tokens"], + prompt, + functions=None, + system_message=self.system_message, + ) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with self._client_session.post(f"{SERVER_URL}/complete", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: + async with self._client_session.post( + f"{SERVER_URL}/complete", + json={"messages": messages, **args}, + headers=self.get_headers(), + ) as resp: if resp.status != 200: raise Exception(await resp.text()) @@ -82,16 +116,26 @@ class ProxyServer(LLM): self.write_log(f"Completion: \n\n{response_text}") return response_text - async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: + async def stream_chat( + self, messages: List[ChatMessage] = None, **kwargs + ) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["model"], + messages, + self.context_length, + args["max_tokens"], + None, + functions=args.get("functions", None), + system_message=self.system_message, + ) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: + async with self._client_session.post( + f"{SERVER_URL}/stream_chat", + json={"messages": messages, **args}, + headers=self.get_headers(), + ) as resp: # This is streaming application/json instaed of text/event-stream completion = "" if resp.status != 200: @@ -109,23 +153,40 @@ class ProxyServer(LLM): if "content" in loaded_chunk: completion += loaded_chunk["content"] except Exception as e: - posthog_logger.capture_event("proxy_server_parse_error", { - "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) + posthog_logger.capture_event( + "proxy_server_parse_error", + { + "error_title": "Proxy server stream_chat parsing failed", + "error_message": "\n".join( + traceback.format_exception(e) + ), + }, + ) else: break self.write_log(f"Completion: \n\n{completion}") - async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_complete( + self, prompt, with_history: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + self.model, + with_history, + self.context_length, + args["max_tokens"], + prompt, + functions=args.get("functions", None), + system_message=self.system_message, + ) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: + async with self._client_session.post( + f"{SERVER_URL}/stream_complete", + json={"messages": messages, **args}, + headers=self.get_headers(), + ) as resp: completion = "" if resp.status != 200: raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py index 0dd359e7..c4373185 100644 --- a/continuedev/src/continuedev/libs/llm/replicate.py +++ b/continuedev/src/continuedev/libs/llm/replicate.py @@ -1,10 +1,10 @@ -from abc import abstractproperty -from typing import List, Optional -import replicate import concurrent.futures +from typing import List + +import replicate -from ..util.count_tokens import DEFAULT_ARGS, count_tokens from ...core.main import ChatMessage +from ..util.count_tokens import DEFAULT_ARGS, count_tokens from . import LLM @@ -36,10 +36,12 @@ class ReplicateLLM(LLM): async def stop(self): pass - async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): + async def complete( + self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + ): def helper(): output = self._client.run(self.model, input={"message": prompt}) - completion = '' + completion = "" for item in output: completion += item @@ -51,13 +53,14 @@ class ReplicateLLM(LLM): return completion - async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs): + async def stream_complete( + self, prompt, with_history: List[ChatMessage] = None, **kwargs + ): for item in self._client.run(self.model, input={"message": prompt}): yield item async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs): - for item in self._client.run(self.model, input={"message": messages[-1].content}): - yield { - "content": item, - "role": "assistant" - } + for item in self._client.run( + self.model, input={"message": messages[-1].content} + ): + yield {"content": item, "role": "assistant"} diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index 874dea07..44f5030c 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -2,9 +2,10 @@ import json from typing import Any, Coroutine, Dict, Generator, List, Union import aiohttp + from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens +from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens class TogetherLLM(LLM): @@ -19,7 +20,8 @@ class TogetherLLM(LLM): async def start(self, **kwargs): self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) + ) async def stop(self): await self._client_session.close() @@ -53,21 +55,29 @@ class TogetherLLM(LLM): prompt += "<bot>:" return prompt - async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_complete( + self, prompt, with_history: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream_tokens"] = True args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) - - async with self._client_session.post(f"{self.base_url}/inference", json={ - "prompt": self.convert_to_prompt(messages), - **args - }, headers={ - "Authorization": f"Bearer {self.api_key}" - }) as resp: + self.name, + with_history, + self.context_length, + args["max_tokens"], + prompt, + functions=args.get("functions", None), + system_message=self.system_message, + ) + + async with self._client_session.post( + f"{self.base_url}/inference", + json={"prompt": self.convert_to_prompt(messages), **args}, + headers={"Authorization": f"Bearer {self.api_key}"}, + ) as resp: async for line in resp.content.iter_any(): if line: try: @@ -75,22 +85,32 @@ class TogetherLLM(LLM): except: raise Exception(str(line)) - async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat( + self, messages: List[ChatMessage] = None, **kwargs + ) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + self.name, + messages, + self.context_length, + args["max_tokens"], + None, + functions=args.get("functions", None), + system_message=self.system_message, + ) args["stream_tokens"] = True - async with self._client_session.post(f"{self.base_url}/inference", json={ - "prompt": self.convert_to_prompt(messages), - **args - }, headers={ - "Authorization": f"Bearer {self.api_key}" - }) as resp: + async with self._client_session.post( + f"{self.base_url}/inference", + json={"prompt": self.convert_to_prompt(messages), **args}, + headers={"Authorization": f"Bearer {self.api_key}"}, + ) as resp: async for line in resp.content.iter_chunks(): if line[1]: json_chunk = line[0].decode("utf-8") - if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): + if json_chunk.startswith(": ping - ") or json_chunk.startswith( + "data: [DONE]" + ): continue if json_chunk.startswith("data: "): json_chunk = json_chunk[6:] @@ -101,20 +121,28 @@ class TogetherLLM(LLM): if "choices" in json_chunk: yield { "role": "assistant", - "content": json_chunk["choices"][0]["text"] + "content": json_chunk["choices"][0]["text"], } - async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + async def complete( + self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + ) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} - messages = compile_chat_messages(args["model"], with_history, self.context_length, - args["max_tokens"], prompt, functions=None, system_message=self.system_message) - async with self._client_session.post(f"{self.base_url}/inference", json={ - "prompt": self.convert_to_prompt(messages), - **args - }, headers={ - "Authorization": f"Bearer {self.api_key}" - }) as resp: + messages = compile_chat_messages( + args["model"], + with_history, + self.context_length, + args["max_tokens"], + prompt, + functions=None, + system_message=self.system_message, + ) + async with self._client_session.post( + f"{self.base_url}/inference", + json={"prompt": self.convert_to_prompt(messages), **args}, + headers={"Authorization": f"Bearer {self.api_key}"}, + ) as resp: try: text = await resp.text() j = json.loads(text) diff --git a/continuedev/src/continuedev/libs/util/calculate_diff.py b/continuedev/src/continuedev/libs/util/calculate_diff.py index e8e48839..99301ae7 100644 --- a/continuedev/src/continuedev/libs/util/calculate_diff.py +++ b/continuedev/src/continuedev/libs/util/calculate_diff.py @@ -1,7 +1,8 @@ import difflib from typing import List -from ...models.main import Position, Range + from ...models.filesystem import FileEdit +from ...models.main import Position, Range def calculate_diff(filepath: str, original: str, updated: str) -> List[FileEdit]: @@ -14,16 +15,25 @@ def calculate_diff(filepath: str, original: str, updated: str) -> List[FileEdit] if tag == "equal": pass elif tag == "delete": - edits.append(FileEdit.from_deletion( - filepath, Range.from_indices(original, i1, i2))) + edits.append( + FileEdit.from_deletion(filepath, Range.from_indices(original, i1, i2)) + ) offset -= i2 - i1 elif tag == "insert": - edits.append(FileEdit.from_insertion( - filepath, Position.from_index(original, i1), replacement)) + edits.append( + FileEdit.from_insertion( + filepath, Position.from_index(original, i1), replacement + ) + ) offset += j2 - j1 elif tag == "replace": - edits.append(FileEdit(filepath=filepath, range=Range.from_indices( - original, i1, i2), replacement=replacement)) + edits.append( + FileEdit( + filepath=filepath, + range=Range.from_indices(original, i1, i2), + replacement=replacement, + ) + ) offset += (j2 - j1) - (i2 - i1) else: raise Exception("Unexpected difflib.SequenceMatcher tag: " + tag) @@ -59,17 +69,27 @@ def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit if tag == "equal": continue # ;) elif tag == "delete": - edits.append(FileEdit.from_deletion( - filepath, Range.from_indices(original, i1, i2))) + edits.append( + FileEdit.from_deletion( + filepath, Range.from_indices(original, i1, i2) + ) + ) elif tag == "insert": - edits.append(FileEdit.from_insertion( - filepath, Position.from_index(original, i1), replacement)) + edits.append( + FileEdit.from_insertion( + filepath, Position.from_index(original, i1), replacement + ) + ) elif tag == "replace": - edits.append(FileEdit(filepath=filepath, range=Range.from_indices( - original, i1, i2), replacement=replacement)) + edits.append( + FileEdit( + filepath=filepath, + range=Range.from_indices(original, i1, i2), + replacement=replacement, + ) + ) else: - raise Exception( - "Unexpected difflib.SequenceMatcher tag: " + tag) + raise Exception("Unexpected difflib.SequenceMatcher tag: " + tag) break original = apply_edit_to_str(original, edits[-1]) @@ -82,17 +102,17 @@ def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit def read_range_in_str(s: str, r: Range) -> str: - lines = s.splitlines()[r.start.line:r.end.line + 1] + lines = s.splitlines()[r.start.line : r.end.line + 1] if len(lines) == 0: return "" - lines[0] = lines[0][r.start.character:] - lines[-1] = lines[-1][:r.end.character + 1] + lines[0] = lines[0][r.start.character :] + lines[-1] = lines[-1][: r.end.character + 1] return "\n".join(lines) def apply_edit_to_str(s: str, edit: FileEdit) -> str: - original = read_range_in_str(s, edit.range) + read_range_in_str(s, edit.range) # Split lines and deal with some edge cases (could obviously be nicer) lines = s.splitlines() @@ -104,27 +124,30 @@ def apply_edit_to_str(s: str, edit: FileEdit) -> str: if len(lines) == 0: lines = [""] - end = Position(line=edit.range.end.line, - character=edit.range.end.character) + end = Position(line=edit.range.end.line, character=edit.range.end.character) if edit.range.end.line == len(lines) and edit.range.end.character == 0: - end = Position(line=edit.range.end.line - 1, - character=len(lines[min(len(lines) - 1, edit.range.end.line - 1)])) + end = Position( + line=edit.range.end.line - 1, + character=len(lines[min(len(lines) - 1, edit.range.end.line - 1)]), + ) - before_lines = lines[:edit.range.start.line] - after_lines = lines[end.line + 1:] - between_str = lines[min(len(lines) - 1, edit.range.start.line)][:edit.range.start.character] + \ - edit.replacement + \ - lines[min(len(lines) - 1, end.line)][end.character + 1:] + before_lines = lines[: edit.range.start.line] + after_lines = lines[end.line + 1 :] + between_str = ( + lines[min(len(lines) - 1, edit.range.start.line)][: edit.range.start.character] + + edit.replacement + + lines[min(len(lines) - 1, end.line)][end.character + 1 :] + ) - new_range = Range( + Range( start=edit.range.start, end=Position( - line=edit.range.start.line + - len(edit.replacement.splitlines()) - 1, - character=edit.range.start.character + - len(edit.replacement.splitlines() - [-1]) if edit.replacement != "" else 0 - ) + line=edit.range.start.line + len(edit.replacement.splitlines()) - 1, + character=edit.range.start.character + + len(edit.replacement.splitlines()[-1]) + if edit.replacement != "" + else 0, + ), ) lines = before_lines + between_str.splitlines() + after_lines diff --git a/continuedev/src/continuedev/libs/util/commonregex.py b/continuedev/src/continuedev/libs/util/commonregex.py index 55da7fc0..3c4fb38c 100644 --- a/continuedev/src/continuedev/libs/util/commonregex.py +++ b/continuedev/src/continuedev/libs/util/commonregex.py @@ -1,37 +1,54 @@ # coding: utf-8 -import json import re -from typing import Any, Dict +from typing import Any date = re.compile( - '(?:(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?\s+(?:of\s+)?(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)|(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)\s+(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?)(?:\,)?\s*(?:\d{4})?|[0-3]?\d[-\./][0-3]?\d[-\./]\d{2,4}', re.IGNORECASE) -time = re.compile( - '\d{1,2}:\d{2} ?(?:[ap]\.?m\.?)?|\d[ap]\.?m\.?', re.IGNORECASE) + "(?:(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?\s+(?:of\s+)?(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)|(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)\s+(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?)(?:\,)?\s*(?:\d{4})?|[0-3]?\d[-\./][0-3]?\d[-\./]\d{2,4}", + re.IGNORECASE, +) +time = re.compile("\d{1,2}:\d{2} ?(?:[ap]\.?m\.?)?|\d[ap]\.?m\.?", re.IGNORECASE) phone = re.compile( - '''((?:(?<![\d-])(?:\+?\d{1,3}[-.\s*]?)?(?:\(?\d{3}\)?[-.\s*]?)?\d{3}[-.\s*]?\d{4}(?![\d-]))|(?:(?<![\d-])(?:(?:\(\+?\d{2}\))|(?:\+?\d{2}))\s*\d{2}\s*\d{3}\s*\d{4}(?![\d-])))''') + """((?:(?<![\d-])(?:\+?\d{1,3}[-.\s*]?)?(?:\(?\d{3}\)?[-.\s*]?)?\d{3}[-.\s*]?\d{4}(?![\d-]))|(?:(?<![\d-])(?:(?:\(\+?\d{2}\))|(?:\+?\d{2}))\s*\d{2}\s*\d{3}\s*\d{4}(?![\d-])))""" +) phones_with_exts = re.compile( - '((?:(?:\+?1\s*(?:[.-]\s*)?)?(?:\(\s*(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9])\s*\)|(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9]))\s*(?:[.-]\s*)?)?(?:[2-9]1[02-9]|[2-9][02-9]1|[2-9][02-9]{2})\s*(?:[.-]\s*)?(?:[0-9]{4})(?:\s*(?:#|x\.?|ext\.?|extension)\s*(?:\d+)?))', re.IGNORECASE) -link = re.compile('(?i)((?:https?://|www\d{0,3}[.])?[a-z0-9.\-]+[.](?:(?:international)|(?:construction)|(?:contractors)|(?:enterprises)|(?:photography)|(?:immobilien)|(?:management)|(?:technology)|(?:directory)|(?:education)|(?:equipment)|(?:institute)|(?:marketing)|(?:solutions)|(?:builders)|(?:clothing)|(?:computer)|(?:democrat)|(?:diamonds)|(?:graphics)|(?:holdings)|(?:lighting)|(?:plumbing)|(?:training)|(?:ventures)|(?:academy)|(?:careers)|(?:company)|(?:domains)|(?:florist)|(?:gallery)|(?:guitars)|(?:holiday)|(?:kitchen)|(?:recipes)|(?:shiksha)|(?:singles)|(?:support)|(?:systems)|(?:agency)|(?:berlin)|(?:camera)|(?:center)|(?:coffee)|(?:estate)|(?:kaufen)|(?:luxury)|(?:monash)|(?:museum)|(?:photos)|(?:repair)|(?:social)|(?:tattoo)|(?:travel)|(?:viajes)|(?:voyage)|(?:build)|(?:cheap)|(?:codes)|(?:dance)|(?:email)|(?:glass)|(?:house)|(?:ninja)|(?:photo)|(?:shoes)|(?:solar)|(?:today)|(?:aero)|(?:arpa)|(?:asia)|(?:bike)|(?:buzz)|(?:camp)|(?:club)|(?:coop)|(?:farm)|(?:gift)|(?:guru)|(?:info)|(?:jobs)|(?:kiwi)|(?:land)|(?:limo)|(?:link)|(?:menu)|(?:mobi)|(?:moda)|(?:name)|(?:pics)|(?:pink)|(?:post)|(?:rich)|(?:ruhr)|(?:sexy)|(?:tips)|(?:wang)|(?:wien)|(?:zone)|(?:biz)|(?:cab)|(?:cat)|(?:ceo)|(?:com)|(?:edu)|(?:gov)|(?:int)|(?:mil)|(?:net)|(?:onl)|(?:org)|(?:pro)|(?:red)|(?:tel)|(?:uno)|(?:xxx)|(?:ac)|(?:ad)|(?:ae)|(?:af)|(?:ag)|(?:ai)|(?:al)|(?:am)|(?:an)|(?:ao)|(?:aq)|(?:ar)|(?:as)|(?:at)|(?:au)|(?:aw)|(?:ax)|(?:az)|(?:ba)|(?:bb)|(?:bd)|(?:be)|(?:bf)|(?:bg)|(?:bh)|(?:bi)|(?:bj)|(?:bm)|(?:bn)|(?:bo)|(?:br)|(?:bs)|(?:bt)|(?:bv)|(?:bw)|(?:by)|(?:bz)|(?:ca)|(?:cc)|(?:cd)|(?:cf)|(?:cg)|(?:ch)|(?:ci)|(?:ck)|(?:cl)|(?:cm)|(?:cn)|(?:co)|(?:cr)|(?:cu)|(?:cv)|(?:cw)|(?:cx)|(?:cy)|(?:cz)|(?:de)|(?:dj)|(?:dk)|(?:dm)|(?:do)|(?:dz)|(?:ec)|(?:ee)|(?:eg)|(?:er)|(?:es)|(?:et)|(?:eu)|(?:fi)|(?:fj)|(?:fk)|(?:fm)|(?:fo)|(?:fr)|(?:ga)|(?:gb)|(?:gd)|(?:ge)|(?:gf)|(?:gg)|(?:gh)|(?:gi)|(?:gl)|(?:gm)|(?:gn)|(?:gp)|(?:gq)|(?:gr)|(?:gs)|(?:gt)|(?:gu)|(?:gw)|(?:gy)|(?:hk)|(?:hm)|(?:hn)|(?:hr)|(?:ht)|(?:hu)|(?:id)|(?:ie)|(?:il)|(?:im)|(?:in)|(?:io)|(?:iq)|(?:ir)|(?:is)|(?:it)|(?:je)|(?:jm)|(?:jo)|(?:jp)|(?:ke)|(?:kg)|(?:kh)|(?:ki)|(?:km)|(?:kn)|(?:kp)|(?:kr)|(?:kw)|(?:ky)|(?:kz)|(?:la)|(?:lb)|(?:lc)|(?:li)|(?:lk)|(?:lr)|(?:ls)|(?:lt)|(?:lu)|(?:lv)|(?:ly)|(?:ma)|(?:mc)|(?:md)|(?:me)|(?:mg)|(?:mh)|(?:mk)|(?:ml)|(?:mm)|(?:mn)|(?:mo)|(?:mp)|(?:mq)|(?:mr)|(?:ms)|(?:mt)|(?:mu)|(?:mv)|(?:mw)|(?:mx)|(?:my)|(?:mz)|(?:na)|(?:nc)|(?:ne)|(?:nf)|(?:ng)|(?:ni)|(?:nl)|(?:no)|(?:np)|(?:nr)|(?:nu)|(?:nz)|(?:om)|(?:pa)|(?:pe)|(?:pf)|(?:pg)|(?:ph)|(?:pk)|(?:pl)|(?:pm)|(?:pn)|(?:pr)|(?:ps)|(?:pt)|(?:pw)|(?:py)|(?:qa)|(?:re)|(?:ro)|(?:rs)|(?:ru)|(?:rw)|(?:sa)|(?:sb)|(?:sc)|(?:sd)|(?:se)|(?:sg)|(?:sh)|(?:si)|(?:sj)|(?:sk)|(?:sl)|(?:sm)|(?:sn)|(?:so)|(?:sr)|(?:st)|(?:su)|(?:sv)|(?:sx)|(?:sy)|(?:sz)|(?:tc)|(?:td)|(?:tf)|(?:tg)|(?:th)|(?:tj)|(?:tk)|(?:tl)|(?:tm)|(?:tn)|(?:to)|(?:tp)|(?:tr)|(?:tt)|(?:tv)|(?:tw)|(?:tz)|(?:ua)|(?:ug)|(?:uk)|(?:us)|(?:uy)|(?:uz)|(?:va)|(?:vc)|(?:ve)|(?:vg)|(?:vi)|(?:vn)|(?:vu)|(?:wf)|(?:ws)|(?:ye)|(?:yt)|(?:za)|(?:zm)|(?:zw))(?:/[^\s()<>]+[^\s`!()\[\]{};:\'".,<>?\xab\xbb\u201c\u201d\u2018\u2019])?)', re.IGNORECASE) + "((?:(?:\+?1\s*(?:[.-]\s*)?)?(?:\(\s*(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9])\s*\)|(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9]))\s*(?:[.-]\s*)?)?(?:[2-9]1[02-9]|[2-9][02-9]1|[2-9][02-9]{2})\s*(?:[.-]\s*)?(?:[0-9]{4})(?:\s*(?:#|x\.?|ext\.?|extension)\s*(?:\d+)?))", + re.IGNORECASE, +) +link = re.compile( + "(?i)((?:https?://|www\d{0,3}[.])?[a-z0-9.\-]+[.](?:(?:international)|(?:construction)|(?:contractors)|(?:enterprises)|(?:photography)|(?:immobilien)|(?:management)|(?:technology)|(?:directory)|(?:education)|(?:equipment)|(?:institute)|(?:marketing)|(?:solutions)|(?:builders)|(?:clothing)|(?:computer)|(?:democrat)|(?:diamonds)|(?:graphics)|(?:holdings)|(?:lighting)|(?:plumbing)|(?:training)|(?:ventures)|(?:academy)|(?:careers)|(?:company)|(?:domains)|(?:florist)|(?:gallery)|(?:guitars)|(?:holiday)|(?:kitchen)|(?:recipes)|(?:shiksha)|(?:singles)|(?:support)|(?:systems)|(?:agency)|(?:berlin)|(?:camera)|(?:center)|(?:coffee)|(?:estate)|(?:kaufen)|(?:luxury)|(?:monash)|(?:museum)|(?:photos)|(?:repair)|(?:social)|(?:tattoo)|(?:travel)|(?:viajes)|(?:voyage)|(?:build)|(?:cheap)|(?:codes)|(?:dance)|(?:email)|(?:glass)|(?:house)|(?:ninja)|(?:photo)|(?:shoes)|(?:solar)|(?:today)|(?:aero)|(?:arpa)|(?:asia)|(?:bike)|(?:buzz)|(?:camp)|(?:club)|(?:coop)|(?:farm)|(?:gift)|(?:guru)|(?:info)|(?:jobs)|(?:kiwi)|(?:land)|(?:limo)|(?:link)|(?:menu)|(?:mobi)|(?:moda)|(?:name)|(?:pics)|(?:pink)|(?:post)|(?:rich)|(?:ruhr)|(?:sexy)|(?:tips)|(?:wang)|(?:wien)|(?:zone)|(?:biz)|(?:cab)|(?:cat)|(?:ceo)|(?:com)|(?:edu)|(?:gov)|(?:int)|(?:mil)|(?:net)|(?:onl)|(?:org)|(?:pro)|(?:red)|(?:tel)|(?:uno)|(?:xxx)|(?:ac)|(?:ad)|(?:ae)|(?:af)|(?:ag)|(?:ai)|(?:al)|(?:am)|(?:an)|(?:ao)|(?:aq)|(?:ar)|(?:as)|(?:at)|(?:au)|(?:aw)|(?:ax)|(?:az)|(?:ba)|(?:bb)|(?:bd)|(?:be)|(?:bf)|(?:bg)|(?:bh)|(?:bi)|(?:bj)|(?:bm)|(?:bn)|(?:bo)|(?:br)|(?:bs)|(?:bt)|(?:bv)|(?:bw)|(?:by)|(?:bz)|(?:ca)|(?:cc)|(?:cd)|(?:cf)|(?:cg)|(?:ch)|(?:ci)|(?:ck)|(?:cl)|(?:cm)|(?:cn)|(?:co)|(?:cr)|(?:cu)|(?:cv)|(?:cw)|(?:cx)|(?:cy)|(?:cz)|(?:de)|(?:dj)|(?:dk)|(?:dm)|(?:do)|(?:dz)|(?:ec)|(?:ee)|(?:eg)|(?:er)|(?:es)|(?:et)|(?:eu)|(?:fi)|(?:fj)|(?:fk)|(?:fm)|(?:fo)|(?:fr)|(?:ga)|(?:gb)|(?:gd)|(?:ge)|(?:gf)|(?:gg)|(?:gh)|(?:gi)|(?:gl)|(?:gm)|(?:gn)|(?:gp)|(?:gq)|(?:gr)|(?:gs)|(?:gt)|(?:gu)|(?:gw)|(?:gy)|(?:hk)|(?:hm)|(?:hn)|(?:hr)|(?:ht)|(?:hu)|(?:id)|(?:ie)|(?:il)|(?:im)|(?:in)|(?:io)|(?:iq)|(?:ir)|(?:is)|(?:it)|(?:je)|(?:jm)|(?:jo)|(?:jp)|(?:ke)|(?:kg)|(?:kh)|(?:ki)|(?:km)|(?:kn)|(?:kp)|(?:kr)|(?:kw)|(?:ky)|(?:kz)|(?:la)|(?:lb)|(?:lc)|(?:li)|(?:lk)|(?:lr)|(?:ls)|(?:lt)|(?:lu)|(?:lv)|(?:ly)|(?:ma)|(?:mc)|(?:md)|(?:me)|(?:mg)|(?:mh)|(?:mk)|(?:ml)|(?:mm)|(?:mn)|(?:mo)|(?:mp)|(?:mq)|(?:mr)|(?:ms)|(?:mt)|(?:mu)|(?:mv)|(?:mw)|(?:mx)|(?:my)|(?:mz)|(?:na)|(?:nc)|(?:ne)|(?:nf)|(?:ng)|(?:ni)|(?:nl)|(?:no)|(?:np)|(?:nr)|(?:nu)|(?:nz)|(?:om)|(?:pa)|(?:pe)|(?:pf)|(?:pg)|(?:ph)|(?:pk)|(?:pl)|(?:pm)|(?:pn)|(?:pr)|(?:ps)|(?:pt)|(?:pw)|(?:py)|(?:qa)|(?:re)|(?:ro)|(?:rs)|(?:ru)|(?:rw)|(?:sa)|(?:sb)|(?:sc)|(?:sd)|(?:se)|(?:sg)|(?:sh)|(?:si)|(?:sj)|(?:sk)|(?:sl)|(?:sm)|(?:sn)|(?:so)|(?:sr)|(?:st)|(?:su)|(?:sv)|(?:sx)|(?:sy)|(?:sz)|(?:tc)|(?:td)|(?:tf)|(?:tg)|(?:th)|(?:tj)|(?:tk)|(?:tl)|(?:tm)|(?:tn)|(?:to)|(?:tp)|(?:tr)|(?:tt)|(?:tv)|(?:tw)|(?:tz)|(?:ua)|(?:ug)|(?:uk)|(?:us)|(?:uy)|(?:uz)|(?:va)|(?:vc)|(?:ve)|(?:vg)|(?:vi)|(?:vn)|(?:vu)|(?:wf)|(?:ws)|(?:ye)|(?:yt)|(?:za)|(?:zm)|(?:zw))(?:/[^\s()<>]+[^\s`!()\[\]{};:'\".,<>?\xab\xbb\u201c\u201d\u2018\u2019])?)", + re.IGNORECASE, +) email = re.compile( - "([a-z0-9!#$%&'*+\/=?^_`{|.}~-]+@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?)", re.IGNORECASE) -ip = re.compile('(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)', re.IGNORECASE) + "([a-z0-9!#$%&'*+\/=?^_`{|.}~-]+@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?)", + re.IGNORECASE, +) +ip = re.compile( + "(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)", + re.IGNORECASE, +) ipv6 = re.compile( - '\s*(?!.*::.*::)(?:(?!:)|:(?=:))(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)){6}(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)[0-9a-f]{0,4}(?:(?<=::)|(?<!:)|(?<=:)(?<!::):)|(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)(?:\.(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)){3})\s*', re.VERBOSE | re.IGNORECASE | re.DOTALL) -price = re.compile('[$]\s?[+-]?[0-9]{1,3}(?:(?:,?[0-9]{3}))*(?:\.[0-9]{1,2})?') -hex_color = re.compile('(#(?:[0-9a-fA-F]{8})|#(?:[0-9a-fA-F]{3}){1,2})\\b') -credit_card = re.compile('((?:(?:\\d{4}[- ]?){3}\\d{4}|\\d{15,16}))(?![\\d])') + "\s*(?!.*::.*::)(?:(?!:)|:(?=:))(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)){6}(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)[0-9a-f]{0,4}(?:(?<=::)|(?<!:)|(?<=:)(?<!::):)|(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)(?:\.(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)){3})\s*", + re.VERBOSE | re.IGNORECASE | re.DOTALL, +) +price = re.compile("[$]\s?[+-]?[0-9]{1,3}(?:(?:,?[0-9]{3}))*(?:\.[0-9]{1,2})?") +hex_color = re.compile("(#(?:[0-9a-fA-F]{8})|#(?:[0-9a-fA-F]{3}){1,2})\\b") +credit_card = re.compile("((?:(?:\\d{4}[- ]?){3}\\d{4}|\\d{15,16}))(?![\\d])") btc_address = re.compile( - '(?<![a-km-zA-HJ-NP-Z0-9])[13][a-km-zA-HJ-NP-Z0-9]{26,33}(?![a-km-zA-HJ-NP-Z0-9])') + "(?<![a-km-zA-HJ-NP-Z0-9])[13][a-km-zA-HJ-NP-Z0-9]{26,33}(?![a-km-zA-HJ-NP-Z0-9])" +) street_address = re.compile( - '\d{1,4} [\w\s]{1,20}(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|park|parkway|pkwy|circle|cir|boulevard|blvd)\W?(?=\s|$)', re.IGNORECASE) -zip_code = re.compile(r'\b\d{5}(?:[-\s]\d{4})?\b') -po_box = re.compile(r'P\.? ?O\.? Box \d+', re.IGNORECASE) + "\d{1,4} [\w\s]{1,20}(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|park|parkway|pkwy|circle|cir|boulevard|blvd)\W?(?=\s|$)", + re.IGNORECASE, +) +zip_code = re.compile(r"\b\d{5}(?:[-\s]\d{4})?\b") +po_box = re.compile(r"P\.? ?O\.? Box \d+", re.IGNORECASE) ssn = re.compile( - '(?!000|666|333)0*(?:[0-6][0-9][0-9]|[0-7][0-6][0-9]|[0-7][0-7][0-2])[- ](?!00)[0-9]{2}[- ](?!0000)[0-9]{4}') + "(?!000|666|333)0*(?:[0-6][0-9][0-9]|[0-7][0-6][0-9]|[0-7][0-7][0-2])[- ](?!00)[0-9]{2}[- ](?!0000)[0-9]{4}" +) win_absolute_filepath = re.compile( - r'^(?:[a-zA-Z]\:|\\\\[\w\.]+\\[\w.$]+)\\(?:[\w]+\\)*\w([\w.])+', re.IGNORECASE) -unix_absolute_filepath = re.compile( - r'^\/(?:[\/\w]+\/)*\w([\w.])+', re.IGNORECASE) + r"^(?:[a-zA-Z]\:|\\\\[\w\.]+\\[\w.$]+)\\(?:[\w]+\\)*\w([\w.])+", re.IGNORECASE +) +unix_absolute_filepath = re.compile(r"^\/(?:[\/\w]+\/)*\w([\w.])+", re.IGNORECASE) regexes = { "win_absolute_filepath": win_absolute_filepath, @@ -77,7 +94,6 @@ placeholders = { class regex: - def __init__(self, obj, regex): self.obj = obj self.regex = regex @@ -85,11 +101,11 @@ class regex: def __call__(self, *args): def regex_method(text=None): return [x.strip() for x in self.regex.findall(text or self.obj.text)] + return regex_method class CommonRegex(object): - def __init__(self, text=""): self.text = text diff --git a/continuedev/src/continuedev/libs/util/copy_codebase.py b/continuedev/src/continuedev/libs/util/copy_codebase.py index 97143faf..aafb435c 100644 --- a/continuedev/src/continuedev/libs/util/copy_codebase.py +++ b/continuedev/src/continuedev/libs/util/copy_codebase.py @@ -1,14 +1,24 @@ import os +import shutil from pathlib import Path from typing import Iterable, List, Union -from watchdog.observers import Observer + from watchdog.events import PatternMatchingEventHandler -from ...models.main import FileEdit, DeleteDirectory, DeleteFile, AddDirectory, AddFile, FileSystemEdit, RenameFile, RenameDirectory, SequentialFileSystemEdit -from ...models.filesystem import FileSystem +from watchdog.observers import Observer + from ...core.autopilot import Autopilot +from ...models.filesystem import FileSystem +from ...models.main import ( + AddDirectory, + AddFile, + DeleteDirectory, + DeleteFile, + FileSystemEdit, + RenameDirectory, + RenameFile, + SequentialFileSystemEdit, +) from .map_path import map_path -from ...core.sdk import ManualEditStep -import shutil def create_copy(orig_root: str, copy_root: str = None, ignore: Iterable[str] = []): @@ -24,8 +34,7 @@ def create_copy(orig_root: str, copy_root: str = None, ignore: Iterable[str] = [ if os.path.isdir(child): if child not in ignore: os.mkdir(map_path(child)) - create_copy(Path(orig_root) / child, - Path(copy_root) / child, ignore) + create_copy(Path(orig_root) / child, Path(copy_root) / child, ignore) else: os.symlink(child, map_path(child)) else: @@ -37,8 +46,18 @@ def create_copy(orig_root: str, copy_root: str = None, ignore: Iterable[str] = [ # The whole usage of watchdog here should only be specific to RealFileSystem, you want to have a different "Observer" class for VirtualFileSystem, which would depend on being sent notifications class CopyCodebaseEventHandler(PatternMatchingEventHandler): - def __init__(self, ignore_directories: List[str], ignore_patterns: List[str], autopilot: Autopilot, orig_root: str, copy_root: str, filesystem: FileSystem): - super().__init__(ignore_directories=ignore_directories, ignore_patterns=ignore_patterns) + def __init__( + self, + ignore_directories: List[str], + ignore_patterns: List[str], + autopilot: Autopilot, + orig_root: str, + copy_root: str, + filesystem: FileSystem, + ): + super().__init__( + ignore_directories=ignore_directories, ignore_patterns=ignore_patterns + ) self.autopilot = autopilot self.orig_root = orig_root self.copy_root = copy_root @@ -85,10 +104,13 @@ class CopyCodebaseEventHandler(PatternMatchingEventHandler): self.autopilot.act(action) -def maintain_copy_workspace(autopilot: Autopilot, filesystem: FileSystem, orig_root: str, copy_root: str): +def maintain_copy_workspace( + autopilot: Autopilot, filesystem: FileSystem, orig_root: str, copy_root: str +): observer = Observer() event_handler = CopyCodebaseEventHandler( - [".git"], [], autopilot, orig_root, copy_root, filesystem) + [".git"], [], autopilot, orig_root, copy_root, filesystem + ) observer.schedule(event_handler, orig_root, recursive=True) observer.start() try: diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 3b594036..a339c288 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,10 +1,10 @@ import json from typing import Dict, List, Union + +import tiktoken + from ...core.main import ChatMessage from .templating import render_templated_string -from ...libs.llm import LLM -from tiktoken_ext import openai_public -import tiktoken # TODO move many of these into specific LLM.properties() function that # contains max tokens, if its a chat model or not, default args (not all models @@ -15,8 +15,13 @@ aliases = { "claude-2": "gpt-3.5-turbo", } DEFAULT_MAX_TOKENS = 2048 -DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1, - "frequency_penalty": 0, "presence_penalty": 0} +DEFAULT_ARGS = { + "max_tokens": DEFAULT_MAX_TOKENS, + "temperature": 0.5, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, +} def encoding_for_model(model_name: str): @@ -41,7 +46,9 @@ def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE -def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int): +def prune_raw_prompt_from_top( + model_name: str, context_length: int, prompt: str, tokens_for_completion: int +): max_tokens = context_length - tokens_for_completion encoding = encoding_for_model(model_name) tokens = encoding.encode(prompt, disallowed_special=()) @@ -51,10 +58,15 @@ def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, return encoding.decode(tokens[-max_tokens:]) -def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int): - total_tokens = tokens_for_completion + \ - sum(count_chat_message_tokens(model_name, message) - for message in chat_history) +def prune_chat_history( + model_name: str, + chat_history: List[ChatMessage], + context_length: int, + tokens_for_completion: int, +): + total_tokens = tokens_for_completion + sum( + count_chat_message_tokens(model_name, message) for message in chat_history + ) # 1. Replace beyond last 5 messages with summary i = 0 @@ -66,13 +78,21 @@ def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context i += 1 # 2. Remove entire messages until the last 5 - while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0: + while ( + len(chat_history) > 5 + and total_tokens > context_length + and len(chat_history) > 0 + ): message = chat_history.pop(0) total_tokens -= count_tokens(model_name, message.content) # 3. Truncate message in the last 5, except last 1 i = 0 - while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1: + while ( + total_tokens > context_length + and len(chat_history) > 0 + and i < len(chat_history) - 1 + ): message = chat_history[i] total_tokens -= count_tokens(model_name, message.content) total_tokens += count_tokens(model_name, message.summary) @@ -88,7 +108,8 @@ def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context if total_tokens > context_length and len(chat_history) > 0: message = chat_history[0] message.content = prune_raw_prompt_from_top( - model_name, context_length, message.content, tokens_for_completion) + model_name, context_length, message.content, tokens_for_completion + ) total_tokens = context_length return chat_history @@ -98,12 +119,19 @@ def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context TOKEN_BUFFER_FOR_SAFETY = 100 -def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages( + model_name: str, + msgs: Union[List[ChatMessage], None], + context_length: int, + max_tokens: int, + prompt: Union[str, None] = None, + functions: Union[List, None] = None, + system_message: Union[str, None] = None, +) -> List[Dict]: """ The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message """ - msgs_copy = [msg.copy(deep=True) - for msg in msgs] if msgs is not None else [] + msgs_copy = [msg.copy(deep=True) for msg in msgs] if msgs is not None else [] if prompt is not None: prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt) @@ -114,7 +142,10 @@ def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], # but move back to start after processing rendered_system_message = render_templated_string(system_message) system_chat_msg = ChatMessage( - role="system", content=rendered_system_message, summary=rendered_system_message) + role="system", + content=rendered_system_message, + summary=rendered_system_message, + ) # insert at second-to-last position msgs_copy.insert(-1, system_chat_msg) @@ -125,13 +156,20 @@ def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], function_tokens += count_tokens(model_name, json.dumps(function)) msgs_copy = prune_chat_history( - model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) + model_name, + msgs_copy, + context_length, + function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY, + ) - history = [msg.to_dict(with_functions=functions is not None) - for msg in msgs_copy] + history = [msg.to_dict(with_functions=functions is not None) for msg in msgs_copy] # Move system message back to start - if system_message is not None and len(history) >= 2 and history[-2]["role"] == "system": + if ( + system_message is not None + and len(history) >= 2 + and history[-2]["role"] == "system" + ): system_message_dict = history.pop(-2) history.insert(0, system_message_dict) diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py index 4c6d3c95..232d3fa1 100644 --- a/continuedev/src/continuedev/libs/util/create_async_task.py +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -1,13 +1,18 @@ -from typing import Callable, Coroutine, Optional, Union -import traceback -from .telemetry import posthog_logger -from .logging import logger import asyncio +import traceback +from typing import Callable, Coroutine, Optional + import nest_asyncio + +from .logging import logger +from .telemetry import posthog_logger + nest_asyncio.apply() -def create_async_task(coro: Coroutine, on_error: Optional[Callable[[Exception], Coroutine]] = None): +def create_async_task( + coro: Coroutine, on_error: Optional[Callable[[Exception], Coroutine]] = None +): """asyncio.create_task and log errors by adding a callback""" task = asyncio.create_task(coro) @@ -15,12 +20,15 @@ def create_async_task(coro: Coroutine, on_error: Optional[Callable[[Exception], try: future.result() except Exception as e: - formatted_tb = '\n'.join(traceback.format_exception(e)) - logger.critical( - f"Exception caught from async task: {formatted_tb}") - posthog_logger.capture_event("async_task_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e)) - }) + formatted_tb = "\n".join(traceback.format_exception(e)) + logger.critical(f"Exception caught from async task: {formatted_tb}") + posthog_logger.capture_event( + "async_task_error", + { + "error_title": e.__str__() or e.__repr__(), + "error_message": "\n".join(traceback.format_exception(e)), + }, + ) # Log the error to the GUI if on_error is not None: diff --git a/continuedev/src/continuedev/libs/util/logging.py b/continuedev/src/continuedev/libs/util/logging.py index 668d313f..4a550168 100644 --- a/continuedev/src/continuedev/libs/util/logging.py +++ b/continuedev/src/continuedev/libs/util/logging.py @@ -15,8 +15,7 @@ console_handler = logging.StreamHandler() console_handler.setLevel(logging.DEBUG) # Create a formatter -formatter = logging.Formatter( - '[%(asctime)s] [%(levelname)s] %(message)s') +formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] %(message)s") # Add the formatter to the handlers file_handler.setFormatter(formatter) @@ -27,4 +26,4 @@ logger.addHandler(file_handler) logger.addHandler(console_handler) # Log a test message -logger.debug('Testing logs') +logger.debug("Testing logs") diff --git a/continuedev/src/continuedev/libs/util/map_path.py b/continuedev/src/continuedev/libs/util/map_path.py index 8eb57712..1dddc2e9 100644 --- a/continuedev/src/continuedev/libs/util/map_path.py +++ b/continuedev/src/continuedev/libs/util/map_path.py @@ -1,16 +1,16 @@ from pathlib import Path -def map_path(path: str, orig_root: str, copy_root: str) -> Path: - path = Path(path) - if path.is_relative_to(orig_root): - if path.is_absolute(): - return Path(copy_root) / path.relative_to(orig_root) - else: - return path - else: - if path.is_absolute(): - return path - else: - # For this one, you need to know the directory from which the relative path is being used. - return Path(orig_root) / path +def map_path(path: str, orig_root: str, copy_root: str) -> Path: + path = Path(path) + if path.is_relative_to(orig_root): + if path.is_absolute(): + return Path(copy_root) / path.relative_to(orig_root) + else: + return path + else: + if path.is_absolute(): + return path + else: + # For this one, you need to know the directory from which the relative path is being used. + return Path(orig_root) / path diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index 01b594cf..483f6b63 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -1,6 +1,11 @@ import os -from ..constants.main import CONTINUE_SESSIONS_FOLDER, CONTINUE_GLOBAL_FOLDER, CONTINUE_SERVER_FOLDER + from ..constants.default_config import default_config +from ..constants.main import ( + CONTINUE_GLOBAL_FOLDER, + CONTINUE_SERVER_FOLDER, + CONTINUE_SESSIONS_FOLDER, +) def find_data_file(filename): @@ -36,7 +41,7 @@ def getSessionsListFilePath(): path = os.path.join(getSessionsFolderPath(), "sessions.json") os.makedirs(os.path.dirname(path), exist_ok=True) if not os.path.exists(path): - with open(path, 'w') as f: + with open(path, "w") as f: f.write("[]") return path @@ -46,19 +51,22 @@ def getConfigFilePath() -> str: os.makedirs(os.path.dirname(path), exist_ok=True) if not os.path.exists(path): - with open(path, 'w') as f: + with open(path, "w") as f: f.write(default_config) else: - with open(path, 'r') as f: + with open(path, "r") as f: existing_content = f.read() if existing_content.strip() == "": - with open(path, 'w') as f: + with open(path, "w") as f: f.write(default_config) elif " continuedev.core" in existing_content: - with open(path, 'w') as f: - f.write(existing_content.replace(" continuedev.", - " continuedev.src.continuedev.")) + with open(path, "w") as f: + f.write( + existing_content.replace( + " continuedev.", " continuedev.src.continuedev." + ) + ) return path diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py index ed1e79b7..0cca261f 100644 --- a/continuedev/src/continuedev/libs/util/step_name_to_steps.py +++ b/continuedev/src/continuedev/libs/util/step_name_to_steps.py @@ -1,20 +1,22 @@ from typing import Dict from ...core.main import Step -from ...plugins.steps.core.core import UserInputStep -from ...plugins.steps.main import EditHighlightedCodeStep -from ...plugins.steps.chat import SimpleChatStep -from ...plugins.steps.comment_code import CommentCodeStep -from ...plugins.steps.feedback import FeedbackStep +from ...libs.util.logging import logger from ...plugins.recipes.AddTransformRecipe.main import AddTransformRecipe from ...plugins.recipes.CreatePipelineRecipe.main import CreatePipelineRecipe from ...plugins.recipes.DDtoBQRecipe.main import DDtoBQRecipe -from ...plugins.recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe -from ...plugins.steps.on_traceback import DefaultOnTracebackStep +from ...plugins.recipes.DeployPipelineAirflowRecipe.main import ( + DeployPipelineAirflowRecipe, +) +from ...plugins.steps.chat import SimpleChatStep from ...plugins.steps.clear_history import ClearHistoryStep -from ...plugins.steps.open_config import OpenConfigStep +from ...plugins.steps.comment_code import CommentCodeStep +from ...plugins.steps.core.core import UserInputStep +from ...plugins.steps.feedback import FeedbackStep from ...plugins.steps.help import HelpStep -from ...libs.util.logging import logger +from ...plugins.steps.main import EditHighlightedCodeStep +from ...plugins.steps.on_traceback import DefaultOnTracebackStep +from ...plugins.steps.open_config import OpenConfigStep # This mapping is used to convert from string in ContinueConfig json to corresponding Step class. # Used for example in slash_commands and steps_on_startup @@ -40,5 +42,6 @@ def get_step_from_name(step_name: str, params: Dict) -> Step: return step_name_to_step_class[step_name](**params) except: logger.error( - f"Incorrect parameters for step {step_name}. Parameters provided were: {params}") + f"Incorrect parameters for step {step_name}. Parameters provided were: {params}" + ) raise diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py index 285c1e47..d33c46c4 100644 --- a/continuedev/src/continuedev/libs/util/strings.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -43,7 +43,9 @@ def remove_quotes_and_escapes(output: str) -> str: output = output.replace("\\n", "\n") output = output.replace("\\t", "\t") output = output.replace("\\\\", "\\") - if (output.startswith('"') and output.endswith('"')) or (output.startswith("'") and output.endswith("'")): + if (output.startswith('"') and output.endswith('"')) or ( + output.startswith("'") and output.endswith("'") + ): output = output[1:-1] return output diff --git a/continuedev/src/continuedev/libs/util/telemetry.py b/continuedev/src/continuedev/libs/util/telemetry.py index 2c76f4a3..ab5ec328 100644 --- a/continuedev/src/continuedev/libs/util/telemetry.py +++ b/continuedev/src/continuedev/libs/util/telemetry.py @@ -1,15 +1,16 @@ -from typing import Any -from posthog import Posthog import os +from typing import Any + from dotenv import load_dotenv +from posthog import Posthog + +from ..constants.main import CONTINUE_SERVER_VERSION_FILE from .commonregex import clean_pii_from_any -from .logging import logger from .paths import getServerFolderPath -from ..constants.main import CONTINUE_SERVER_VERSION_FILE load_dotenv() in_codespaces = os.getenv("CODESPACES") == "true" -POSTHOG_API_KEY = 'phc_JS6XFROuNbhJtVCEdTSYk6gl5ArRrTNMpCcguAXlSPs' +POSTHOG_API_KEY = "phc_JS6XFROuNbhJtVCEdTSYk6gl5ArRrTNMpCcguAXlSPs" class PostHogLogger: @@ -20,16 +21,14 @@ class PostHogLogger: self.api_key = api_key # The personal API key is necessary only if you want to use local evaluation of feature flags. - self.posthog = Posthog(self.api_key, host='https://app.posthog.com') + self.posthog = Posthog(self.api_key, host="https://app.posthog.com") def setup(self, unique_id: str, allow_anonymous_telemetry: bool): self.unique_id = unique_id or "NO_UNIQUE_ID" self.allow_anonymous_telemetry = allow_anonymous_telemetry or True # Capture initial event - self.capture_event("session_start", { - "os": os.name - }) + self.capture_event("session_start", {"os": os.name}) def capture_event(self, event_name: str, event_properties: Any): # logger.debug( @@ -53,13 +52,14 @@ class PostHogLogger: # Add additional properties that are on every event if in_codespaces: - event_properties['codespaces'] = True + event_properties["codespaces"] = True server_version_file = os.path.join( - getServerFolderPath(), CONTINUE_SERVER_VERSION_FILE) + getServerFolderPath(), CONTINUE_SERVER_VERSION_FILE + ) if os.path.exists(server_version_file): with open(server_version_file, "r") as f: - event_properties['server_version'] = f.read() + event_properties["server_version"] = f.read() # Send event to PostHog self.posthog.capture(self.unique_id, event_name, event_properties) diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py index bb922ad7..edcf2884 100644 --- a/continuedev/src/continuedev/libs/util/templating.py +++ b/continuedev/src/continuedev/libs/util/templating.py @@ -1,4 +1,5 @@ import os + import chevron @@ -6,14 +7,18 @@ def get_vars_in_template(template): """ Get the variables in a template """ - return [token[1] for token in chevron.tokenizer.tokenize(template) if token[0] == 'variable'] + return [ + token[1] + for token in chevron.tokenizer.tokenize(template) + if token[0] == "variable" + ] def escape_var(var: str) -> str: """ Escape a variable so it can be used in a template """ - return var.replace(os.path.sep, '').replace('.', '') + return var.replace(os.path.sep, "").replace(".", "") def render_templated_string(template: str) -> str: @@ -28,12 +33,11 @@ def render_templated_string(template: str) -> str: if var.startswith(os.path.sep): # Escape vars which are filenames, because mustache doesn't allow / in variable names escaped_var = escape_var(var) - template = template.replace( - var, escaped_var) + template = template.replace(var, escaped_var) if os.path.exists(var): - args[escaped_var] = open(var, 'r').read() + args[escaped_var] = open(var, "r").read() else: - args[escaped_var] = '' + args[escaped_var] = "" return chevron.render(template, args) diff --git a/continuedev/src/continuedev/libs/util/traceback_parsers.py b/continuedev/src/continuedev/libs/util/traceback_parsers.py index a2e94c26..2b164c0f 100644 --- a/continuedev/src/continuedev/libs/util/traceback_parsers.py +++ b/continuedev/src/continuedev/libs/util/traceback_parsers.py @@ -15,11 +15,16 @@ def get_javascript_traceback(output: str) -> str: first_line = None for i in range(len(lines) - 1): segs = lines[i].split(":") - if len(segs) > 1 and segs[0] != "" and segs[1].startswith(" ") and lines[i + 1].strip().startswith("at"): + if ( + len(segs) > 1 + and segs[0] != "" + and segs[1].startswith(" ") + and lines[i + 1].strip().startswith("at") + ): first_line = lines[i] break if first_line is not None: - return "\n".join(lines[lines.index(first_line):]) + return "\n".join(lines[lines.index(first_line) :]) else: return None diff --git a/continuedev/src/continuedev/models/filesystem.py b/continuedev/src/continuedev/models/filesystem.py index ca12579c..de426282 100644 --- a/continuedev/src/continuedev/models/filesystem.py +++ b/continuedev/src/continuedev/models/filesystem.py @@ -1,9 +1,22 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Tuple import os -from ..models.main import Position, Range, AbstractModel +from abc import abstractmethod +from typing import Dict, List, Tuple + from pydantic import BaseModel -from .filesystem_edit import FileSystemEdit, FileEdit, AddFile, DeleteFile, RenameDirectory, RenameFile, AddDirectory, DeleteDirectory, EditDiff, SequentialFileSystemEdit + +from ..models.main import AbstractModel, Position, Range +from .filesystem_edit import ( + AddDirectory, + AddFile, + DeleteDirectory, + DeleteFile, + EditDiff, + FileEdit, + FileSystemEdit, + RenameDirectory, + RenameFile, + SequentialFileSystemEdit, +) class RangeInFile(BaseModel): @@ -16,14 +29,12 @@ class RangeInFile(BaseModel): @staticmethod def from_entire_file(filepath: str, content: str) -> "RangeInFile": range = Range.from_entire_file(content) - return RangeInFile( - filepath=filepath, - range=range - ) + return RangeInFile(filepath=filepath, range=range) class RangeInFileWithContents(RangeInFile): """A range in a file with the contents of the range.""" + contents: str def __hash__(self): @@ -42,13 +53,15 @@ class RangeInFileWithContents(RangeInFile): # Calculate union of contents num_overlapping_lines = first.range.end.line - second.range.start.line + 1 - union_lines = first.contents.splitlines()[:-num_overlapping_lines] + \ - second.contents.splitlines() + union_lines = ( + first.contents.splitlines()[:-num_overlapping_lines] + + second.contents.splitlines() + ) return RangeInFileWithContents( filepath=first.filepath, range=first.range.union(second.range), - contents="\n".join(union_lines) + contents="\n".join(union_lines), ) @staticmethod @@ -56,28 +69,24 @@ class RangeInFileWithContents(RangeInFile): lines = content.splitlines() if not lines: return RangeInFileWithContents( - filepath=filepath, - range=Range.from_shorthand(0, 0, 0, 0), - contents="" + filepath=filepath, range=Range.from_shorthand(0, 0, 0, 0), contents="" ) return RangeInFileWithContents( filepath=filepath, - range=Range.from_shorthand( - 0, 0, len(lines) - 1, len(lines[-1]) - 1), - contents=content + range=Range.from_shorthand(0, 0, len(lines) - 1, len(lines[-1]) - 1), + contents=content, ) @staticmethod def from_range_in_file(rif: RangeInFile, content: str) -> "RangeInFileWithContents": return RangeInFileWithContents( - filepath=rif.filepath, - range=rif.range, - contents=content + filepath=rif.filepath, range=rif.range, contents=content ) class FileSystem(AbstractModel): """An abstract filesystem that can read/write from a set of files.""" + @abstractmethod def read(self, path) -> str: raise NotImplementedError @@ -129,12 +138,12 @@ class FileSystem(AbstractModel): @classmethod def read_range_in_str(self, s: str, r: Range) -> str: - lines = s.split("\n")[r.start.line:r.end.line + 1] + lines = s.split("\n")[r.start.line : r.end.line + 1] if len(lines) == 0: return "" - lines[0] = lines[0][r.start.character:] - lines[-1] = lines[-1][:r.end.character + 1] + lines[0] = lines[0][r.start.character :] + lines[-1] = lines[-1][: r.end.character + 1] return "\n".join(lines) @classmethod @@ -151,37 +160,40 @@ class FileSystem(AbstractModel): if len(lines) == 0: lines = [""] - end = Position(line=edit.range.end.line, - character=edit.range.end.character) + end = Position(line=edit.range.end.line, character=edit.range.end.character) if edit.range.end.line == len(lines) and edit.range.end.character == 0: - end = Position(line=edit.range.end.line - 1, - character=len(lines[min(len(lines) - 1, edit.range.end.line - 1)])) + end = Position( + line=edit.range.end.line - 1, + character=len(lines[min(len(lines) - 1, edit.range.end.line - 1)]), + ) - before_lines = lines[:edit.range.start.line] - after_lines = lines[end.line + 1:] - between_str = lines[min(len(lines) - 1, edit.range.start.line)][:edit.range.start.character] + \ - edit.replacement + \ - lines[min(len(lines) - 1, end.line)][end.character + 1:] + before_lines = lines[: edit.range.start.line] + after_lines = lines[end.line + 1 :] + between_str = ( + lines[min(len(lines) - 1, edit.range.start.line)][ + : edit.range.start.character + ] + + edit.replacement + + lines[min(len(lines) - 1, end.line)][end.character + 1 :] + ) new_range = Range( start=edit.range.start, end=Position( - line=edit.range.start.line + - len(edit.replacement.splitlines()) - 1, - character=edit.range.start.character + - len(edit.replacement.splitlines() - [-1]) if edit.replacement != "" else 0 - ) + line=edit.range.start.line + len(edit.replacement.splitlines()) - 1, + character=edit.range.start.character + + len(edit.replacement.splitlines()[-1]) + if edit.replacement != "" + else 0, + ), ) lines = before_lines + between_str.splitlines() + after_lines return "\n".join(lines), EditDiff( forward=edit, backward=FileEdit( - filepath=edit.filepath, - range=new_range, - replacement=original - ) + filepath=edit.filepath, range=new_range, replacement=original + ), ) def reverse_edit_on_str(self, s: str, diff: EditDiff) -> str: @@ -194,15 +206,17 @@ class FileSystem(AbstractModel): start=diff.edit.range.start, end=Position( line=diff.edit.range.start + replacement_d_lines, - character=diff.edit.range.start.character + replacement_d_chars - ) + character=diff.edit.range.start.character + replacement_d_chars, + ), ) - before_lines = lines[:replacement_range.start.line] - after_lines = lines[replacement_range.end.line + 1:] - between_str = lines[replacement_range.start.line][:replacement_range.start.character] + \ - diff.original + \ - lines[replacement_range.end.line][replacement_range.end.character + 1:] + before_lines = lines[: replacement_range.start.line] + after_lines = lines[replacement_range.end.line + 1 :] + between_str = ( + lines[replacement_range.start.line][: replacement_range.start.character] + + diff.original + + lines[replacement_range.end.line][replacement_range.end.character + 1 :] + ) lines = before_lines + between_str.splitlines() + after_lines return "\n".join(lines) @@ -221,8 +235,9 @@ class FileSystem(AbstractModel): self.delete_file(edit.filepath) elif isinstance(edit, RenameFile): self.rename_file(edit.filepath, edit.new_filepath) - backward = RenameFile(filepath=edit.new_filepath, - new_filepath=edit.filepath) + backward = RenameFile( + filepath=edit.new_filepath, new_filepath=edit.filepath + ) elif isinstance(edit, AddDirectory): self.add_directory(edit.path) backward = DeleteDirectory(edit.path) @@ -235,8 +250,7 @@ class FileSystem(AbstractModel): backward_edits.append(self.apply_edit(DeleteFile(path))) for d in dirs: path = os.path.join(root, d) - backward_edits.append( - self.apply_edit(DeleteDirectory(path))) + backward_edits.append(self.apply_edit(DeleteDirectory(path))) backward_edits.append(self.apply_edit(DeleteDirectory(edit.path))) backward_edits.reverse() @@ -253,10 +267,7 @@ class FileSystem(AbstractModel): else: raise TypeError("Unknown FileSystemEdit type: " + str(type(edit))) - return EditDiff( - forward=edit, - backward=backward - ) + return EditDiff(forward=edit, backward=backward) class RealFileSystem(FileSystem): @@ -304,6 +315,7 @@ class RealFileSystem(FileSystem): class VirtualFileSystem(FileSystem): """A simulated filesystem from a mapping of filepath to file contents.""" + files: Dict[str, str] def __init__(self, files: Dict[str, str]): @@ -331,7 +343,7 @@ class VirtualFileSystem(FileSystem): def rename_directory(self, path: str, new_path: str): for filepath in self.files: if filepath.startswith(path): - new_filepath = new_path + filepath[len(path):] + new_filepath = new_path + filepath[len(path) :] self.files[new_filepath] = self.files[filepath] del self.files[filepath] @@ -349,9 +361,7 @@ class VirtualFileSystem(FileSystem): old_content = self.read(edit.filepath) new_content, original = FileSystem.apply_edit_to_str(old_content, edit) self.write(edit.filepath, new_content) - return EditDiff( - edit=edit, - original=original - ) + return EditDiff(edit=edit, original=original) + # TODO: Uniform errors thrown by any FileSystem subclass. diff --git a/continuedev/src/continuedev/models/filesystem_edit.py b/continuedev/src/continuedev/models/filesystem_edit.py index b06ca2b3..9316ff46 100644 --- a/continuedev/src/continuedev/models/filesystem_edit.py +++ b/continuedev/src/continuedev/models/filesystem_edit.py @@ -1,9 +1,11 @@ +import os from abc import abstractmethod from typing import Generator, List + from pydantic import BaseModel -from .main import Position, Range + from ..libs.util.map_path import map_path -import os +from .main import Position, Range class FileSystemEdit(BaseModel): @@ -27,7 +29,9 @@ class FileEdit(AtomicFileSystemEdit): replacement: str def with_mapped_paths(self, orig_root: str, copy_root: str) -> "FileSystemEdit": - return FileEdit(map_path(self.filepath, orig_root, copy_root), self.range, self.replacement) + return FileEdit( + map_path(self.filepath, orig_root, copy_root), self.range, self.replacement + ) @staticmethod def from_deletion(filepath: str, range: Range) -> "FileEdit": @@ -35,11 +39,23 @@ class FileEdit(AtomicFileSystemEdit): @staticmethod def from_insertion(filepath: str, position: Position, content: str) -> "FileEdit": - return FileEdit(filepath=filepath, range=Range.from_shorthand(position.line, position.character, position.line, position.character), replacement=content) + return FileEdit( + filepath=filepath, + range=Range.from_shorthand( + position.line, position.character, position.line, position.character + ), + replacement=content, + ) @staticmethod - def from_append(filepath: str, previous_content: str, appended_content: str) -> "FileEdit": - return FileEdit(filepath=filepath, range=Range.from_position(Position.from_end_of_file(previous_content)), replacement=appended_content) + def from_append( + filepath: str, previous_content: str, appended_content: str + ) -> "FileEdit": + return FileEdit( + filepath=filepath, + range=Range.from_position(Position.from_end_of_file(previous_content)), + replacement=appended_content, + ) class FileEditWithFullContents(BaseModel): @@ -52,7 +68,9 @@ class AddFile(AtomicFileSystemEdit): content: str def with_mapped_paths(self, orig_root: str, copy_root: str) -> "FileSystemEdit": - return AddFile(self, map_path(self.filepath, orig_root, copy_root), self.content) + return AddFile( + self, map_path(self.filepath, orig_root, copy_root), self.content + ) class DeleteFile(AtomicFileSystemEdit): @@ -67,7 +85,10 @@ class RenameFile(AtomicFileSystemEdit): new_filepath: str def with_mapped_paths(self, orig_root: str, copy_root: str) -> "FileSystemEdit": - return RenameFile(map_path(self.filepath, orig_root, copy_root), map_path(self.new_filepath, orig_root, copy_root)) + return RenameFile( + map_path(self.filepath, orig_root, copy_root), + map_path(self.new_filepath, orig_root, copy_root), + ) class AddDirectory(AtomicFileSystemEdit): @@ -89,7 +110,10 @@ class RenameDirectory(AtomicFileSystemEdit): new_path: str def with_mapped_paths(self, orig_root: str, copy_root: str) -> "FileSystemEdit": - return RenameDirectory(map_path(self.filepath, orig_root, copy_root), map_path(self.new_path, orig_root, copy_root)) + return RenameDirectory( + map_path(self.filepath, orig_root, copy_root), + map_path(self.new_path, orig_root, copy_root), + ) class DeleteDirectoryRecursive(FileSystemEdit): @@ -112,10 +136,9 @@ class SequentialFileSystemEdit(FileSystemEdit): edits: List[FileSystemEdit] def with_mapped_paths(self, orig_root: str, copy_root: str) -> "FileSystemEdit": - return SequentialFileSystemEdit([ - edit.with_mapped_paths(orig_root, copy_root) - for edit in self.edits - ]) + return SequentialFileSystemEdit( + [edit.with_mapped_paths(orig_root, copy_root) for edit in self.edits] + ) def next_edit(self) -> Generator["FileSystemEdit", None, None]: for edit in self.edits: @@ -124,6 +147,7 @@ class SequentialFileSystemEdit(FileSystemEdit): class EditDiff(BaseModel): """A reversible edit that can be applied to a file.""" + forward: FileSystemEdit backward: FileSystemEdit @@ -136,5 +160,5 @@ class EditDiff(BaseModel): backwards.insert(0, diff.backward) return cls( forward=SequentialFileSystemEdit(edits=forwards), - backward=SequentialFileSystemEdit(edits=backwards) + backward=SequentialFileSystemEdit(edits=backwards), ) diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 4262ac55..1c43f0a0 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -1,26 +1,22 @@ -from .main import * -from .filesystem import RangeInFile, FileEdit -from .filesystem_edit import FileEditWithFullContents -from ..core.main import History, HistoryNode, FullState, SessionInfo -from ..core.context import ContextItem -from pydantic import schema_json_of import os -MODELS_TO_GENERATE = [ - Position, Range, Traceback, TracebackFrame -] + [ - RangeInFile, FileEdit -] + [ - FileEditWithFullContents -] + [ - History, HistoryNode, FullState, SessionInfo -] + [ - ContextItem -] - -RENAMES = { - "ExampleClass": "RenamedName" -} +from pydantic import schema_json_of + +from ..core.context import ContextItem +from ..core.main import FullState, History, HistoryNode, SessionInfo +from .filesystem import FileEdit, RangeInFile +from .filesystem_edit import FileEditWithFullContents +from .main import * + +MODELS_TO_GENERATE = ( + [Position, Range, Traceback, TracebackFrame] + + [RangeInFile, FileEdit] + + [FileEditWithFullContents] + + [History, HistoryNode, FullState, SessionInfo] + + [ContextItem] +) + +RENAMES = {"ExampleClass": "RenamedName"} SCHEMA_DIR = "../schema/json" diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index fa736772..d442a415 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -1,7 +1,8 @@ from abc import ABC -from typing import List, Union, Tuple -from pydantic import BaseModel, root_validator from functools import total_ordering +from typing import List, Tuple, Union + +from pydantic import BaseModel, root_validator class ContinueBaseModel(BaseModel): @@ -46,16 +47,19 @@ class Position(BaseModel): def to_index(self, string: str) -> int: """Convert line and character to index in string""" lines = string.splitlines() - return sum(map(len, lines[:self.line])) + self.character + return sum(map(len, lines[: self.line])) + self.character class Range(BaseModel): """A range in a file. 0-indexed.""" + start: Position end: Position def __lt__(self, other: "Range") -> bool: - return self.start < other.start or (self.start == other.start and self.end < other.end) + return self.start < other.start or ( + self.start == other.start and self.end < other.end + ) def __eq__(self, other: "Range") -> bool: return self.start == other.start and self.end == other.end @@ -78,10 +82,13 @@ class Range(BaseModel): if len(lines) == 0: return (0, 0) - start_index = sum( - [len(line) + 1 for line in lines[:self.start.line]]) + self.start.character - end_index = sum( - [len(line) + 1 for line in lines[:self.end.line]]) + self.end.character + start_index = ( + sum([len(line) + 1 for line in lines[: self.start.line]]) + + self.start.character + ) + end_index = ( + sum([len(line) + 1 for line in lines[: self.end.line]]) + self.end.character + ) return (start_index, end_index) def overlaps_with(self, other: "Range") -> bool: @@ -90,27 +97,23 @@ class Range(BaseModel): def to_full_lines(self) -> "Range": return Range( start=Position(line=self.start.line, character=0), - end=Position(line=self.end.line + 1, character=0) + end=Position(line=self.end.line + 1, character=0), ) @staticmethod def from_indices(string: str, start_index: int, end_index: int) -> "Range": return Range( start=Position.from_index(string, start_index), - end=Position.from_index(string, end_index) + end=Position.from_index(string, end_index), ) @staticmethod - def from_shorthand(start_line: int, start_char: int, end_line: int, end_char: int) -> "Range": + def from_shorthand( + start_line: int, start_char: int, end_line: int, end_char: int + ) -> "Range": return Range( - start=Position( - line=start_line, - character=start_char - ), - end=Position( - line=end_line, - character=end_char - ) + start=Position(line=start_line, character=start_char), + end=Position(line=end_line, character=end_char), ) @staticmethod @@ -148,7 +151,9 @@ class Range(BaseModel): if start_line == -1 or end_line == -1: raise ValueError("Snippet not found in content") - return Range.from_shorthand(start_line, 0, end_line, len(content_lines[end_line]) - 1) + return Range.from_shorthand( + start_line, 0, end_line, len(content_lines[end_line]) - 1 + ) @staticmethod def from_position(position: Position) -> "Range": @@ -160,7 +165,8 @@ class AbstractModel(ABC, BaseModel): def check_is_subclass(cls, values): if not issubclass(cls, AbstractModel): raise TypeError( - "AbstractModel subclasses must be subclasses of AbstractModel") + "AbstractModel subclasses must be subclasses of AbstractModel" + ) class TracebackFrame(BaseModel): @@ -170,7 +176,11 @@ class TracebackFrame(BaseModel): code: Union[str, None] def __eq__(self, other): - return self.filepath == other.filepath and self.lineno == other.lineno and self.function == other.function + return ( + self.filepath == other.filepath + and self.lineno == other.lineno + and self.function == other.function + ) class Traceback(BaseModel): diff --git a/continuedev/src/continuedev/plugins/context_providers/diff.py b/continuedev/src/continuedev/plugins/context_providers/diff.py index 7a53e87a..c8345d02 100644 --- a/continuedev/src/continuedev/plugins/context_providers/diff.py +++ b/continuedev/src/continuedev/plugins/context_providers/diff.py @@ -1,9 +1,8 @@ import subprocess from typing import List -from .util import remove_meilisearch_disallowed_chars -from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId class DiffContextProvider(ContextProvider): @@ -21,10 +20,9 @@ class DiffContextProvider(ContextProvider): name="Diff", description="Reference the output of 'git diff' for the current workspace", id=ContextItemId( - provider_title=self.title, - item_id=self.DIFF_CONTEXT_ITEM_ID - ) - ) + provider_title=self.title, item_id=self.DIFF_CONTEXT_ITEM_ID + ), + ), ) async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: @@ -35,8 +33,9 @@ class DiffContextProvider(ContextProvider): if not id.item_id == self.DIFF_CONTEXT_ITEM_ID: raise Exception("Invalid item id") - diff = subprocess.check_output( - ["git", "diff"], cwd=self.workspace_dir).decode("utf-8") + diff = subprocess.check_output(["git", "diff"], cwd=self.workspace_dir).decode( + "utf-8" + ) ctx_item = self.BASE_CONTEXT_ITEM.copy() ctx_item.content = diff diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py index 42d1f754..3f37232e 100644 --- a/continuedev/src/continuedev/plugins/context_providers/embeddings.py +++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py @@ -1,13 +1,12 @@ import os -from typing import List, Optional import uuid +from typing import List, Optional + from pydantic import BaseModel -from ...core.main import ContextItemId from ...core.context import ContextProvider from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...libs.chroma.query import ChromaIndexManager -from .util import remove_meilisearch_disallowed_chars class EmbeddingResult(BaseModel): @@ -41,10 +40,9 @@ class EmbeddingsProvider(ContextProvider): name="Embedding Search", description="Enter a query to embedding search codebase", id=ContextItemId( - provider_title=self.title, - item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID - ) - ) + provider_title=self.title, item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID + ), + ), ) async def _get_query_results(self, query: str) -> str: @@ -53,9 +51,8 @@ class EmbeddingsProvider(ContextProvider): ret = [] for node in results.source_nodes: resource_name = list(node.node.relationships.values())[0] - filepath = resource_name[:resource_name.index("::")] - ret.append(EmbeddingResult( - filename=filepath, content=node.node.text)) + filepath = resource_name[: resource_name.index("::")] + ret.append(EmbeddingResult(filename=filepath, content=node.node.text)) return ret diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index b40092af..33e20662 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -1,11 +1,10 @@ import os -import re +from fnmatch import fnmatch from typing import List -from ...core.main import ContextItem, ContextItemDescription, ContextItemId + from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId from .util import remove_meilisearch_disallowed_chars -from fnmatch import fnmatch - MAX_SIZE_IN_BYTES = 1024 * 1024 * 1 @@ -18,7 +17,7 @@ def get_file_contents(filepath: str) -> str: with open(filepath, "r") as f: return f.read() - except Exception as e: + except Exception: # Some files cannot be read, e.g. binary files return None @@ -40,7 +39,7 @@ DEFAULT_IGNORE_DIRS = [ ".pytest_cache", ".vscode-test", ".continue", - "__pycache__" + "__pycache__", ] @@ -50,14 +49,18 @@ class FileContextProvider(ContextProvider): """ title = "file" - ignore_patterns: List[str] = DEFAULT_IGNORE_DIRS + \ - list(filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS)) + ignore_patterns: List[str] = DEFAULT_IGNORE_DIRS + list( + filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS) + ) async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: absolute_filepaths: List[str] = [] for root, dir_names, file_names in os.walk(workspace_dir): - dir_names[:] = [d for d in dir_names if not any( - fnmatch(d, pattern) for pattern in self.ignore_patterns)] + dir_names[:] = [ + d + for d in dir_names + if not any(fnmatch(d, pattern) for pattern in self.ignore_patterns) + ] for file_name in file_names: absolute_filepaths.append(os.path.join(root, file_name)) @@ -72,20 +75,24 @@ class FileContextProvider(ContextProvider): content = get_file_contents(absolute_filepath) if content is None: continue # no pun intended - + relative_to_workspace = os.path.relpath(absolute_filepath, workspace_dir) - items.append(ContextItem( - content=content[:min(2000, len(content))], - description=ContextItemDescription( - name=os.path.basename(absolute_filepath), - # We should add the full path to the ContextItem - # It warrants a data modeling discussion and has no immediate use case - description=relative_to_workspace, - id=ContextItemId( - provider_title=self.title, - item_id=remove_meilisearch_disallowed_chars(absolute_filepath) - ) + items.append( + ContextItem( + content=content[: min(2000, len(content))], + description=ContextItemDescription( + name=os.path.basename(absolute_filepath), + # We should add the full path to the ContextItem + # It warrants a data modeling discussion and has no immediate use case + description=relative_to_workspace, + id=ContextItemId( + provider_title=self.title, + item_id=remove_meilisearch_disallowed_chars( + absolute_filepath + ), + ), + ), ) - )) + ) return items diff --git a/continuedev/src/continuedev/plugins/context_providers/filetree.py b/continuedev/src/continuedev/plugins/context_providers/filetree.py index c7b4806b..ea86f214 100644 --- a/continuedev/src/continuedev/plugins/context_providers/filetree.py +++ b/continuedev/src/continuedev/plugins/context_providers/filetree.py @@ -1,21 +1,19 @@ -import json -from typing import List import os -import aiohttp +from typing import List -from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId def format_file_tree(startpath) -> str: result = "" for root, dirs, files in os.walk(startpath): - level = root.replace(startpath, '').count(os.sep) - indent = ' ' * 4 * (level) - result += '{}{}/'.format(indent, os.path.basename(root)) + "\n" - subindent = ' ' * 4 * (level + 1) + level = root.replace(startpath, "").count(os.sep) + indent = " " * 4 * (level) + result += "{}{}/".format(indent, os.path.basename(root)) + "\n" + subindent = " " * 4 * (level + 1) for f in files: - result += '{}{}'.format(subindent, f) + "\n" + result += "{}{}".format(subindent, f) + "\n" return result @@ -31,11 +29,8 @@ class FileTreeContextProvider(ContextProvider): description=ContextItemDescription( name="File Tree", description="Add a formatted file tree of this directory to the context", - id=ContextItemId( - provider_title=self.title, - item_id=self.title - ) - ) + id=ContextItemId(provider_title=self.title, item_id=self.title), + ), ) async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: diff --git a/continuedev/src/continuedev/plugins/context_providers/github.py b/continuedev/src/continuedev/plugins/context_providers/github.py index 2e7047f2..d394add1 100644 --- a/continuedev/src/continuedev/plugins/context_providers/github.py +++ b/continuedev/src/continuedev/plugins/context_providers/github.py @@ -1,8 +1,13 @@ from typing import List -from github import Github -from github import Auth -from ...core.context import ContextProvider, ContextItemDescription, ContextItem, ContextItemId +from github import Auth, Github + +from ...core.context import ( + ContextItem, + ContextItemDescription, + ContextItemId, + ContextProvider, +) class GitHubIssuesContextProvider(ContextProvider): @@ -22,14 +27,14 @@ class GitHubIssuesContextProvider(ContextProvider): repo = gh.get_repo(self.repo_name) issues = repo.get_issues().get_page(0) - return [ContextItem( - content=issue.body, - description=ContextItemDescription( - name=f"Issue #{issue.number}", - description=issue.title, - id=ContextItemId( - provider_title=self.title, - item_id=issue.id - ) + return [ + ContextItem( + content=issue.body, + description=ContextItemDescription( + name=f"Issue #{issue.number}", + description=issue.title, + id=ContextItemId(provider_title=self.title, item_id=issue.id), + ), ) - ) for issue in issues] + for issue in issues + ] diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py index 4b0a59ec..d716c9d3 100644 --- a/continuedev/src/continuedev/plugins/context_providers/google.py +++ b/continuedev/src/continuedev/plugins/context_providers/google.py @@ -2,9 +2,10 @@ import json from typing import List import aiohttp -from .util import remove_meilisearch_disallowed_chars -from ...core.main import ContextItem, ContextItemDescription, ContextItemId + from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from .util import remove_meilisearch_disallowed_chars class GoogleContextProvider(ContextProvider): @@ -22,22 +23,16 @@ class GoogleContextProvider(ContextProvider): name="Google Search", description="Enter a query to search google", id=ContextItemId( - provider_title=self.title, - item_id=self.GOOGLE_CONTEXT_ITEM_ID - ) - ) + provider_title=self.title, item_id=self.GOOGLE_CONTEXT_ITEM_ID + ), + ), ) async def _google_search(self, query: str) -> str: url = "https://google.serper.dev/search" - payload = json.dumps({ - "q": query - }) - headers = { - 'X-API-KEY': self.serper_api_key, - 'Content-Type': 'application/json' - } + payload = json.dumps({"q": query}) + headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"} async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, data=payload) as response: @@ -61,6 +56,5 @@ class GoogleContextProvider(ContextProvider): ctx_item = self.BASE_CONTEXT_ITEM.copy() ctx_item.content = content - ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars( - query) + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(query) return ctx_item diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index 750775ac..ed293124 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -1,11 +1,17 @@ import os from typing import Any, Dict, List -from ...core.main import ChatMessage -from ...models.filesystem import RangeInFile, RangeInFileWithContents -from ...core.context import ContextItem, ContextItemDescription, ContextItemId, ContextProvider from pydantic import BaseModel +from ...core.context import ( + ContextItem, + ContextItemDescription, + ContextItemId, + ContextProvider, +) +from ...core.main import ChatMessage +from ...models.filesystem import RangeInFileWithContents + class HighlightedRangeContextItem(BaseModel): rif: RangeInFileWithContents @@ -40,12 +46,10 @@ class HighlightedCodeContextProvider(ContextProvider): visible_files = await self.ide.getVisibleFiles() if len(visible_files) > 0: content = await self.ide.readFile(visible_files[0]) - rif = RangeInFileWithContents.from_entire_file( - visible_files[0], content) + rif = RangeInFileWithContents.from_entire_file(visible_files[0], content) item = self._rif_to_context_item(rif, 0, True) - item.description.name = self._rif_to_name( - rif, show_line_nums=False) + item.description.name = self._rif_to_name(rif, show_line_nums=False) self.last_added_fallback = True return HighlightedRangeContextItem(rif=rif, item=item) @@ -55,21 +59,28 @@ class HighlightedCodeContextProvider(ContextProvider): async def get_selected_items(self) -> List[ContextItem]: items = [hr.item for hr in self.highlighted_ranges] - if len(items) == 0 and (fallback_item := await self._get_fallback_context_item()): + if len(items) == 0 and ( + fallback_item := await self._get_fallback_context_item() + ): items = [fallback_item.item] return items async def get_chat_messages(self) -> List[ContextItem]: ranges = self.highlighted_ranges - if len(ranges) == 0 and (fallback_item := await self._get_fallback_context_item()): + if len(ranges) == 0 and ( + fallback_item := await self._get_fallback_context_item() + ): ranges = [fallback_item] - return [ChatMessage( - role="user", - content=f"Code in this file is highlighted ({r.rif.filepath}):\n```\n{r.rif.contents}\n```", - summary=f"Code in this file is highlighted: {r.rif.filepath}" - ) for r in ranges] + return [ + ChatMessage( + role="user", + content=f"Code in this file is highlighted ({r.rif.filepath}):\n```\n{r.rif.contents}\n```", + summary=f"Code in this file is highlighted: {r.rif.filepath}", + ) + for r in ranges + ] def _make_sure_is_editing_range(self): """If none of the highlighted ranges are currently being edited, the first should be selected""" @@ -80,8 +91,9 @@ class HighlightedCodeContextProvider(ContextProvider): def _disambiguate_highlighted_ranges(self): """If any files have the same name, also display their folder name""" - name_status: Dict[str, set] = { - } # basename -> set of full paths with that basename + name_status: Dict[ + str, set + ] = {} # basename -> set of full paths with that basename for hr in self.highlighted_ranges: basename = os.path.basename(hr.rif.filepath) if basename in name_status: @@ -92,11 +104,16 @@ class HighlightedCodeContextProvider(ContextProvider): for hr in self.highlighted_ranges: basename = os.path.basename(hr.rif.filepath) if len(name_status[basename]) > 1: - hr.item.description.name = self._rif_to_name(hr.rif, display_filename=os.path.join( - os.path.basename(os.path.dirname(hr.rif.filepath)), basename)) + hr.item.description.name = self._rif_to_name( + hr.rif, + display_filename=os.path.join( + os.path.basename(os.path.dirname(hr.rif.filepath)), basename + ), + ) else: hr.item.description.name = self._rif_to_name( - hr.rif, display_filename=basename) + hr.rif, display_filename=basename + ) async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: return [] @@ -110,7 +127,9 @@ class HighlightedCodeContextProvider(ContextProvider): self.should_get_fallback_context_item = True self.last_added_fallback = False - async def delete_context_with_ids(self, ids: List[ContextItemId]) -> List[ContextItem]: + async def delete_context_with_ids( + self, ids: List[ContextItemId] + ) -> List[ContextItem]: ids_to_delete = [id.item_id for id in ids] kept_ranges = [] @@ -126,36 +145,57 @@ class HighlightedCodeContextProvider(ContextProvider): return [hr.item for hr in self.highlighted_ranges] - def _rif_to_name(self, rif: RangeInFileWithContents, display_filename: str = None, show_line_nums: bool = True) -> str: - line_nums = f" ({rif.range.start.line + 1}-{rif.range.end.line + 1})" if show_line_nums else "" + def _rif_to_name( + self, + rif: RangeInFileWithContents, + display_filename: str = None, + show_line_nums: bool = True, + ) -> str: + line_nums = ( + f" ({rif.range.start.line + 1}-{rif.range.end.line + 1})" + if show_line_nums + else "" + ) return f"{display_filename or os.path.basename(rif.filepath)}{line_nums}" - def _rif_to_context_item(self, rif: RangeInFileWithContents, idx: int, editing: bool) -> ContextItem: + def _rif_to_context_item( + self, rif: RangeInFileWithContents, idx: int, editing: bool + ) -> ContextItem: return ContextItem( description=ContextItemDescription( name=self._rif_to_name(rif), description=rif.filepath, - id=ContextItemId( - provider_title=self.title, - item_id=str(idx) - ) + id=ContextItemId(provider_title=self.title, item_id=str(idx)), ), content=rif.contents, editing=editing, - editable=True + editable=True, ) - async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): + async def handle_highlighted_code( + self, range_in_files: List[RangeInFileWithContents] + ): self.should_get_fallback_context_item = True self.last_added_fallback = False # Filter out rifs from ~/.continue/diffs folder range_in_files = [ - rif for rif in range_in_files if not os.path.dirname(rif.filepath) == os.path.expanduser("~/.continue/diffs")] + rif + for rif in range_in_files + if not os.path.dirname(rif.filepath) + == os.path.expanduser("~/.continue/diffs") + ] # If not adding highlighted code if not self.adding_highlighted_code: - if len(self.highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end): + if ( + len(self.highlighted_ranges) == 1 + and len(range_in_files) <= 1 + and ( + len(range_in_files) == 0 + or range_in_files[0].range.start == range_in_files[0].range.end + ) + ): # If un-highlighting the range to edit, then remove the range self.highlighted_ranges = [] elif len(range_in_files) > 0: @@ -164,7 +204,9 @@ class HighlightedCodeContextProvider(ContextProvider): self.highlighted_ranges = [ HighlightedRangeContextItem( rif=range_in_files[0], - item=self._rif_to_context_item(range_in_files[0], 0, True))] + item=self._rif_to_context_item(range_in_files[0], 0, True), + ) + ] return @@ -173,22 +215,36 @@ class HighlightedCodeContextProvider(ContextProvider): for i, hr in enumerate(self.highlighted_ranges): found_overlap = False for new_rif in range_in_files: - if hr.rif.filepath == new_rif.filepath and hr.rif.range.overlaps_with(new_rif.range): + if hr.rif.filepath == new_rif.filepath and hr.rif.range.overlaps_with( + new_rif.range + ): found_overlap = True break # Also don't allow multiple ranges in same file with same content. This is useless to the model, and avoids # the bug where cmd+f causes repeated highlights - if hr.rif.filepath == new_rif.filepath and hr.rif.contents == new_rif.contents: + if ( + hr.rif.filepath == new_rif.filepath + and hr.rif.contents == new_rif.contents + ): found_overlap = True break if not found_overlap: - new_ranges.append(HighlightedRangeContextItem(rif=hr.rif, item=self._rif_to_context_item( - hr.rif, len(new_ranges), False))) + new_ranges.append( + HighlightedRangeContextItem( + rif=hr.rif, + item=self._rif_to_context_item(hr.rif, len(new_ranges), False), + ) + ) - self.highlighted_ranges = new_ranges + [HighlightedRangeContextItem(rif=rif, item=self._rif_to_context_item( - rif, len(new_ranges) + idx, False)) for idx, rif in enumerate(range_in_files)] + self.highlighted_ranges = new_ranges + [ + HighlightedRangeContextItem( + rif=rif, + item=self._rif_to_context_item(rif, len(new_ranges) + idx, False), + ) + for idx, rif in enumerate(range_in_files) + ] self._make_sure_is_editing_range() self._disambiguate_highlighted_ranges() @@ -197,5 +253,7 @@ class HighlightedCodeContextProvider(ContextProvider): for hr in self.highlighted_ranges: hr.item.editing = hr.item.description.id.to_string() in ids - async def add_context_item(self, id: ContextItemId, query: str, prev: List[ContextItem] = None) -> List[ContextItem]: + async def add_context_item( + self, id: ContextItemId, query: str, prev: List[ContextItem] = None + ) -> List[ContextItem]: raise NotImplementedError() diff --git a/continuedev/src/continuedev/plugins/context_providers/search.py b/continuedev/src/continuedev/plugins/context_providers/search.py index da991a78..6aecb5bf 100644 --- a/continuedev/src/continuedev/plugins/context_providers/search.py +++ b/continuedev/src/continuedev/plugins/context_providers/search.py @@ -1,10 +1,11 @@ import os from typing import List + from ripgrepy import Ripgrepy -from .util import remove_meilisearch_disallowed_chars -from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from .util import remove_meilisearch_disallowed_chars class SearchContextProvider(ContextProvider): @@ -22,17 +23,16 @@ class SearchContextProvider(ContextProvider): name="Search", description="Search the workspace for all matches of an exact string (e.g. '@search console.log')", id=ContextItemId( - provider_title=self.title, - item_id=self.SEARCH_CONTEXT_ITEM_ID - ) - ) + provider_title=self.title, item_id=self.SEARCH_CONTEXT_ITEM_ID + ), + ), ) def _get_rg_path(self): - if os.name == 'nt': + if os.name == "nt": rg_path = f"C:\\Users\\{os.getlogin()}\\AppData\\Local\\Programs\\Microsoft VS Code\\resources\\app\\node_modules.asar.unpacked\\vscode-ripgrep\\bin\\rg.exe" - elif os.name == 'posix': - if 'darwin' in os.sys.platform: + elif os.name == "posix": + if "darwin" in os.sys.platform: rg_path = "/Applications/Visual Studio Code.app/Contents/Resources/app/node_modules.asar.unpacked/@vscode/ripgrep/bin/rg" else: rg_path = "/usr/share/code/resources/app/node_modules.asar.unpacked/vscode-ripgrep/bin/rg" @@ -87,6 +87,5 @@ class SearchContextProvider(ContextProvider): ctx_item = self.BASE_CONTEXT_ITEM.copy() ctx_item.content = results ctx_item.description.name = f"Search: '{query}'" - ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars( - query) + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(query) return ctx_item diff --git a/continuedev/src/continuedev/plugins/context_providers/url.py b/continuedev/src/continuedev/plugins/context_providers/url.py index 5b67608d..a5ec8990 100644 --- a/continuedev/src/continuedev/plugins/context_providers/url.py +++ b/continuedev/src/continuedev/plugins/context_providers/url.py @@ -1,11 +1,11 @@ - from typing import List -from bs4 import BeautifulSoup + import requests +from bs4 import BeautifulSoup -from .util import remove_meilisearch_disallowed_chars -from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from .util import remove_meilisearch_disallowed_chars class URLContextProvider(ContextProvider): @@ -30,10 +30,9 @@ class URLContextProvider(ContextProvider): name="Dynamic URL", description="Reference the contents of a webpage (e.g. '@url https://www.w3schools.com/python/python_ref_functions.asp')", id=ContextItemId( - provider_title=self.title, - item_id=self.DYNAMIC_URL_CONTEXT_ITEM_ID - ) - ) + provider_title=self.title, item_id=self.DYNAMIC_URL_CONTEXT_ITEM_ID + ), + ), ) def static_url_context_item_from_url(self, url: str) -> ContextItem: @@ -45,30 +44,36 @@ class URLContextProvider(ContextProvider): description=f"Contents of {url}", id=ContextItemId( provider_title=self.title, - item_id=remove_meilisearch_disallowed_chars(url) - ) - ) + item_id=remove_meilisearch_disallowed_chars(url), + ), + ), ) def _get_url_text_contents_and_title(self, url: str) -> (str, str): response = requests.get(url) - soup = BeautifulSoup(response.text, 'html.parser') - title = url.replace( - "https://", "").replace("http://", "").replace("www.", "") + soup = BeautifulSoup(response.text, "html.parser") + title = url.replace("https://", "").replace("http://", "").replace("www.", "") if soup.title is not None: title = soup.title.string return soup.get_text(), title async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: self.static_url_context_items = [ - self.static_url_context_item_from_url(url) for url in self.preset_urls] + self.static_url_context_item_from_url(url) for url in self.preset_urls + ] return [self.DYNAMIC_CONTEXT_ITEM] + self.static_url_context_items async def get_item(self, id: ContextItemId, query: str) -> ContextItem: # Check if the item is a static item matching_static_item = next( - (item for item in self.static_url_context_items if item.description.id.item_id == id.item_id), None) + ( + item + for item in self.static_url_context_items + if item.description.id.item_id == id.item_id + ), + None, + ) if matching_static_item: return matching_static_item @@ -85,6 +90,5 @@ class URLContextProvider(ContextProvider): ctx_item = self.DYNAMIC_CONTEXT_ITEM.copy() ctx_item.content = content ctx_item.description.name = title - ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars( - url) + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(url) return ctx_item diff --git a/continuedev/src/continuedev/plugins/context_providers/util.py b/continuedev/src/continuedev/plugins/context_providers/util.py index da2e6b17..61bea8aa 100644 --- a/continuedev/src/continuedev/plugins/context_providers/util.py +++ b/continuedev/src/continuedev/plugins/context_providers/util.py @@ -2,4 +2,4 @@ import re def remove_meilisearch_disallowed_chars(id: str) -> str: - return re.sub(r'[^0-9a-zA-Z_-]', '', id) + return re.sub(r"[^0-9a-zA-Z_-]", "", id) diff --git a/continuedev/src/continuedev/plugins/policies/default.py b/continuedev/src/continuedev/plugins/policies/default.py index 2382f33a..ef88c8d6 100644 --- a/continuedev/src/continuedev/plugins/policies/default.py +++ b/continuedev/src/continuedev/plugins/policies/default.py @@ -1,15 +1,15 @@ from textwrap import dedent from typing import Type, Union -from ..steps.chat import SimpleChatStep -from ..steps.welcome import WelcomeStep from ...core.config import ContinueConfig -from ..steps.steps_on_startup import StepsOnStartupStep -from ...core.main import Step, History, Policy +from ...core.main import History, Policy, Step from ...core.observation import UserInputObservation +from ..steps.chat import SimpleChatStep from ..steps.core.core import MessageStep from ..steps.custom_command import CustomCommandStep from ..steps.main import EditHighlightedCodeStep +from ..steps.steps_on_startup import StepsOnStartupStep +from ..steps.welcome import WelcomeStep def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: @@ -28,7 +28,8 @@ def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: return slash_command.step(**params) except TypeError as e: raise Exception( - f"Incorrect params used for slash command '{command_name}': {e}") + f"Incorrect params used for slash command '{command_name}': {e}" + ) return None @@ -40,12 +41,17 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]: slash_command = parse_slash_command(custom_cmd.prompt, config) if slash_command is not None: return slash_command - return CustomCommandStep(name=custom_cmd.name, description=custom_cmd.description, prompt=custom_cmd.prompt, user_input=after_command, slash_command=command_name) + return CustomCommandStep( + name=custom_cmd.name, + description=custom_cmd.description, + prompt=custom_cmd.prompt, + user_input=after_command, + slash_command=command_name, + ) return None class DefaultPolicy(Policy): - default_step: Type[Step] = SimpleChatStep default_params: dict = {} @@ -53,13 +59,19 @@ class DefaultPolicy(Policy): # At the very start, run initial Steps spcecified in the config if history.get_current() is None: return ( - MessageStep(name="Welcome to Continue", message=dedent("""\ + MessageStep( + name="Welcome to Continue", + message=dedent( + """\ - Highlight code section and ask a question or give instructions - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue - Use `/help` to ask questions about how to use Continue - - [Customize Continue](https://continue.dev/docs/customization) (e.g. use your own API key) by typing '/config'.""")) >> - WelcomeStep() >> - StepsOnStartupStep()) + - [Customize Continue](https://continue.dev/docs/customization) (e.g. use your own API key) by typing '/config'.""" + ), + ) + >> WelcomeStep() + >> StepsOnStartupStep() + ) observation = history.get_current().observation if observation is not None and isinstance(observation, UserInputObservation): diff --git a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/README.md b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/README.md index d735e0cd..78d603a2 100644 --- a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/README.md +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/README.md @@ -3,6 +3,7 @@ Uses the Chess.com API example to show how to add map and filter Python transforms to a dlt pipeline.
Background
+
- https://dlthub.com/docs/general-usage/resource#filter-transform-and-pivot-data
- https://dlthub.com/docs/customizations/customizing-pipelines/renaming_columns
-- https://dlthub.com/docs/customizations/customizing-pipelines/pseudonymizing_columns
\ No newline at end of file +- https://dlthub.com/docs/customizations/customizing-pipelines/pseudonymizing_columns
diff --git a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/dlt_transform_docs.md b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/dlt_transform_docs.md index 658b285f..864aea87 100644 --- a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/dlt_transform_docs.md +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/dlt_transform_docs.md @@ -1,15 +1,19 @@ # Customize resources + ## Filter, transform and pivot data You can attach any number of transformations that are evaluated on item per item basis to your resource. The available transformation types: + - map - transform the data item (resource.add_map) - filter - filter the data item (resource.add_filter) - yield map - a map that returns iterator (so single row may generate many rows - resource.add_yield_map) Example: We have a resource that loads a list of users from an api endpoint. We want to customize it so: + - we remove users with user_id == 'me' - we anonymize user data -Here's our resource: + Here's our resource: + ```python import dlt @@ -22,6 +26,7 @@ def users(): ``` Here's our script that defines transformations and loads the data. + ```python from pipedrive import users @@ -34,9 +39,9 @@ def anonymize_user(user_data): for user in users().add_filter(lambda user: user['user_id'] != 'me').add_map(anonymize_user): print(user) ``` - + Here is a more complex example of a filter transformation: - + # Renaming columns ## Renaming columns by replacing the special characters @@ -80,11 +85,13 @@ Here is a more complex example of a filter transformation: # {'Objekt_0': {'Groesse': 0, 'Aequivalenzpruefung': True}} # ... ``` - + Here is a more complex example of a map transformation: - + # Pseudonymizing columns + ## Pseudonymizing (or anonymizing) columns by replacing the special characters + Pseudonymization is a deterministic way to hide personally identifiable info (PII), enabling us to consistently achieve the same mapping. If instead you wish to anonymize, you can delete the data, or replace it with a constant. In the example below, we create a dummy source with a PII column called 'name', which we replace with deterministic hashes (i.e. replacing the German umlaut). ```python @@ -132,4 +139,4 @@ def pseudonymize_name(doc): pipeline = dlt.pipeline(pipeline_name='example', destination='bigquery', dataset_name='normalized_data') load_info = pipeline.run(data_source) -```
\ No newline at end of file +``` diff --git a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/main.py index 5d242f7c..54207399 100644 --- a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/main.py @@ -2,9 +2,8 @@ from textwrap import dedent from ....core.main import Step from ....core.sdk import ContinueSDK -from ....plugins.steps.core.core import WaitForUserInputStep -from ....plugins.steps.core.core import MessageStep -from .steps import SetUpChessPipelineStep, AddTransformStep +from ....plugins.steps.core.core import MessageStep, WaitForUserInputStep +from .steps import AddTransformStep, SetUpChessPipelineStep class AddTransformRecipe(Step): @@ -12,16 +11,21 @@ class AddTransformRecipe(Step): async def run(self, sdk: ContinueSDK): text_observation = await sdk.run_step( - MessageStep(message=dedent("""\ + MessageStep( + message=dedent( + """\ This recipe will walk you through the process of adding a transform to a dlt pipeline that uses the chess.com API source. With the help of Continue, you will: - Set up a dlt pipeline for the chess.com API - Add a filter or map transform to the pipeline - - Run the pipeline and view the transformed data in a Streamlit app"""), name="Add transformation to a dlt pipeline") >> - SetUpChessPipelineStep() >> - WaitForUserInputStep( - prompt="How do you want to transform the Chess.com API data before loading it? For example, you could filter out games that ended in a draw.") + - Run the pipeline and view the transformed data in a Streamlit app""" + ), + name="Add transformation to a dlt pipeline", + ) + >> SetUpChessPipelineStep() + >> WaitForUserInputStep( + prompt="How do you want to transform the Chess.com API data before loading it? For example, you could filter out games that ended in a draw." + ) ) await sdk.run_step( - AddTransformStep( - transform_description=text_observation.text) + AddTransformStep(transform_description=text_observation.text) ) diff --git a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/steps.py index e589fc36..0091af97 100644 --- a/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/steps.py @@ -1,12 +1,10 @@ import os from textwrap import dedent - -from ....plugins.steps.core.core import MessageStep -from ....libs.util.paths import find_data_file -from ....core.sdk import Models from ....core.main import Step -from ....core.sdk import ContinueSDK +from ....core.sdk import ContinueSDK, Models +from ....libs.util.paths import find_data_file +from ....plugins.steps.core.core import MessageStep AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" @@ -19,21 +17,26 @@ class SetUpChessPipelineStep(Step): return "This step will create a new dlt pipeline that loads data from the chess.com API." async def run(self, sdk: ContinueSDK): - # running commands to get started when creating a new dlt pipeline - await sdk.run([ - 'python3 -m venv .env', - 'source .env/bin/activate', - 'pip install dlt', - 'dlt --non-interactive init chess duckdb', - 'pip install -r requirements.txt', - 'pip install pandas streamlit' # Needed for the pipeline show step later - ], name="Set up Python environment", description=dedent(f"""\ + await sdk.run( + [ + "python3 -m venv .env", + "source .env/bin/activate", + "pip install dlt", + "dlt --non-interactive init chess duckdb", + "pip install -r requirements.txt", + "pip install pandas streamlit", # Needed for the pipeline show step later + ], + name="Set up Python environment", + description=dedent( + """\ - Create a Python virtual environment: `python3 -m venv .env` - Activate the virtual environment: `source .env/bin/activate` - Install dlt: `pip install dlt` - Create a new dlt pipeline called "chess" that loads data into a local DuckDB instance: `dlt init chess duckdb` - - Install the Python dependencies for the pipeline: `pip install -r requirements.txt`""")) + - Install the Python dependencies for the pipeline: `pip install -r requirements.txt`""" + ), + ) class AddTransformStep(Step): @@ -43,42 +46,61 @@ class AddTransformStep(Step): transform_description: str async def run(self, sdk: ContinueSDK): - source_name = 'chess' - filename = f'{source_name}_pipeline.py' + source_name = "chess" + filename = f"{source_name}_pipeline.py" abs_filepath = os.path.join(sdk.ide.workspace_directory, filename) # Open the file and highlight the function to be edited await sdk.ide.setFileOpen(abs_filepath) - await sdk.run_step(MessageStep(message=dedent("""\ + await sdk.run_step( + MessageStep( + message=dedent( + """\ This step will customize your resource function with a transform of your choice: - Add a filter or map transformation depending on your request - Load the data into a local DuckDB instance - - Open up a Streamlit app for you to view the data"""), name="Write transformation function")) + - Open up a Streamlit app for you to view the data""" + ), + name="Write transformation function", + ) + ) - with open(find_data_file('dlt_transform_docs.md')) as f: + with open(find_data_file("dlt_transform_docs.md")) as f: dlt_transform_docs = f.read() - prompt = dedent(f"""\ + prompt = dedent( + f"""\ Task: Write a transform function using the description below and then use `add_map` or `add_filter` from the `dlt` library to attach it a resource. Description: {self.transform_description} Here are some docs pages that will help you better understand how to use `dlt`. - {dlt_transform_docs}""") + {dlt_transform_docs}""" + ) # edit the pipeline to add a tranform function and attach it to a resource await sdk.edit_file( filename=filename, prompt=prompt, - name=f"Writing transform function {AI_ASSISTED_STRING}" + name=f"Writing transform function {AI_ASSISTED_STRING}", ) - await sdk.wait_for_user_confirmation("Press Continue to confirm that the changes are okay before we run the pipeline.") + await sdk.wait_for_user_confirmation( + "Press Continue to confirm that the changes are okay before we run the pipeline." + ) # run the pipeline and load the data - await sdk.run(f'python3 {filename}', name="Run the pipeline", description=f"Running `python3 {filename}` to load the data into a local DuckDB instance") + await sdk.run( + f"python3 {filename}", + name="Run the pipeline", + description=f"Running `python3 {filename}` to load the data into a local DuckDB instance", + ) # run a streamlit app to show the data - await sdk.run(f'dlt pipeline {source_name}_pipeline show', name="Show data in a Streamlit app", description=f"Running `dlt pipeline {source_name} show` to show the data in a Streamlit app, where you can view and play with the data.") + await sdk.run( + f"dlt pipeline {source_name}_pipeline show", + name="Show data in a Streamlit app", + description=f"Running `dlt pipeline {source_name} show` to show the data in a Streamlit app, where you can view and play with the data.", + ) diff --git a/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/main.py index c0f9e7e3..e67ea557 100644 --- a/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/main.py @@ -1,7 +1,8 @@ from textwrap import dedent -from ....plugins.steps.main import EditHighlightedCodeStep + from ....core.main import Step from ....core.sdk import ContinueSDK +from ....plugins.steps.main import EditHighlightedCodeStep class ContinueStepStep(Step): @@ -9,7 +10,10 @@ class ContinueStepStep(Step): prompt: str async def run(self, sdk: ContinueSDK): - await sdk.run_step(EditHighlightedCodeStep(user_input=dedent(f"""\ + await sdk.run_step( + EditHighlightedCodeStep( + user_input=dedent( + f"""\ Here is an example of a Step that runs a command and then edits a file. ```python @@ -33,4 +37,7 @@ class ContinueStepStep(Step): It should be a subclass of Step as above, implementing the `run` method, and using pydantic attributes to define the parameters. - """))) + """ + ) + ) + ) diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/main.py index 84363e02..4b259769 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/main.py @@ -1,10 +1,9 @@ from textwrap import dedent -from ....core.sdk import ContinueSDK from ....core.main import Step -from ....plugins.steps.core.core import WaitForUserInputStep -from ....plugins.steps.core.core import MessageStep -from .steps import SetupPipelineStep, ValidatePipelineStep, RunQueryStep +from ....core.sdk import ContinueSDK +from ....plugins.steps.core.core import MessageStep, WaitForUserInputStep +from .steps import RunQueryStep, SetupPipelineStep, ValidatePipelineStep class CreatePipelineRecipe(Step): @@ -12,7 +11,10 @@ class CreatePipelineRecipe(Step): async def run(self, sdk: ContinueSDK): text_observation = await sdk.run_step( - MessageStep(name="Building your first dlt pipeline", message=dedent("""\ + MessageStep( + name="Building your first dlt pipeline", + message=dedent( + """\ This recipe will walk you through the process of creating a dlt pipeline for your chosen data source. With the help of Continue, you will: - Create a Python virtual environment with dlt installed - Run `dlt init` to generate a pipeline template @@ -20,14 +22,19 @@ class CreatePipelineRecipe(Step): - Add any required API keys to the `secrets.toml` file - Test that the API call works - Load the data into a local DuckDB instance - - Write a query to view the data""")) >> - WaitForUserInputStep( - prompt="What API do you want to load data from? (e.g. weatherapi.com, chess.com)") + - Write a query to view the data""" + ), + ) + >> WaitForUserInputStep( + prompt="What API do you want to load data from? (e.g. weatherapi.com, chess.com)" + ) ) await sdk.run_step( - SetupPipelineStep(api_description=text_observation.text) >> - ValidatePipelineStep() >> - RunQueryStep() >> - MessageStep( - name="Congrats!", message="You've successfully created your first dlt pipeline! 🎉") + SetupPipelineStep(api_description=text_observation.text) + >> ValidatePipelineStep() + >> RunQueryStep() + >> MessageStep( + name="Congrats!", + message="You've successfully created your first dlt pipeline! 🎉", + ) ) diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index 872f8d62..43a2b800 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -1,13 +1,13 @@ import os -from textwrap import dedent import time +from textwrap import dedent -from ....models.main import Range -from ....models.filesystem import RangeInFile -from ....plugins.steps.core.core import MessageStep -from ....models.filesystem_edit import AddFile, FileEdit from ....core.main import Step from ....core.sdk import ContinueSDK, Models +from ....models.filesystem import RangeInFile +from ....models.filesystem_edit import AddFile, FileEdit +from ....models.main import Range +from ....plugins.steps.core.core import MessageStep AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" @@ -19,50 +19,71 @@ class SetupPipelineStep(Step): api_description: str # e.g. "I want to load data from the weatherapi.com API" async def describe(self, models: Models): - return dedent(f"""\ + return dedent( + f"""\ This step will create a new dlt pipeline that loads data from an API, as per your request: {self.api_description} - """) + """ + ) async def run(self, sdk: ContinueSDK): sdk.context.set("api_description", self.api_description) - source_name = (await sdk.models.medium.complete( - f"Write a snake_case name for the data source described by {self.api_description}: ")).strip() - filename = f'{source_name}.py' + source_name = ( + await sdk.models.medium.complete( + f"Write a snake_case name for the data source described by {self.api_description}: " + ) + ).strip() + filename = f"{source_name}.py" # running commands to get started when creating a new dlt pipeline - await sdk.run([ - 'python3 -m venv .env', - 'source .env/bin/activate', - 'pip install dlt', - f'dlt --non-interactive init {source_name} duckdb', - 'pip install -r requirements.txt' - ], description=dedent(f"""\ + await sdk.run( + [ + "python3 -m venv .env", + "source .env/bin/activate", + "pip install dlt", + f"dlt --non-interactive init {source_name} duckdb", + "pip install -r requirements.txt", + ], + description=dedent( + f"""\ Running the following commands: - `python3 -m venv .env`: Create a Python virtual environment - `source .env/bin/activate`: Activate the virtual environment - `pip install dlt`: Install dlt - `dlt init {source_name} duckdb`: Create a new dlt pipeline called {source_name} that loads data into a local DuckDB instance - - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline"""), name="Setup Python environment") + - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline""" + ), + name="Setup Python environment", + ) # editing the resource function to call the requested API resource_function_range = Range.from_shorthand(15, 0, 30, 0) - await sdk.ide.highlightCode(RangeInFile(filepath=os.path.join(await sdk.ide.getWorkspaceDirectory(), filename), range=resource_function_range), "#ffa50033") + await sdk.ide.highlightCode( + RangeInFile( + filepath=os.path.join(await sdk.ide.getWorkspaceDirectory(), filename), + range=resource_function_range, + ), + "#ffa50033", + ) # sdk.set_loading_message("Writing code to call the API...") await sdk.edit_file( range=resource_function_range, filename=filename, - prompt=f'Edit the resource function to call the API described by this: {self.api_description}. Do not move or remove the exit() call in __main__.', - name=f"Edit the resource function to call the API {AI_ASSISTED_STRING}" + prompt=f"Edit the resource function to call the API described by this: {self.api_description}. Do not move or remove the exit() call in __main__.", + name=f"Edit the resource function to call the API {AI_ASSISTED_STRING}", ) time.sleep(1) # wait for user to put API key in secrets.toml - await sdk.ide.setFileOpen(await sdk.ide.getWorkspaceDirectory() + "/.dlt/secrets.toml") - await sdk.wait_for_user_confirmation("If this service requires an API key, please add it to the `secrets.toml` file and then press `Continue`.") + await sdk.ide.setFileOpen( + await sdk.ide.getWorkspaceDirectory() + "/.dlt/secrets.toml" + ) + await sdk.wait_for_user_confirmation( + "If this service requires an API key, please add it to the `secrets.toml` file and then press `Continue`." + ) sdk.context.set("source_name", source_name) @@ -73,7 +94,7 @@ class ValidatePipelineStep(Step): async def run(self, sdk: ContinueSDK): workspace_dir = await sdk.ide.getWorkspaceDirectory() source_name = sdk.context.get("source_name") - filename = f'{source_name}.py' + filename = f"{source_name}.py" # await sdk.run_step(MessageStep(name="Validate the pipeline", message=dedent("""\ # Next, we will validate that your dlt pipeline is working as expected: @@ -83,13 +104,20 @@ class ValidatePipelineStep(Step): # """))) # test that the API call works - output = await sdk.run(f'python3 {filename}', name="Test the pipeline", description=f"Running `python3 {filename}` to test loading data from the API", handle_error=False) + output = await sdk.run( + f"python3 {filename}", + name="Test the pipeline", + description=f"Running `python3 {filename}` to test loading data from the API", + handle_error=False, + ) # If it fails, return the error if "Traceback" in output or "SyntaxError" in output: output = "Traceback" + output.split("Traceback")[-1] file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) - suggestion = await sdk.models.medium.complete(dedent(f"""\ + suggestion = await sdk.models.medium.complete( + dedent( + f"""\ ```python {file_content} ``` @@ -99,59 +127,98 @@ class ValidatePipelineStep(Step): {output} ``` - This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:""")) + This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:""" + ) + ) - api_documentation_url = await sdk.models.medium.complete(dedent(f"""\ + api_documentation_url = await sdk.models.medium.complete( + dedent( + f"""\ The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this: ```python {file_content} ``` - What is the URL for the API documentation that will help me learn how to make this call? Please format in markdown so I can click the link.""")) + What is the URL for the API documentation that will help me learn how to make this call? Please format in markdown so I can click the link.""" + ) + ) sdk.raise_exception( - title=f"Error while running pipeline.\nFix the resource function in {filename} and rerun this step", message=output, with_step=MessageStep(name=f"Suggestion to solve error {AI_ASSISTED_STRING}", message=dedent(f"""\ + title=f"Error while running pipeline.\nFix the resource function in {filename} and rerun this step", + message=output, + with_step=MessageStep( + name=f"Suggestion to solve error {AI_ASSISTED_STRING}", + message=dedent( + f"""\ {suggestion} {api_documentation_url} - After you've fixed the code, click the retry button at the top of the Validate Pipeline step above."""))) + After you've fixed the code, click the retry button at the top of the Validate Pipeline step above.""" + ), + ), + ) # remove exit() from the main main function - await sdk.run_step(MessageStep(name="Remove early exit() from main function", message="Remove the early exit() from the main function now that we are done testing and want the pipeline to load the data into DuckDB.")) + await sdk.run_step( + MessageStep( + name="Remove early exit() from main function", + message="Remove the early exit() from the main function now that we are done testing and want the pipeline to load the data into DuckDB.", + ) + ) contents = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) replacement = "\n".join( - list(filter(lambda line: line.strip() != "exit()", contents.split("\n")))) - await sdk.ide.applyFileSystemEdit(FileEdit( - filepath=os.path.join(workspace_dir, filename), - replacement=replacement, - range=Range.from_entire_file(contents) - )) + list(filter(lambda line: line.strip() != "exit()", contents.split("\n"))) + ) + await sdk.ide.applyFileSystemEdit( + FileEdit( + filepath=os.path.join(workspace_dir, filename), + replacement=replacement, + range=Range.from_entire_file(contents), + ) + ) # load the data into the DuckDB instance - await sdk.run(f'python3 {filename}', name="Load data into DuckDB", description=f"Running python3 {filename} to load data into DuckDB") + await sdk.run( + f"python3 {filename}", + name="Load data into DuckDB", + description=f"Running python3 {filename} to load data into DuckDB", + ) - tables_query_code = dedent(f'''\ + tables_query_code = dedent( + f"""\ import duckdb # connect to DuckDB instance conn = duckdb.connect(database="{source_name}.duckdb") # list all tables - print(conn.sql("DESCRIBE"))''') + print(conn.sql("DESCRIBE"))""" + ) query_filename = os.path.join(workspace_dir, "query.py") - await sdk.apply_filesystem_edit(AddFile(filepath=query_filename, content=tables_query_code), name="Add query.py file", description="Adding a file called `query.py` to the workspace that will run a test query on the DuckDB instance") + await sdk.apply_filesystem_edit( + AddFile(filepath=query_filename, content=tables_query_code), + name="Add query.py file", + description="Adding a file called `query.py` to the workspace that will run a test query on the DuckDB instance", + ) class RunQueryStep(Step): hide: bool = True async def run(self, sdk: ContinueSDK): - output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False) + output = await sdk.run( + ".env/bin/python3 query.py", + name="Run test query", + description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", + handle_error=False, + ) if "Traceback" in output or "SyntaxError" in output: - suggestion = await sdk.models.medium.complete(dedent(f"""\ + suggestion = await sdk.models.medium.complete( + dedent( + f"""\ ```python {await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))} ``` @@ -161,8 +228,16 @@ class RunQueryStep(Step): {output} ``` - This is a brief summary of the error followed by a suggestion on how it can be fixed:""")) + This is a brief summary of the error followed by a suggestion on how it can be fixed:""" + ) + ) sdk.raise_exception( - title="Error while running query", message=output, with_step=MessageStep(name=f"Suggestion to solve error {AI_ASSISTED_STRING}", message=suggestion + "\n\nIt is also very likely that no duckdb table was created, which can happen if the resource function did not yield any data. Please make sure that it is yielding data and then rerun this step.") + title="Error while running query", + message=output, + with_step=MessageStep( + name=f"Suggestion to solve error {AI_ASSISTED_STRING}", + message=suggestion + + "\n\nIt is also very likely that no duckdb table was created, which can happen if the resource function did not yield any data. Please make sure that it is yielding data and then rerun this step.", + ), ) diff --git a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/README.md b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/README.md index c4981e56..d50324f7 100644 --- a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/README.md +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/README.md @@ -1,3 +1,3 @@ # DDtoBQRecipe -Move from using DuckDB to Google BigQuery as the destination for your `dlt` pipeline
\ No newline at end of file +Move from using DuckDB to Google BigQuery as the destination for your `dlt` pipeline diff --git a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/main.py index 5b6aa8f0..5348321d 100644 --- a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/main.py @@ -3,7 +3,7 @@ from textwrap import dedent from ....core.main import Step from ....core.sdk import ContinueSDK from ....plugins.steps.core.core import MessageStep -from .steps import SetUpChessPipelineStep, SwitchDestinationStep, LoadDataStep +from .steps import LoadDataStep, SetUpChessPipelineStep, SwitchDestinationStep # Based on the following guide: # https://github.com/dlt-hub/dlt/pull/392 @@ -14,13 +14,18 @@ class DDtoBQRecipe(Step): async def run(self, sdk: ContinueSDK): await sdk.run_step( - MessageStep(name="Move from using DuckDB to Google BigQuery as the destination", message=dedent("""\ + MessageStep( + name="Move from using DuckDB to Google BigQuery as the destination", + message=dedent( + """\ This recipe will walk you through the process of moving from using DuckDB to Google BigQuery as the destination for your dlt pipeline. With the help of Continue, you will: - Set up a dlt pipeline for the chess.com API - Switch destination from DuckDB to Google BigQuery - Add BigQuery credentials to your secrets.toml file - - Run the pipeline again to load data to BigQuery""")) >> - SetUpChessPipelineStep() >> - SwitchDestinationStep() >> - LoadDataStep() + - Run the pipeline again to load data to BigQuery""" + ), + ) + >> SetUpChessPipelineStep() + >> SwitchDestinationStep() + >> LoadDataStep() ) diff --git a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py index 14972142..d6769148 100644 --- a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py @@ -1,12 +1,11 @@ import os from textwrap import dedent -from ....plugins.steps.find_and_replace import FindAndReplaceStep -from ....plugins.steps.core.core import MessageStep -from ....core.sdk import Models from ....core.main import Step -from ....core.sdk import ContinueSDK +from ....core.sdk import ContinueSDK, Models from ....libs.util.paths import find_data_file +from ....plugins.steps.core.core import MessageStep +from ....plugins.steps.find_and_replace import FindAndReplaceStep AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" @@ -19,49 +18,61 @@ class SetUpChessPipelineStep(Step): return "This step will create a new dlt pipeline that loads data from the chess.com API." async def run(self, sdk: ContinueSDK): - # running commands to get started when creating a new dlt pipeline - await sdk.run([ - 'python3 -m venv .env', - 'source .env/bin/activate', - 'pip install dlt', - 'dlt --non-interactive init chess duckdb', - 'pip install -r requirements.txt', - ], name="Set up Python environment", description=dedent(f"""\ + await sdk.run( + [ + "python3 -m venv .env", + "source .env/bin/activate", + "pip install dlt", + "dlt --non-interactive init chess duckdb", + "pip install -r requirements.txt", + ], + name="Set up Python environment", + description=dedent( + """\ Running the following commands: - `python3 -m venv .env`: Create a Python virtual environment - `source .env/bin/activate`: Activate the virtual environment - `pip install dlt`: Install dlt - `dlt init chess duckdb`: Create a new dlt pipeline called "chess" that loads data into a local DuckDB instance - - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline""")) + - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline""" + ), + ) class SwitchDestinationStep(Step): hide: bool = True async def run(self, sdk: ContinueSDK): - # Switch destination from DuckDB to Google BigQuery - filepath = os.path.join( - sdk.ide.workspace_directory, 'chess_pipeline.py') - await sdk.run_step(FindAndReplaceStep(filepath=filepath, pattern="destination='duckdb'", replacement="destination='bigquery'")) + filepath = os.path.join(sdk.ide.workspace_directory, "chess_pipeline.py") + await sdk.run_step( + FindAndReplaceStep( + filepath=filepath, + pattern="destination='duckdb'", + replacement="destination='bigquery'", + ) + ) # Add BigQuery credentials to your secrets.toml file - template = dedent(f"""\ + template = dedent( + """\ [destination.bigquery.credentials] location = "US" # change the location of the data project_id = "project_id" # please set me up! private_key = "private_key" # please set me up! - client_email = "client_email" # please set me up!""") + client_email = "client_email" # please set me up!""" + ) # wait for user to put API key in secrets.toml - secrets_path = os.path.join( - sdk.ide.workspace_directory, ".dlt/secrets.toml") + secrets_path = os.path.join(sdk.ide.workspace_directory, ".dlt/secrets.toml") await sdk.ide.setFileOpen(secrets_path) await sdk.append_to_file(secrets_path, template) # append template to bottom of secrets.toml - await sdk.wait_for_user_confirmation("Please add your GCP credentials to `secrets.toml` file and then press `Continue`") + await sdk.wait_for_user_confirmation( + "Please add your GCP credentials to `secrets.toml` file and then press `Continue`" + ) class LoadDataStep(Step): @@ -70,14 +81,20 @@ class LoadDataStep(Step): async def run(self, sdk: ContinueSDK): # Run the pipeline again to load data to BigQuery - output = await sdk.run('.env/bin/python3 chess_pipeline.py', name="Load data to BigQuery", description="Running `.env/bin/python3 chess_pipeline.py` to load data to Google BigQuery") + output = await sdk.run( + ".env/bin/python3 chess_pipeline.py", + name="Load data to BigQuery", + description="Running `.env/bin/python3 chess_pipeline.py` to load data to Google BigQuery", + ) if "Traceback" in output or "SyntaxError" in output: with open(find_data_file("dlt_duckdb_to_bigquery_docs.md"), "r") as f: docs = f.read() output = "Traceback" + output.split("Traceback")[-1] - suggestion = await sdk.models.default.complete(dedent(f"""\ + suggestion = await sdk.models.default.complete( + dedent( + f"""\ When trying to load data into BigQuery, the following error occurred: ```ascii @@ -88,8 +105,15 @@ class LoadDataStep(Step): {docs} - This is a brief summary of the error followed by a suggestion on how it can be fixed:""")) + This is a brief summary of the error followed by a suggestion on how it can be fixed:""" + ) + ) sdk.raise_exception( - title="Error while running query", message=output, with_step=MessageStep(name=f"Suggestion to solve error {AI_ASSISTED_STRING}", message=suggestion) + title="Error while running query", + message=output, + with_step=MessageStep( + name=f"Suggestion to solve error {AI_ASSISTED_STRING}", + message=suggestion, + ), ) diff --git a/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/main.py index 54cba45f..8f16cb34 100644 --- a/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/main.py @@ -1,11 +1,10 @@ from textwrap import dedent -from ....plugins.steps.input.nl_multiselect import NLMultiselectStep from ....core.main import Step from ....core.sdk import ContinueSDK from ....plugins.steps.core.core import MessageStep -from .steps import SetupPipelineStep, DeployAirflowStep, RunPipelineStep - +from ....plugins.steps.input.nl_multiselect import NLMultiselectStep +from .steps import DeployAirflowStep, RunPipelineStep, SetupPipelineStep # https://github.com/dlt-hub/dlt-deploy-template/blob/master/airflow-composer/dag_template.py # https://www.notion.so/dlthub/Deploy-a-pipeline-with-Airflow-245fd1058652479494307ead0b5565f3 @@ -20,14 +19,20 @@ class DeployPipelineAirflowRecipe(Step): async def run(self, sdk: ContinueSDK): source_name = await sdk.run_step( - MessageStep(name="Deploying a pipeline to Airflow", message=dedent("""\ + MessageStep( + name="Deploying a pipeline to Airflow", + message=dedent( + """\ This recipe will show you how to deploy a pipeline to Airflow. With the help of Continue, you will: - Select a dlt-verified pipeline - Setup the pipeline - Deploy it to Airflow - - Optionally, setup Airflow locally""")) >> - NLMultiselectStep( - prompt=dedent("""\ + - Optionally, setup Airflow locally""" + ), + ) + >> NLMultiselectStep( + prompt=dedent( + """\ Which verified pipeline do you want to deploy with Airflow? The options are: - Asana - Chess.com @@ -48,14 +53,34 @@ class DeployPipelineAirflowRecipe(Step): - Stripe - SQL Database - Workable - - Zendesk"""), + - Zendesk""" + ), options=[ - "asana_dlt", "chess", "github", "google_analytics", "google_sheets", "hubspot", "matomo", "pipedrive", "shopify_dlt", "strapi", "zendesk", - "facebook_ads", "jira", "mux", "notion", "pokemon", "salesforce", "stripe_analytics", "sql_database", "workable" - ]) + "asana_dlt", + "chess", + "github", + "google_analytics", + "google_sheets", + "hubspot", + "matomo", + "pipedrive", + "shopify_dlt", + "strapi", + "zendesk", + "facebook_ads", + "jira", + "mux", + "notion", + "pokemon", + "salesforce", + "stripe_analytics", + "sql_database", + "workable", + ], + ) ) await sdk.run_step( - SetupPipelineStep(source_name=source_name) >> - RunPipelineStep(source_name=source_name) >> - DeployAirflowStep(source_name=source_name) + SetupPipelineStep(source_name=source_name) + >> RunPipelineStep(source_name=source_name) + >> DeployAirflowStep(source_name=source_name) ) diff --git a/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/steps.py index 83067d52..d09cf8bb 100644 --- a/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/steps.py @@ -1,10 +1,9 @@ import os from textwrap import dedent -from ....plugins.steps.core.core import MessageStep -from ....core.sdk import Models from ....core.main import Step -from ....core.sdk import ContinueSDK +from ....core.sdk import ContinueSDK, Models +from ....plugins.steps.core.core import MessageStep from ....plugins.steps.find_and_replace import FindAndReplaceStep AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" @@ -20,19 +19,25 @@ class SetupPipelineStep(Step): pass async def run(self, sdk: ContinueSDK): - await sdk.run([ - 'python3 -m venv .env', - 'source .env/bin/activate', - 'pip install dlt', - f'dlt --non-interactive init {self.source_name} duckdb', - 'pip install -r requirements.txt' - ], description=dedent(f"""\ + await sdk.run( + [ + "python3 -m venv .env", + "source .env/bin/activate", + "pip install dlt", + f"dlt --non-interactive init {self.source_name} duckdb", + "pip install -r requirements.txt", + ], + description=dedent( + f"""\ Running the following commands: - `python3 -m venv .env`: Create a Python virtual environment - `source .env/bin/activate`: Activate the virtual environment - `pip install dlt`: Install dlt - `dlt init {self.source_name} duckdb`: Create a new dlt pipeline called {self.source_name} that loads data into a local DuckDB instance - - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline"""), name="Setup Python environment") + - `pip install -r requirements.txt`: Install the Python dependencies for the pipeline""" + ), + name="Setup Python environment", + ) class RunPipelineStep(Step): @@ -45,10 +50,16 @@ class RunPipelineStep(Step): pass async def run(self, sdk: ContinueSDK): - await sdk.run([ - f'python3 {self.source_name}_pipeline.py', - ], description=dedent(f"""\ - Running the command `python3 {self.source_name}_pipeline.py to run the pipeline: """), name="Run dlt pipeline") + await sdk.run( + [ + f"python3 {self.source_name}_pipeline.py", + ], + description=dedent( + f"""\ + Running the command `python3 {self.source_name}_pipeline.py to run the pipeline: """ + ), + name="Run dlt pipeline", + ) class DeployAirflowStep(Step): @@ -56,26 +67,47 @@ class DeployAirflowStep(Step): source_name: str async def run(self, sdk: ContinueSDK): - # Run dlt command to deploy pipeline to Airflow await sdk.run( - ['git init', - f'dlt --non-interactive deploy {self.source_name}_pipeline.py airflow-composer'], - description="Running `dlt deploy airflow` to deploy the dlt pipeline to Airflow", name="Deploy dlt pipeline to Airflow") + [ + "git init", + f"dlt --non-interactive deploy {self.source_name}_pipeline.py airflow-composer", + ], + description="Running `dlt deploy airflow` to deploy the dlt pipeline to Airflow", + name="Deploy dlt pipeline to Airflow", + ) # Get filepaths, open the DAG file directory = await sdk.ide.getWorkspaceDirectory() - pipeline_filepath = os.path.join( - directory, f"{self.source_name}_pipeline.py") + pipeline_filepath = os.path.join(directory, f"{self.source_name}_pipeline.py") dag_filepath = os.path.join( - directory, f"dags/dag_{self.source_name}_pipeline.py") + directory, f"dags/dag_{self.source_name}_pipeline.py" + ) await sdk.ide.setFileOpen(dag_filepath) # Replace the pipeline name and dataset name - await sdk.run_step(FindAndReplaceStep(filepath=pipeline_filepath, pattern="'pipeline_name'", replacement=f"'{self.source_name}_pipeline'")) - await sdk.run_step(FindAndReplaceStep(filepath=pipeline_filepath, pattern="'dataset_name'", replacement=f"'{self.source_name}_data'")) - await sdk.run_step(FindAndReplaceStep(filepath=pipeline_filepath, pattern="pipeline_or_source_script", replacement=f"{self.source_name}_pipeline")) + await sdk.run_step( + FindAndReplaceStep( + filepath=pipeline_filepath, + pattern="'pipeline_name'", + replacement=f"'{self.source_name}_pipeline'", + ) + ) + await sdk.run_step( + FindAndReplaceStep( + filepath=pipeline_filepath, + pattern="'dataset_name'", + replacement=f"'{self.source_name}_data'", + ) + ) + await sdk.run_step( + FindAndReplaceStep( + filepath=pipeline_filepath, + pattern="pipeline_or_source_script", + replacement=f"{self.source_name}_pipeline", + ) + ) # Prompt the user for the DAG schedule # edit_dag_range = Range.from_shorthand(18, 0, 23, 0) @@ -85,4 +117,9 @@ class DeployAirflowStep(Step): # range=edit_dag_range) # Tell the user to check the schedule and fill in owner, email, other default_args - await sdk.run_step(MessageStep(message="Fill in the owner, email, and other default_args in the DAG file with your own personal information. Then the DAG will be ready to run!", name="Fill in default_args")) + await sdk.run_step( + MessageStep( + message="Fill in the owner, email, and other default_args in the DAG file with your own personal information. Then the DAG will be ready to run!", + name="Fill in default_args", + ) + ) diff --git a/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/main.py index 197abe85..2ca65b8e 100644 --- a/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/main.py @@ -1,7 +1,7 @@ from typing import Coroutine -from ....core.main import Step, Observation -from ....core.sdk import ContinueSDK -from ....core.sdk import Models + +from ....core.main import Observation, Step +from ....core.sdk import ContinueSDK, Models class TemplateRecipe(Step): @@ -25,5 +25,5 @@ class TemplateRecipe(Step): visible_files = await sdk.ide.getVisibleFiles() await sdk.edit_file( filename=visible_files[0], - prompt=f"Append a statement to print `Hello, {self.name}!` at the end of the file." + prompt=f"Append a statement to print `Hello, {self.name}!` at the end of the file.", ) diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index c66cd629..e2712746 100644 --- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -1,9 +1,10 @@ +import os from textwrap import dedent from typing import Union -from ....models.filesystem_edit import AddDirectory, AddFile + from ....core.main import Step from ....core.sdk import ContinueSDK -import os +from ....models.filesystem_edit import AddDirectory, AddFile class WritePytestsRecipe(Step): @@ -30,7 +31,8 @@ class WritePytestsRecipe(Step): for_file_contents = await sdk.ide.readFile(self.for_filepath) - prompt = dedent(f"""\ + prompt = dedent( + f"""\ This is the file you will write unit tests for: ```python @@ -41,7 +43,8 @@ class WritePytestsRecipe(Step): "{self.user_input}" - Here is a complete set of pytest unit tests:""") + Here is a complete set of pytest unit tests:""" + ) tests = await sdk.models.medium.complete(prompt) await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index 658cc7f3..25633942 100644 --- a/continuedev/src/continuedev/plugins/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -1,9 +1,9 @@ from textwrap import dedent from typing import Coroutine, Union -from ...core.observation import Observation, TextObservation + from ...core.main import Step +from ...core.observation import Observation from ...core.sdk import ContinueSDK -from .core.core import EditFileStep from ...libs.chroma.query import ChromaIndexManager from .core.core import EditFileStep @@ -42,11 +42,12 @@ class AnswerQuestionChroma(Step): files = [] for node in results.source_nodes: resource_name = list(node.node.relationships.values())[0] - filepath = resource_name[:resource_name.index("::")] + filepath = resource_name[: resource_name.index("::")] files.append(filepath) code_snippets += f"""{filepath}```\n{node.node.text}\n```\n\n""" - prompt = dedent(f"""Here are a few snippets of code that might be useful in answering the question: + prompt = dedent( + f"""Here are a few snippets of code that might be useful in answering the question: {code_snippets} @@ -54,7 +55,8 @@ class AnswerQuestionChroma(Step): {self.question} - Here is the answer:""") + Here is the answer:""" + ) answer = await sdk.models.medium.complete(prompt) # Make paths relative to the workspace directory @@ -73,8 +75,12 @@ class EditFileChroma(Step): index = ChromaIndexManager(await sdk.ide.getWorkspaceDirectory()) results = index.query_codebase_index(self.request) - resource_name = list( - results.source_nodes[0].node.relationships.values())[0] - filepath = resource_name[:resource_name.index("::")] + resource_name = list(results.source_nodes[0].node.relationships.values())[0] + filepath = resource_name[: resource_name.index("::")] - await sdk.run_step(EditFileStep(filepath=filepath, prompt=f"Here is the code:\n\n{{code}}\n\nHere is the user request:\n\n{self.request}\n\nHere is the code after making the requested changes:\n")) + await sdk.run_step( + EditFileStep( + filepath=filepath, + prompt=f"Here is the code:\n\n{{code}}\n\nHere is the user request:\n\n{self.request}\n\nHere is the code after making the requested changes:\n", + ) + ) diff --git a/continuedev/src/continuedev/plugins/steps/comment_code.py b/continuedev/src/continuedev/plugins/steps/comment_code.py index 3e34ab52..1eee791d 100644 --- a/continuedev/src/continuedev/plugins/steps/comment_code.py +++ b/continuedev/src/continuedev/plugins/steps/comment_code.py @@ -9,4 +9,8 @@ class CommentCodeStep(Step): return "Writing comments" async def run(self, sdk: ContinueSDK): - await sdk.run_step(EditHighlightedCodeStep(user_input="Write comprehensive comments in the canonical format for every class and function")) + await sdk.run_step( + EditHighlightedCodeStep( + user_input="Write comprehensive comments in the canonical format for every class and function" + ) + ) diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 1ef201da..86569dcb 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -1,25 +1,36 @@ # These steps are depended upon by ContinueSDK -import os -import json import difflib -from textwrap import dedent +import os import traceback +from textwrap import dedent from typing import Any, Coroutine, List, Union -import difflib from pydantic import validator +from ....core.main import ChatMessage, ContinueCustomException, Step +from ....core.observation import ( + Observation, + TextObservation, + UserInputObservation, +) from ....libs.llm.ggml import GGML -# from ....libs.llm.replicate import ReplicateLLM -from ....models.main import Range from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI -from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit -from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents -from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation -from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS -from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes +from ....libs.util.strings import ( + dedent_and_get_common_whitespace, + remove_quotes_and_escapes, +) from ....libs.util.telemetry import posthog_logger +from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents +from ....models.filesystem_edit import ( + EditDiff, + FileEdit, + FileEditWithFullContents, + FileSystemEdit, +) + +# from ....libs.llm.replicate import ReplicateLLM +from ....models.main import Range class ContinueSDK: @@ -56,7 +67,7 @@ class DisplayErrorStep(Step): @validator("e", pre=True, always=True) def validate_e(cls, v): if isinstance(v, Exception): - return '\n'.join(traceback.format_exception(v)) + return "\n".join(traceback.format_exception(v)) async def describe(self, models: Models) -> Coroutine[str, None, None]: return self.e @@ -100,25 +111,41 @@ class ShellCommandsStep(Step): return f"Error when running shell commands:\n```\n{self._err_text}\n```" cmds_str = "\n".join(self.cmds) - return await models.medium.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:") + return await models.medium.complete( + f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:" + ) async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd + await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd for cmd in self.cmds: output = await sdk.ide.runCommand(cmd) - if self.handle_error and output is not None and output_contains_error(output): - suggestion = await sdk.models.medium.complete(dedent(f"""\ + if ( + self.handle_error + and output is not None + and output_contains_error(output) + ): + suggestion = await sdk.models.medium.complete( + dedent( + f"""\ While running the command `{cmd}`, the following error occurred: ```ascii {output} ``` - This is a brief summary of the error followed by a suggestion on how it can be fixed:"""), with_history=await sdk.get_chat_context()) + This is a brief summary of the error followed by a suggestion on how it can be fixed:""" + ), + with_history=await sdk.get_chat_context(), + ) sdk.raise_exception( - title="Error while running query", message=output, with_step=MessageStep(name=f"Suggestion to solve error {AI_ASSISTED_STRING}", message=f"{suggestion}\n\nYou can click the retry button on the failed step to try again.") + title="Error while running query", + message=output, + with_step=MessageStep( + name=f"Suggestion to solve error {AI_ASSISTED_STRING}", + message=f"{suggestion}\n\nYou can click the retry button on the failed step to try again.", + ), ) return TextObservation(text=output) @@ -143,7 +170,8 @@ class DefaultModelEditCodeStep(Step): name: str = "Editing Code" hide = False description: str = "" - _prompt: str = dedent("""\ + _prompt: str = dedent( + """\ Take the file prefix and suffix into account, but only rewrite the code_to_edit as specified in the user_request. The code you write in modified_code_to_edit will replace the code between the code_to_edit tags. Do NOT preface your answer or write anything other than code. The </modified_code_to_edit> tag should be written to indicate the end of the modified code section. Do not ever use nested tags. Example: @@ -177,7 +205,8 @@ class DefaultModelEditCodeStep(Step): </modified_code_to_edit> Main task: - """) + """ + ) _previous_contents: str = "" _new_contents: str = "" _prompt_and_completion: str = "" @@ -186,22 +215,34 @@ class DefaultModelEditCodeStep(Step): if self._previous_contents.strip() == self._new_contents.strip(): description = "No edits were made" else: - changes = '\n'.join(difflib.ndiff( - self._previous_contents.splitlines(), self._new_contents.splitlines())) - description = await models.medium.complete(dedent(f"""\ + changes = "\n".join( + difflib.ndiff( + self._previous_contents.splitlines(), + self._new_contents.splitlines(), + ) + ) + description = await models.medium.complete( + dedent( + f"""\ Diff summary: "{self.user_input}" ```diff {changes} ``` - Please give brief a description of the changes made above using markdown bullet points. Be concise:""")) - name = await models.medium.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:") + Please give brief a description of the changes made above using markdown bullet points. Be concise:""" + ) + ) + name = await models.medium.complete( + f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:" + ) self.name = remove_quotes_and_escapes(name) return f"{remove_quotes_and_escapes(description)}" - async def get_prompt_parts(self, rif: RangeInFileWithContents, sdk: ContinueSDK, full_file_contents: str): + async def get_prompt_parts( + self, rif: RangeInFileWithContents, sdk: ContinueSDK, full_file_contents: str + ): # We don't know here all of the functions being passed in. # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. @@ -209,18 +250,27 @@ class DefaultModelEditCodeStep(Step): max_tokens = int(model_to_use.context_length / 2) TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 - if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: + if ( + model_to_use.count_tokens(rif.contents) + > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE + ): self.description += "\n\n**It looks like you've selected a large range to edit, which may take a while to complete. If you'd like to cancel, click the 'X' button above. If you highlight a more specific range, Continue will only edit within it.**" # At this point, we also increase the max_tokens parameter so it doesn't stop in the middle of generation # Increase max_tokens to be double the size of the range # But don't exceed twice default max tokens - max_tokens = int(min(model_to_use.count_tokens( - rif.contents), DEFAULT_MAX_TOKENS) * 2.5) + max_tokens = int( + min(model_to_use.count_tokens(rif.contents), DEFAULT_MAX_TOKENS) * 2.5 + ) BUFFER_FOR_FUNCTIONS = 400 - total_tokens = model_to_use.count_tokens( - full_file_contents + self._prompt + self.user_input) + BUFFER_FOR_FUNCTIONS + max_tokens + total_tokens = ( + model_to_use.count_tokens( + full_file_contents + self._prompt + self.user_input + ) + + BUFFER_FOR_FUNCTIONS + + max_tokens + ) # If using 3.5 and overflows, upgrade to 3.5.16k if model_to_use.name == "gpt-3.5-turbo": @@ -239,7 +289,8 @@ class DefaultModelEditCodeStep(Step): if total_tokens > model_to_use.context_length: while cur_end_line > min_end_line: total_tokens -= model_to_use.count_tokens( - full_file_contents_lst[cur_end_line]) + full_file_contents_lst[cur_end_line] + ) cur_end_line -= 1 if total_tokens < model_to_use.context_length: break @@ -248,33 +299,30 @@ class DefaultModelEditCodeStep(Step): while cur_start_line < max_start_line: cur_start_line += 1 total_tokens -= model_to_use.count_tokens( - full_file_contents_lst[cur_start_line]) + full_file_contents_lst[cur_start_line] + ) if total_tokens < model_to_use.context_length: break # Now use the found start/end lines to get the prefix and suffix strings - file_prefix = "\n".join( - full_file_contents_lst[cur_start_line:max_start_line]) - file_suffix = "\n".join( - full_file_contents_lst[min_end_line:cur_end_line - 1]) + file_prefix = "\n".join(full_file_contents_lst[cur_start_line:max_start_line]) + file_suffix = "\n".join(full_file_contents_lst[min_end_line : cur_end_line - 1]) # Move any surrounding blank line in rif.contents to the prefix/suffix # TODO: Keep track of start line of the range, because it's needed below for offset stuff - rif_start_line = rif.range.start.line if len(rif.contents) > 0: lines = rif.contents.splitlines(keepends=True) first_line = lines[0] if lines else None while first_line and first_line.strip() == "": file_prefix += first_line - rif.contents = rif.contents[len(first_line):] + rif.contents = rif.contents[len(first_line) :] lines = rif.contents.splitlines(keepends=True) first_line = lines[0] if lines else None last_line = lines[-1] if lines else None while last_line and last_line.strip() == "": file_suffix = last_line + file_suffix - rif.contents = rif.contents[:len( - rif.contents) - len(last_line)] + rif.contents = rif.contents[: len(rif.contents) - len(last_line)] lines = rif.contents.splitlines(keepends=True) last_line = lines[-1] if lines else None @@ -287,10 +335,13 @@ class DefaultModelEditCodeStep(Step): return file_prefix, rif.contents, file_suffix, model_to_use, max_tokens - def compile_prompt(self, file_prefix: str, contents: str, file_suffix: str, sdk: ContinueSDK) -> str: + def compile_prompt( + self, file_prefix: str, contents: str, file_suffix: str, sdk: ContinueSDK + ) -> str: if contents.strip() == "": # Seperate prompt for insertion at the cursor, the other tends to cause it to repeat whole file - prompt = dedent(f"""\ + prompt = dedent( + f"""\ <file_prefix> {file_prefix} </file_prefix> @@ -302,30 +353,39 @@ class DefaultModelEditCodeStep(Step): {self.user_input} </user_request> -Please output the code to be inserted at the cursor in order to fulfill the user_request. Do NOT preface your answer or write anything other than code. You should not write any tags, just the code. Make sure to correctly indent the code:""") +Please output the code to be inserted at the cursor in order to fulfill the user_request. Do NOT preface your answer or write anything other than code. You should not write any tags, just the code. Make sure to correctly indent the code:""" + ) return prompt prompt = self._prompt if file_prefix.strip() != "": - prompt += dedent(f""" + prompt += dedent( + f""" <file_prefix> {file_prefix} -</file_prefix>""") - prompt += dedent(f""" +</file_prefix>""" + ) + prompt += dedent( + f""" <code_to_edit> {contents} -</code_to_edit>""") +</code_to_edit>""" + ) if file_suffix.strip() != "": - prompt += dedent(f""" + prompt += dedent( + f""" <file_suffix> {file_suffix} -</file_suffix>""") - prompt += dedent(f""" +</file_suffix>""" + ) + prompt += dedent( + f""" <user_request> {self.user_input} </user_request> <modified_code_to_edit> -""") +""" + ) return prompt @@ -333,28 +393,44 @@ Please output the code to be inserted at the cursor in order to fulfill the user return "</modified_code_to_edit>" in line or "</code_to_edit>" in line def line_to_be_ignored(self, line: str, is_first_line: bool = False) -> bool: - return "```" in line or "<modified_code_to_edit>" in line or "<file_prefix>" in line or "</file_prefix>" in line or "<file_suffix>" in line or "</file_suffix>" in line or "<user_request>" in line or "</user_request>" in line or "<code_to_edit>" in line + return ( + "```" in line + or "<modified_code_to_edit>" in line + or "<file_prefix>" in line + or "</file_prefix>" in line + or "<file_suffix>" in line + or "</file_suffix>" in line + or "<user_request>" in line + or "</user_request>" in line + or "<code_to_edit>" in line + ) async def stream_rif(self, rif: RangeInFileWithContents, sdk: ContinueSDK): await sdk.ide.saveFile(rif.filepath) full_file_contents = await sdk.ide.readFile(rif.filepath) - file_prefix, contents, file_suffix, model_to_use, max_tokens = await self.get_prompt_parts( - rif, sdk, full_file_contents) - contents, common_whitespace = dedent_and_get_common_whitespace( - contents) + ( + file_prefix, + contents, + file_suffix, + model_to_use, + max_tokens, + ) = await self.get_prompt_parts(rif, sdk, full_file_contents) + contents, common_whitespace = dedent_and_get_common_whitespace(contents) prompt = self.compile_prompt(file_prefix, contents, file_suffix, sdk) full_file_contents_lines = full_file_contents.split("\n") lines_to_display = [] - async def sendDiffUpdate(lines: List[str], sdk: ContinueSDK, final: bool = False): + async def sendDiffUpdate( + lines: List[str], sdk: ContinueSDK, final: bool = False + ): nonlocal full_file_contents_lines, rif, lines_to_display completion = "\n".join(lines) - full_prefix_lines = full_file_contents_lines[:rif.range.start.line] - full_suffix_lines = full_file_contents_lines[rif.range.end.line:] + full_prefix_lines = full_file_contents_lines[: rif.range.start.line] + full_suffix_lines = full_file_contents_lines[rif.range.end.line :] # Don't do this at the very end, just show the inserted code if final: @@ -365,13 +441,29 @@ Please output the code to be inserted at the cursor in order to fulfill the user rewritten_lines = 0 for line in lines: for i in range(rewritten_lines, len(contents_lines)): - if difflib.SequenceMatcher(None, line, contents_lines[i]).ratio() > 0.7 and contents_lines[i].strip() != "": + if ( + difflib.SequenceMatcher( + None, line, contents_lines[i] + ).ratio() + > 0.7 + and contents_lines[i].strip() != "" + ): rewritten_lines = i + 1 break lines_to_display = contents_lines[rewritten_lines:] - new_file_contents = "\n".join( - full_prefix_lines) + "\n" + completion + "\n" + ("\n".join(lines_to_display) + "\n" if len(lines_to_display) > 0 else "") + "\n".join(full_suffix_lines) + new_file_contents = ( + "\n".join(full_prefix_lines) + + "\n" + + completion + + "\n" + + ( + "\n".join(lines_to_display) + "\n" + if len(lines_to_display) > 0 + else "" + ) + + "\n".join(full_suffix_lines) + ) step_index = sdk.history.current_index @@ -403,30 +495,61 @@ Please output the code to be inserted at the cursor in order to fulfill the user # Highlight the line to show progress line_to_highlight = current_line_in_file - len(current_block_lines) if False: - await sdk.ide.highlightCode(RangeInFile(filepath=rif.filepath, range=Range.from_shorthand( - line_to_highlight, 0, line_to_highlight, 0)), "#FFFFFF22" if len(current_block_lines) == 0 else "#00FF0022") + await sdk.ide.highlightCode( + RangeInFile( + filepath=rif.filepath, + range=Range.from_shorthand( + line_to_highlight, 0, line_to_highlight, 0 + ), + ), + "#FFFFFF22" if len(current_block_lines) == 0 else "#00FF0022", + ) if len(current_block_lines) == 0: # Set this as the start of the next block - current_block_start = rif.range.start.line + len(original_lines) - len( - original_lines_below_previous_blocks) + offset_from_blocks - if len(original_lines_below_previous_blocks) > 0 and line == original_lines_below_previous_blocks[0]: + current_block_start = ( + rif.range.start.line + + len(original_lines) + - len(original_lines_below_previous_blocks) + + offset_from_blocks + ) + if ( + len(original_lines_below_previous_blocks) > 0 + and line == original_lines_below_previous_blocks[0] + ): # Line is equal to the next line in file, move past this line - original_lines_below_previous_blocks = original_lines_below_previous_blocks[ - 1:] + original_lines_below_previous_blocks = ( + original_lines_below_previous_blocks[1:] + ) return # In a block, and have already matched at least one line # Check if the next line matches, for each of the candidates matches_found = [] first_valid_match = None - for index_of_last_matched_line, num_lines_matched in indices_of_last_matched_lines: - if index_of_last_matched_line + 1 < len(original_lines_below_previous_blocks) and line == original_lines_below_previous_blocks[index_of_last_matched_line + 1]: + for ( + index_of_last_matched_line, + num_lines_matched, + ) in indices_of_last_matched_lines: + if ( + index_of_last_matched_line + 1 + < len(original_lines_below_previous_blocks) + and line + == original_lines_below_previous_blocks[ + index_of_last_matched_line + 1 + ] + ): matches_found.append( - (index_of_last_matched_line + 1, num_lines_matched + 1)) - if first_valid_match is None and num_lines_matched + 1 >= LINES_TO_MATCH_BEFORE_ENDING_BLOCK: + (index_of_last_matched_line + 1, num_lines_matched + 1) + ) + if ( + first_valid_match is None + and num_lines_matched + 1 >= LINES_TO_MATCH_BEFORE_ENDING_BLOCK + ): first_valid_match = ( - index_of_last_matched_line + 1, num_lines_matched + 1) + index_of_last_matched_line + 1, + num_lines_matched + 1, + ) indices_of_last_matched_lines = matches_found if first_valid_match is not None: @@ -436,7 +559,13 @@ Please output the code to be inserted at the cursor in order to fulfill the user # So here we will strip all matching lines from the end of current_block_lines lines_stripped = [] index_of_last_line_in_block = first_valid_match[0] - while len(current_block_lines) > 0 and current_block_lines[-1] == original_lines_below_previous_blocks[index_of_last_line_in_block - 1]: + while ( + len(current_block_lines) > 0 + and current_block_lines[-1] + == original_lines_below_previous_blocks[ + index_of_last_line_in_block - 1 + ] + ): lines_stripped.append(current_block_lines.pop()) index_of_last_line_in_block -= 1 @@ -455,18 +584,22 @@ Please output the code to be inserted at the cursor in order to fulfill the user end_line = current_block_start + index_of_last_line_in_block if False: - await sdk.ide.showSuggestion(FileEdit( - filepath=rif.filepath, - range=Range.from_shorthand( - start_line, 0, end_line, 0), - replacement=replacement - )) + await sdk.ide.showSuggestion( + FileEdit( + filepath=rif.filepath, + range=Range.from_shorthand(start_line, 0, end_line, 0), + replacement=replacement, + ) + ) # Reset current block / update variables current_line_in_file += 1 offset_from_blocks += len(current_block_lines) - original_lines_below_previous_blocks = original_lines_below_previous_blocks[ - index_of_last_line_in_block + 1:] + original_lines_below_previous_blocks = ( + original_lines_below_previous_blocks[ + index_of_last_line_in_block + 1 : + ] + ) current_block_lines = [] current_block_start = -1 indices_of_last_matched_lines = [] @@ -485,7 +618,8 @@ Please output the code to be inserted at the cursor in order to fulfill the user # Make sure they are sorted by index indices_of_last_matched_lines = sorted( - indices_of_last_matched_lines, key=lambda x: x[0]) + indices_of_last_matched_lines, key=lambda x: x[0] + ) current_block_lines.append(line) @@ -498,11 +632,9 @@ Please output the code to be inserted at the cursor in order to fulfill the user messages.pop(i) deleted += 1 i -= 1 - messages.append(ChatMessage( - role="user", - content=prompt, - summary=self.user_input - )) + messages.append( + ChatMessage(role="user", content=prompt, summary=self.user_input) + ) lines_of_prefix_copied = 0 lines = [] @@ -512,18 +644,22 @@ Please output the code to be inserted at the cursor in order to fulfill the user line_below_highlighted_range = file_suffix.lstrip().split("\n")[0] if isinstance(model_to_use, GGML): - messages = [ChatMessage( - role="user", content=f"```\n{rif.contents}\n```\n\nUser request: \"{self.user_input}\"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n", summary=self.user_input)] + messages = [ + ChatMessage( + role="user", + content=f'```\n{rif.contents}\n```\n\nUser request: "{self.user_input}"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n', + summary=self.user_input, + ) + ] # elif isinstance(model_to_use, ReplicateLLM): # messages = [ChatMessage( # role="user", content=f"// Previous implementation\n\n{rif.contents}\n\n// Updated implementation (after following directions: {self.user_input})\n\n", summary=self.user_input)] generator = model_to_use.stream_chat( - messages, temperature=sdk.config.temperature, max_tokens=max_tokens) + messages, temperature=sdk.config.temperature, max_tokens=max_tokens + ) - posthog_logger.capture_event("model_use", { - "model": model_to_use.name - }) + posthog_logger.capture_event("model_use", {"model": model_to_use.name}) try: async for chunk in generator: @@ -555,16 +691,31 @@ Please output the code to be inserted at the cursor in order to fulfill the user if self.is_end_line(chunk_lines[i]): break # Lines that should be ignored, like the <> tags - elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0): + elif self.line_to_be_ignored( + chunk_lines[i], completion_lines_covered == 0 + ): continue # noice # Check if we are currently just copying the prefix - elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]: + elif ( + (lines_of_prefix_copied > 0 or completion_lines_covered == 0) + and lines_of_prefix_copied < len(file_prefix.splitlines()) + and chunk_lines[i] + == full_file_contents_lines[lines_of_prefix_copied] + ): # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line lines_of_prefix_copied += 1 continue # also nice # Because really short lines might be expected to be repeated, this is only a !heuristic! # Stop when it starts copying the file_suffix - elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()): + elif ( + chunk_lines[i].strip() == line_below_highlighted_range.strip() + and len(chunk_lines[i].strip()) > 4 + and not ( + len(original_lines_below_previous_blocks) > 0 + and chunk_lines[i].strip() + == original_lines_below_previous_blocks[0].strip() + ) + ): repeating_file_suffix = True break @@ -576,11 +727,25 @@ Please output the code to be inserted at the cursor in order to fulfill the user completion_lines_covered += 1 current_line_in_file += 1 - await sendDiffUpdate(lines + [common_whitespace if unfinished_line.startswith("<") else (common_whitespace + unfinished_line)], sdk) + await sendDiffUpdate( + lines + + [ + common_whitespace + if unfinished_line.startswith("<") + else (common_whitespace + unfinished_line) + ], + sdk, + ) finally: await generator.aclose() # Add the unfinished line - if unfinished_line != "" and not self.line_to_be_ignored(unfinished_line, completion_lines_covered == 0) and not self.is_end_line(unfinished_line): + if ( + unfinished_line != "" + and not self.line_to_be_ignored( + unfinished_line, completion_lines_covered == 0 + ) + and not self.is_end_line(unfinished_line) + ): unfinished_line = common_whitespace + unfinished_line lines.append(unfinished_line) await handle_generated_line(unfinished_line) @@ -598,13 +763,19 @@ Please output the code to be inserted at the cursor in order to fulfill the user for i in range(-1, -len(current_block_lines) - 1, -1): if len(original_lines_below_previous_blocks) == 0: break - if current_block_lines[i] == original_lines_below_previous_blocks[-1]: + if ( + current_block_lines[i] + == original_lines_below_previous_blocks[-1] + ): num_to_remove += 1 original_lines_below_previous_blocks.pop() else: break - current_block_lines = current_block_lines[:- - num_to_remove] if num_to_remove > 0 else current_block_lines + current_block_lines = ( + current_block_lines[:-num_to_remove] + if num_to_remove > 0 + else current_block_lines + ) # It's also possible that some lines match at the beginning of the block # while len(current_block_lines) > 0 and len(original_lines_below_previous_blocks) > 0 and current_block_lines[0] == original_lines_below_previous_blocks[0]: @@ -612,12 +783,19 @@ Please output the code to be inserted at the cursor in order to fulfill the user # original_lines_below_previous_blocks.pop(0) # current_block_start += 1 - await sdk.ide.showSuggestion(FileEdit( - filepath=rif.filepath, - range=Range.from_shorthand( - current_block_start, 0, current_block_start + len(original_lines_below_previous_blocks), 0), - replacement="\n".join(current_block_lines) - )) + await sdk.ide.showSuggestion( + FileEdit( + filepath=rif.filepath, + range=Range.from_shorthand( + current_block_start, + 0, + current_block_start + + len(original_lines_below_previous_blocks), + 0, + ), + replacement="\n".join(current_block_lines), + ) + ) # Record the completion completion = "\n".join(lines) @@ -629,14 +807,18 @@ Please output the code to be inserted at the cursor in order to fulfill the user await sdk.update_ui() rif_with_contents = [] - for range_in_file in map(lambda x: RangeInFile( - filepath=x.filepath, - # Only consider the range line-by-line. Maybe later don't if it's only a single line. - range=x.range.to_full_lines() - ), self.range_in_files): + for range_in_file in map( + lambda x: RangeInFile( + filepath=x.filepath, + # Only consider the range line-by-line. Maybe later don't if it's only a single line. + range=x.range.to_full_lines(), + ), + self.range_in_files, + ): file_contents = await sdk.ide.readRangeInFile(range_in_file) rif_with_contents.append( - RangeInFileWithContents.from_range_in_file(range_in_file, file_contents)) + RangeInFileWithContents.from_range_in_file(range_in_file, file_contents) + ) rif_dict = {} for rif in rif_with_contents: @@ -645,10 +827,10 @@ Please output the code to be inserted at the cursor in order to fulfill the user for rif in rif_with_contents: # If the file doesn't exist, ask them to save it first if not os.path.exists(rif.filepath): - message = f"The file {rif.filepath} does not exist. Please save it first." - raise ContinueCustomException( - title=message, message=message + message = ( + f"The file {rif.filepath} does not exist. Please save it first." ) + raise ContinueCustomException(title=message, message=message) await sdk.ide.setFileOpen(rif.filepath) await sdk.ide.setSuggestionsLocked(rif.filepath, True) @@ -666,11 +848,14 @@ class EditFileStep(Step): async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: file_contents = await sdk.ide.readFile(self.filepath) - await sdk.run_step(DefaultModelEditCodeStep( - range_in_files=[RangeInFile.from_entire_file( - self.filepath, file_contents)], - user_input=self.prompt - )) + await sdk.run_step( + DefaultModelEditCodeStep( + range_in_files=[ + RangeInFile.from_entire_file(self.filepath, file_contents) + ], + user_input=self.prompt, + ) + ) class ManualEditStep(ReversibleStep): @@ -698,8 +883,7 @@ class ManualEditStep(ReversibleStep): def from_sequence(cls, edits: List[FileEditWithFullContents]) -> "ManualEditStep": diffs = [] for edit in edits: - _, diff = FileSystem.apply_edit_to_str( - edit.fileContents, edit.fileEdit) + _, diff = FileSystem.apply_edit_to_str(edit.fileContents, edit.fileEdit) diffs.append(diff) return cls(edit_diff=EditDiff.from_sequence(diffs)) @@ -720,12 +904,12 @@ class UserInputStep(Step): async def describe(self, models: Models) -> Coroutine[str, None, None]: return self.user_input - async def run(self, sdk: ContinueSDK) -> Coroutine[UserInputObservation, None, None]: - self.chat_context.append(ChatMessage( - role="user", - content=self.user_input, - summary=self.user_input - )) + async def run( + self, sdk: ContinueSDK + ) -> Coroutine[UserInputObservation, None, None]: + self.chat_context.append( + ChatMessage(role="user", content=self.user_input, summary=self.user_input) + ) return UserInputObservation(user_input=self.user_input) diff --git a/continuedev/src/continuedev/plugins/steps/custom_command.py b/continuedev/src/continuedev/plugins/steps/custom_command.py index 419b3c3d..4128415b 100644 --- a/continuedev/src/continuedev/plugins/steps/custom_command.py +++ b/continuedev/src/continuedev/plugins/steps/custom_command.py @@ -1,6 +1,6 @@ -from ...libs.util.templating import render_templated_string from ...core.main import Step from ...core.sdk import ContinueSDK, Models +from ...libs.util.templating import render_templated_string from ..steps.chat import SimpleChatStep @@ -21,8 +21,9 @@ class CustomCommandStep(Step): messages = await sdk.get_chat_context() # Find the last chat message with this slash command and replace it with the user input for i in range(len(messages) - 1, -1, -1): - if messages[i].role == "user" and messages[i].content.startswith(self.slash_command): - messages[i] = messages[i].copy( - update={"content": prompt_user_input}) + if messages[i].role == "user" and messages[i].content.startswith( + self.slash_command + ): + messages[i] = messages[i].copy(update={"content": prompt_user_input}) break await sdk.run_step(SimpleChatStep(messages=messages)) diff --git a/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py index f3131c4b..1d135b3e 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py +++ b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py @@ -1,5 +1,5 @@ -from ....core.sdk import ContinueSDK from ....core.main import Step +from ....core.sdk import ContinueSDK class ImplementAbstractMethodStep(Step): @@ -8,11 +8,9 @@ class ImplementAbstractMethodStep(Step): class_name: str async def run(self, sdk: ContinueSDK): - implementations = await sdk.lsp.go_to_implementations(self.class_name) for implementation in implementations: - await sdk.edit_file( range_in_files=[implementation.range_in_file], prompt=f"Implement method `{self.method_name}` for this subclass of `{self.class_name}`", diff --git a/continuedev/src/continuedev/plugins/steps/draft/redux.py b/continuedev/src/continuedev/plugins/steps/draft/redux.py index 30c8fdbb..5a351e6f 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/redux.py +++ b/continuedev/src/continuedev/plugins/steps/draft/redux.py @@ -4,7 +4,6 @@ from ..core.core import EditFileStep class EditReduxStateStep(Step): - description: str # e.g. "I want to load data from the weatherapi.com API" async def run(self, sdk: ContinueSDK): @@ -15,24 +14,28 @@ class EditReduxStateStep(Step): sdk.run_step( EditFileStep( filename=store_filename, - prompt=f"Edit the root store to add a new slice for {self.description}" + prompt=f"Edit the root store to add a new slice for {self.description}", ) ) store_file_contents = await sdk.ide.readFile(store_filename) # Selector selector_filename = "" - sdk.run_step(EditFileStep( - filepath=selector_filename, - prompt=f"Edit the selector to add a new property for {self.description}. The store looks like this: {store_file_contents}" - )) + sdk.run_step( + EditFileStep( + filepath=selector_filename, + prompt=f"Edit the selector to add a new property for {self.description}. The store looks like this: {store_file_contents}", + ) + ) # Reducer reducer_filename = "" - sdk.run_step(EditFileStep( - filepath=reducer_filename, - prompt=f"Edit the reducer to add a new property for {self.description}. The store looks like this: {store_file_contents}" - )) + sdk.run_step( + EditFileStep( + filepath=reducer_filename, + prompt=f"Edit the reducer to add a new property for {self.description}. The store looks like this: {store_file_contents}", + ) + ) """ Starts with implementing selector 1. RootStore diff --git a/continuedev/src/continuedev/plugins/steps/draft/typeorm.py b/continuedev/src/continuedev/plugins/steps/draft/typeorm.py index d06a6fb4..c79fa041 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/typeorm.py +++ b/continuedev/src/continuedev/plugins/steps/draft/typeorm.py @@ -1,4 +1,5 @@ from textwrap import dedent + from ....core.main import Step from ....core.sdk import ContinueSDK @@ -12,17 +13,23 @@ class CreateTableStep(Step): entity_name = self.sql_str.split(" ")[2].capitalize() await sdk.edit_file( f"src/entity/{entity_name}.ts", - dedent(f"""\ + dedent( + f"""\ {self.sql_str} - Write a TypeORM entity called {entity_name} for this table, importing as necessary:""") + Write a TypeORM entity called {entity_name} for this table, importing as necessary:""" + ), ) # Add entity to data-source.ts - await sdk.edit_file(filepath="src/data-source.ts", prompt=f"Add the {entity_name} entity:") + await sdk.edit_file( + filepath="src/data-source.ts", prompt=f"Add the {entity_name} entity:" + ) # Generate blank migration for the entity - out = await sdk.run(f"npx typeorm migration:create ./src/migration/Create{entity_name}Table") + out = await sdk.run( + f"npx typeorm migration:create ./src/migration/Create{entity_name}Table" + ) migration_filepath = out.text.split(" ")[1] # Wait for user input @@ -31,13 +38,17 @@ class CreateTableStep(Step): # Fill in the migration await sdk.edit_file( migration_filepath, - dedent(f"""\ + dedent( + f"""\ This is the table that was created: {self.sql_str} - Fill in the migration for the table:"""), + Fill in the migration for the table:""" + ), ) # Run the migration - await sdk.run("npx typeorm-ts-node-commonjs migration:run -d ./src/data-source.ts") + await sdk.run( + "npx typeorm-ts-node-commonjs migration:run -d ./src/data-source.ts" + ) diff --git a/continuedev/src/continuedev/plugins/steps/feedback.py b/continuedev/src/continuedev/plugins/steps/feedback.py index fa56a4d9..df1142a1 100644 --- a/continuedev/src/continuedev/plugins/steps/feedback.py +++ b/continuedev/src/continuedev/plugins/steps/feedback.py @@ -1,6 +1,4 @@ -from typing import Coroutine -from ...core.main import Models -from ...core.main import Step +from ...core.main import Models, Step from ...core.sdk import ContinueSDK from ...libs.util.telemetry import posthog_logger diff --git a/continuedev/src/continuedev/plugins/steps/find_and_replace.py b/continuedev/src/continuedev/plugins/steps/find_and_replace.py index a2c9c44e..287e286d 100644 --- a/continuedev/src/continuedev/plugins/steps/find_and_replace.py +++ b/continuedev/src/continuedev/plugins/steps/find_and_replace.py @@ -1,6 +1,6 @@ -from ...models.filesystem_edit import FileEdit, Range from ...core.main import Models, Step from ...core.sdk import ContinueSDK +from ...models.filesystem_edit import FileEdit, Range class FindAndReplaceStep(Step): @@ -17,12 +17,14 @@ class FindAndReplaceStep(Step): while self.pattern in file_content: start_index = file_content.index(self.pattern) end_index = start_index + len(self.pattern) - await sdk.ide.applyFileSystemEdit(FileEdit( - filepath=self.filepath, - range=Range.from_indices( - file_content, start_index, end_index - 1), - replacement=self.replacement - )) - file_content = file_content[:start_index] + \ - self.replacement + file_content[end_index:] + await sdk.ide.applyFileSystemEdit( + FileEdit( + filepath=self.filepath, + range=Range.from_indices(file_content, start_index, end_index - 1), + replacement=self.replacement, + ) + ) + file_content = ( + file_content[:start_index] + self.replacement + file_content[end_index:] + ) await sdk.ide.saveFile(self.filepath) diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index 82f885d6..148dddb8 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -1,9 +1,11 @@ from textwrap import dedent + from ...core.main import ChatMessage, Step from ...core.sdk import ContinueSDK from ...libs.util.telemetry import posthog_logger -help = dedent("""\ +help = dedent( + """\ Continue is an open-source coding autopilot. It is a VS Code extension that brings the power of ChatGPT to your IDE. It gathers context for you and stores your interactions automatically, so that you can avoid copy/paste now and benefit from a customized Large Language Model (LLM) later. @@ -23,11 +25,11 @@ help = dedent("""\ If Continue is stuck loading, try using `cmd+shift+p` to open the command palette, search "Reload Window", and then select it. This will reload VS Code and Continue and often fixes issues. - If you have feedback, please use /feedback to let us know how you would like to use Continue. We are excited to hear from you!""") + If you have feedback, please use /feedback to let us know how you would like to use Continue. We are excited to hear from you!""" +) class HelpStep(Step): - name: str = "Help" user_input: str manage_own_chat_context: bool = True @@ -40,7 +42,8 @@ class HelpStep(Step): self.description = help else: self.description = "The following output is generated by a language model, which may hallucinate. Type just '/help'to see a fixed answer. You can also learn more by reading [the docs](https://continue.dev/docs).\n\n" - prompt = dedent(f""" + prompt = dedent( + f""" Information: {help} @@ -49,13 +52,12 @@ class HelpStep(Step): Please us the information below to provide a succinct answer to the following question: {question} - Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback. Never refer or link to any URL.""") + Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback. Never refer or link to any URL.""" + ) - self.chat_context.append(ChatMessage( - role="user", - content=prompt, - summary="Help" - )) + self.chat_context.append( + ChatMessage(role="user", content=prompt, summary="Help") + ) messages = await sdk.get_chat_context() generator = sdk.models.default.stream_chat(messages) async for chunk in generator: @@ -64,4 +66,5 @@ class HelpStep(Step): await sdk.update_ui() posthog_logger.capture_event( - "help", {"question": question, "answer": self.description}) + "help", {"question": question, "answer": self.description} + ) diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py index 3d8d96fb..721f1306 100644 --- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py @@ -1,7 +1,8 @@ from typing import List, Union -from ..core.core import WaitForUserInputStep + from ....core.main import Step from ....core.sdk import ContinueSDK +from ..core.core import WaitForUserInputStep class NLMultiselectStep(Step): @@ -11,7 +12,9 @@ class NLMultiselectStep(Step): options: List[str] async def run(self, sdk: ContinueSDK): - user_response = (await sdk.run_step(WaitForUserInputStep(prompt=self.prompt))).text + user_response = ( + await sdk.run_step(WaitForUserInputStep(prompt=self.prompt)) + ).text def extract_option(text: str) -> Union[str, None]: for option in self.options: @@ -24,5 +27,6 @@ class NLMultiselectStep(Step): return first_try gpt_parsed = await sdk.models.default.complete( - f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:") + f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:" + ) return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index d2d6f4dd..da9cf5b2 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -1,19 +1,19 @@ import os -from typing import Coroutine, List, Union from textwrap import dedent +from typing import Coroutine, List, Union + from pydantic import BaseModel, Field -from ...models.main import Traceback, Range -from ...models.filesystem_edit import EditDiff, FileEdit -from ...models.filesystem import RangeInFile, RangeInFileWithContents -from ...core.observation import Observation -from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder from ...core.main import ContinueCustomException, Step -from ...core.sdk import ContinueSDK, Models from ...core.observation import Observation -from .core.core import DefaultModelEditCodeStep +from ...core.sdk import ContinueSDK, Models +from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder from ...libs.util.calculate_diff import calculate_diff2 from ...libs.util.logging import logger +from ...models.filesystem import RangeInFile, RangeInFileWithContents +from ...models.filesystem_edit import EditDiff, FileEdit +from ...models.main import Range, Traceback +from .core.core import DefaultModelEditCodeStep class Policy(BaseModel): @@ -36,7 +36,8 @@ class FasterEditHighlightedCodeStep(Step): hide = True _completion: str = "Edit Code" _edit_diffs: Union[List[EditDiff], None] = None - _prompt: str = dedent("""\ + _prompt: str = dedent( + """\ You will be given code to edit in order to perfectly satisfy the user request. All the changes you make must be described as replacements, which you should format in the following way: FILEPATH <FILE_TO_EDIT> @@ -75,7 +76,8 @@ class FasterEditHighlightedCodeStep(Step): This is the user request: "{user_input}" Here is the description of changes to make: -""") +""" + ) async def describe(self, models: Models) -> Coroutine[str, None, None]: return "Editing highlighted code" @@ -89,13 +91,14 @@ class FasterEditHighlightedCodeStep(Step): for file in files: contents[file] = await sdk.ide.readFile(file) - range_in_files = [RangeInFileWithContents.from_entire_file( - filepath, content) for filepath, content in contents.items()] + range_in_files = [ + RangeInFileWithContents.from_entire_file(filepath, content) + for filepath, content in contents.items() + ] enc_dec = MarkdownStyleEncoderDecoder(range_in_files) code_string = enc_dec.encode() - prompt = self._prompt.format( - code=code_string, user_input=self.user_input) + prompt = self._prompt.format(code=code_string, user_input=self.user_input) rif_dict = {} for rif in range_in_files: @@ -145,7 +148,14 @@ class FasterEditHighlightedCodeStep(Step): replace_me = edit["replace_me"] replace_with = edit["replace_with"] file_edits.append( - FileEdit(filepath=filepath, range=Range.from_lines_snippet_in_file(content=rif_dict[filepath], snippet=replace_me), replacement=replace_with)) + FileEdit( + filepath=filepath, + range=Range.from_lines_snippet_in_file( + content=rif_dict[filepath], snippet=replace_me + ), + replacement=replace_with, + ) + ) # ------------------------------ self._edit_diffs = [] @@ -169,7 +179,9 @@ class StarCoderEditHighlightedCodeStep(Step): _prompt_and_completion: str = "" async def describe(self, models: Models) -> Coroutine[str, None, None]: - return await models.medium.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:") + return await models.medium.complete( + f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:" + ) async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: range_in_files = await sdk.get_code_context(only_editing=True) @@ -181,8 +193,10 @@ class StarCoderEditHighlightedCodeStep(Step): for file in files: contents[file] = await sdk.ide.readFile(file) - range_in_files = [RangeInFileWithContents.from_entire_file( - filepath, content) for filepath, content in contents.items()] + range_in_files = [ + RangeInFileWithContents.from_entire_file(filepath, content) + for filepath, content in contents.items() + ] rif_dict = {} for rif in range_in_files: @@ -190,7 +204,8 @@ class StarCoderEditHighlightedCodeStep(Step): for rif in range_in_files: prompt = self._prompt.format( - code=rif.contents, user_request=self.user_input) + code=rif.contents, user_request=self.user_input + ) if found_highlighted_code: full_file_contents = await sdk.ide.readFile(rif.filepath) @@ -208,7 +223,8 @@ class StarCoderEditHighlightedCodeStep(Step): self._prompt_and_completion += prompt + completion edits = calculate_diff2( - rif.filepath, rif.contents, completion.removesuffix("\n")) + rif.filepath, rif.contents, completion.removesuffix("\n") + ) for edit in edits: await sdk.ide.applyFileSystemEdit(edit) @@ -220,7 +236,10 @@ class StarCoderEditHighlightedCodeStep(Step): class EditHighlightedCodeStep(Step): user_input: str = Field( - ..., title="User Input", description="The natural language request describing how to edit the code") + ..., + title="User Input", + description="The natural language request describing how to edit the code", + ) hide = True description: str = "Change the contents of the currently highlighted code or open file. You should call this function if the user asks seems to be asking for a code change." @@ -235,28 +254,44 @@ class EditHighlightedCodeStep(Step): highlighted_code = await sdk.ide.getHighlightedCode() if highlighted_code is not None: for rif in highlighted_code: - if os.path.dirname(rif.filepath) == os.path.expanduser(os.path.join("~", ".continue", "diffs")): + if os.path.dirname(rif.filepath) == os.path.expanduser( + os.path.join("~", ".continue", "diffs") + ): raise ContinueCustomException( - message="Please accept or reject the change before making another edit in this file.", title="Accept/Reject First") + message="Please accept or reject the change before making another edit in this file.", + title="Accept/Reject First", + ) if rif.range.start == rif.range.end: range_in_files.append( - RangeInFileWithContents.from_range_in_file(rif, "")) + RangeInFileWithContents.from_range_in_file(rif, "") + ) # If still no highlighted code, raise error if len(range_in_files) == 0: raise ContinueCustomException( - message="Please highlight some code and try again.", title="No Code Selected") + message="Please highlight some code and try again.", + title="No Code Selected", + ) - range_in_files = list(map(lambda x: RangeInFile( - filepath=x.filepath, range=x.range - ), range_in_files)) + range_in_files = list( + map( + lambda x: RangeInFile(filepath=x.filepath, range=x.range), + range_in_files, + ) + ) for range_in_file in range_in_files: - if os.path.dirname(range_in_file.filepath) == os.path.expanduser(os.path.join("~", ".continue", "diffs")): + if os.path.dirname(range_in_file.filepath) == os.path.expanduser( + os.path.join("~", ".continue", "diffs") + ): self.description = "Please accept or reject the change before making another edit in this file." return - await sdk.run_step(DefaultModelEditCodeStep(user_input=self.user_input, range_in_files=range_in_files)) + await sdk.run_step( + DefaultModelEditCodeStep( + user_input=self.user_input, range_in_files=range_in_files + ) + ) class UserInputStep(Step): @@ -270,7 +305,8 @@ class SolveTracebackStep(Step): return f"```\n{self.traceback.full_traceback}\n```" async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - prompt = dedent("""I ran into this problem with my Python code: + prompt = dedent( + """I ran into this problem with my Python code: {traceback} @@ -279,15 +315,17 @@ class SolveTracebackStep(Step): {code} This is what the code should be in order to avoid the problem: - """).format(traceback=self.traceback.full_traceback, code="{code}") + """ + ).format(traceback=self.traceback.full_traceback, code="{code}") range_in_files = [] for frame in self.traceback.frames: content = await sdk.ide.readFile(frame.filepath) - range_in_files.append( - RangeInFile.from_entire_file(frame.filepath, content)) + range_in_files.append(RangeInFile.from_entire_file(frame.filepath, content)) - await sdk.run_step(DefaultModelEditCodeStep(range_in_files=range_in_files, user_input=prompt)) + await sdk.run_step( + DefaultModelEditCodeStep(range_in_files=range_in_files, user_input=prompt) + ) return None diff --git a/continuedev/src/continuedev/plugins/steps/open_config.py b/continuedev/src/continuedev/plugins/steps/open_config.py index 64ead547..d6283af2 100644 --- a/continuedev/src/continuedev/plugins/steps/open_config.py +++ b/continuedev/src/continuedev/plugins/steps/open_config.py @@ -1,15 +1,16 @@ from textwrap import dedent + from ...core.main import Step from ...core.sdk import ContinueSDK from ...libs.util.paths import getConfigFilePath -import os class OpenConfigStep(Step): name: str = "Open config" async def describe(self, models): - return dedent("""\ + return dedent( + """\ `\"config.py\"` is now open. You can add a custom slash command in the `\"custom_commands\"` section, like in this example: ```python config = ContinueConfig( @@ -23,7 +24,8 @@ class OpenConfigStep(Step): ``` `name` is the command you will type. `description` is the description displayed in the slash command menu. - `prompt` is the instruction given to the model. The overall prompt becomes "Task: {prompt}, Additional info: {user_input}". For example, if you entered "/test exactly 5 assertions", the overall prompt would become "Task: Write a comprehensive...and sophisticated, Additional info: exactly 5 assertions".""") + `prompt` is the instruction given to the model. The overall prompt becomes "Task: {prompt}, Additional info: {user_input}". For example, if you entered "/test exactly 5 assertions", the overall prompt would become "Task: Write a comprehensive...and sophisticated, Additional info: exactly 5 assertions".""" + ) async def run(self, sdk: ContinueSDK): await sdk.ide.setFileOpen(getConfigFilePath()) diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index da6acdbf..a2612731 100644 --- a/continuedev/src/continuedev/plugins/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -1,5 +1,6 @@ from textwrap import dedent -from typing import List, Union, Tuple +from typing import List, Tuple, Union + from ...core.main import Step from ...core.sdk import ContinueSDK @@ -13,11 +14,11 @@ class NLDecisionStep(Step): name: str = "Deciding what to do next" async def run(self, sdk: ContinueSDK): - step_descriptions = "\n".join([ - f"- {step[0].name}: {step[1]}" - for step in self.steps - ]) - prompt = dedent(f"""\ + step_descriptions = "\n".join( + [f"- {step[0].name}: {step[1]}" for step in self.steps] + ) + prompt = dedent( + f"""\ The following steps are available, in the format "- [step name]: [step description]": {step_descriptions} @@ -25,7 +26,8 @@ class NLDecisionStep(Step): {self.user_input} - Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""") + Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""" + ) resp = (await sdk.models.medium.complete(prompt)).lower() diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 456dba84..04fb98b7 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -1,14 +1,14 @@ import asyncio +import os +import re from textwrap import dedent from typing import List, Union -from ...models.filesystem import RangeInFile -from ...models.main import Range from ...core.main import Step from ...core.sdk import ContinueSDK from ...libs.util.create_async_task import create_async_task -import os -import re +from ...models.filesystem import RangeInFile +from ...models.main import Range # Already have some code for this somewhere IGNORE_DIRS = ["env", "venv", ".venv"] @@ -29,8 +29,12 @@ def find_all_matches_in_dir(pattern: str, dirpath: str) -> List[RangeInFile]: file_content = f.read() results = re.finditer(pattern, file_content) range_in_files += [ - RangeInFile(filepath=os.path.join(root, file), range=Range.from_indices( - file_content, result.start(), result.end())) + RangeInFile( + filepath=os.path.join(root, file), + range=Range.from_indices( + file_content, result.start(), result.end() + ), + ) for result in results ] @@ -42,12 +46,16 @@ class WriteRegexPatternStep(Step): async def run(self, sdk: ContinueSDK): # Ask the user for a regex pattern - pattern = await sdk.models.medium.complete(dedent(f"""\ + pattern = await sdk.models.medium.complete( + dedent( + f"""\ This is the user request: {self.user_request} - Please write either a regex pattern or just a string that be used with python's re module to find all matches requested by the user. It will be used as `re.findall(<PATTERN_YOU_WILL_WRITE>, file_content)`. Your output should be only the regex or string, nothing else:""")) + Please write either a regex pattern or just a string that be used with python's re module to find all matches requested by the user. It will be used as `re.findall(<PATTERN_YOU_WILL_WRITE>, file_content)`. Your output should be only the regex or string, nothing else:""" + ) + ) return pattern @@ -59,11 +67,18 @@ class EditAllMatchesStep(Step): async def run(self, sdk: ContinueSDK): # Search all files for a given string - range_in_files = find_all_matches_in_dir(self.pattern, self.directory or await sdk.ide.getWorkspaceDirectory()) + range_in_files = find_all_matches_in_dir( + self.pattern, self.directory or await sdk.ide.getWorkspaceDirectory() + ) - tasks = [create_async_task(sdk.edit_file( - range=range_in_file.range, - filename=range_in_file.filepath, - prompt=self.user_request - )) for range_in_file in range_in_files] + tasks = [ + create_async_task( + sdk.edit_file( + range=range_in_file.range, + filename=range_in_file.filepath, + prompt=self.user_request, + ) + ) + for range_in_file in range_in_files + ] await asyncio.gather(*tasks) diff --git a/continuedev/src/continuedev/plugins/steps/share_session.py b/continuedev/src/continuedev/plugins/steps/share_session.py index de8659bd..1d68dc90 100644 --- a/continuedev/src/continuedev/plugins/steps/share_session.py +++ b/continuedev/src/continuedev/plugins/steps/share_session.py @@ -3,15 +3,13 @@ import os import time from typing import Optional - +from ...core.main import FullState, Step from ...core.sdk import ContinueSDK -from ...core.main import Step, FullState -from ...libs.util.paths import getSessionFilePath, getGlobalFolderPath +from ...libs.util.paths import getGlobalFolderPath, getSessionFilePath from ...server.session_manager import session_manager class ShareSessionStep(Step): - session_id: Optional[str] = None async def run(self, sdk: ContinueSDK): @@ -23,12 +21,14 @@ class ShareSessionStep(Step): # Load the session data and format as a markdown file session_filepath = getSessionFilePath(self.session_id) - with open(session_filepath, 'r') as f: + with open(session_filepath, "r") as f: session_state = FullState(**json.load(f)) import datetime + date_created = datetime.datetime.fromtimestamp( - float(session_state.session_info.date_created)).strftime('%Y-%m-%d %H:%M:%S') + float(session_state.session_info.date_created) + ).strftime("%Y-%m-%d %H:%M:%S") content = f"This is a session transcript from [Continue](https://continue.dev) on {date_created}.\n\n" for node in session_state.history.timeline[:-2]: @@ -40,9 +40,10 @@ class ShareSessionStep(Step): # Save to a markdown file save_filepath = os.path.join( - getGlobalFolderPath(), f"{session_state.session_info.title}.md") + getGlobalFolderPath(), f"{session_state.session_info.title}.md" + ) - with open(save_filepath, 'w') as f: + with open(save_filepath, "w") as f: f.write(content) # Open the file diff --git a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py index 489cada3..d0058ffc 100644 --- a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py +++ b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py @@ -1,5 +1,5 @@ from ...core.main import Step -from ...core.sdk import Models, ContinueSDK +from ...core.sdk import ContinueSDK, Models class StepsOnStartupStep(Step): diff --git a/continuedev/src/continuedev/plugins/steps/welcome.py b/continuedev/src/continuedev/plugins/steps/welcome.py index df3e9a8a..ef1acfc1 100644 --- a/continuedev/src/continuedev/plugins/steps/welcome.py +++ b/continuedev/src/continuedev/plugins/steps/welcome.py @@ -1,9 +1,9 @@ -from textwrap import dedent import os +from textwrap import dedent -from ...models.filesystem_edit import AddFile from ...core.main import Step from ...core.sdk import ContinueSDK, Models +from ...models.filesystem_edit import AddFile class WelcomeStep(Step): @@ -21,13 +21,20 @@ class WelcomeStep(Step): if not os.path.exists(continue_dir): os.mkdir(continue_dir) - await sdk.ide.applyFileSystemEdit(AddFile(filepath=filepath, content=dedent("""\ + await sdk.ide.applyFileSystemEdit( + AddFile( + filepath=filepath, + content=dedent( + """\ \"\"\" Welcome to Continue! To learn how to use it, delete this comment and try to use Continue for the following: - "Write me a calculator class" - Ask for a new method (e.g. "exp", "mod", "sqrt") - Type /comment to write comments for the entire class - Ask about how the class works, how to write it in another language, etc. - \"\"\""""))) + \"\"\"""" + ), + ) + ) # await sdk.ide.setFileOpen(filepath=filepath) diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py index 990833be..d079475c 100644 --- a/continuedev/src/continuedev/server/gui_protocol.py +++ b/continuedev/src/continuedev/server/gui_protocol.py @@ -1,7 +1,5 @@ -from typing import Any, Dict, List from abc import ABC, abstractmethod - -from ..core.context import ContextItem +from typing import Any class AbstractGUIProtocolServer(ABC): diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 72b410d4..f63fecf8 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -1,10 +1,10 @@ +from abc import ABC, abstractmethod from typing import Any, List, Union -from abc import ABC, abstractmethod, abstractproperty + from fastapi import WebSocket -from ..models.main import Traceback -from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff from ..models.filesystem import RangeInFile, RangeInFileWithContents +from ..models.filesystem_edit import EditDiff, FileEdit, FileSystemEdit class AbstractIdeProtocolServer(ABC): diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index f0a3f094..00ded6f1 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,18 +1,16 @@ +import argparse import asyncio -import time -import psutil -import os -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware import atexit -import uvicorn -import argparse +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware -from .ide import router as ide_router -from .gui import router as gui_router -from .session_manager import session_manager, router as sessions_router from ..libs.util.logging import logger +from .gui import router as gui_router +from .ide import router as ide_router +from .session_manager import router as sessions_router +from .session_manager import session_manager app = FastAPI() @@ -39,8 +37,7 @@ def health(): try: # add cli arg for server port parser = argparse.ArgumentParser() - parser.add_argument("-p", "--port", help="server port", - type=int, default=65432) + parser.add_argument("-p", "--port", help="server port", type=int, default=65432) args = parser.parse_args() except Exception as e: logger.debug(f"Error parsing command line arguments: {e}") diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index f47c08ca..98b48685 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -5,8 +5,8 @@ import subprocess from meilisearch_python_async import Client -from ..libs.util.paths import getServerFolderPath from ..libs.util.logging import logger +from ..libs.util.paths import getServerFolderPath def ensure_meilisearch_installed() -> bool: @@ -43,7 +43,11 @@ def ensure_meilisearch_installed() -> bool: # Download MeiliSearch logger.debug("Downloading MeiliSearch...") subprocess.run( - f"curl -L https://install.meilisearch.com | sh", shell=True, check=True, cwd=serverPath) + "curl -L https://install.meilisearch.com | sh", + shell=True, + check=True, + cwd=serverPath, + ) return False @@ -56,7 +60,7 @@ async def check_meilisearch_running() -> bool: """ try: - async with Client('http://localhost:7700') as client: + async with Client("http://localhost:7700") as client: try: resp = await client.health() if resp.status != "available": @@ -96,5 +100,11 @@ async def start_meilisearch(): # Check if MeiliSearch is running if not await check_meilisearch_running() or not was_already_installed: logger.debug("Starting MeiliSearch...") - subprocess.Popen(["./meilisearch", "--no-analytics"], cwd=serverPath, stdout=subprocess.DEVNULL, - stderr=subprocess.STDOUT, close_fds=True, start_new_session=True) + subprocess.Popen( + ["./meilisearch", "--no-analytics"], + cwd=serverPath, + stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT, + close_fds=True, + start_new_session=True, + ) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 88ac13f6..6f4e4a87 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,20 +1,23 @@ +import json import os import traceback -from fastapi import WebSocket, APIRouter from typing import Any, Coroutine, Dict, Optional, Union from uuid import uuid4 -import json +from fastapi import APIRouter, WebSocket from fastapi.websockets import WebSocketState -from ..plugins.steps.core.core import MessageStep -from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath, getSessionsListFilePath -from ..core.main import FullState, HistoryNode, SessionInfo from ..core.autopilot import Autopilot -from .ide_protocol import AbstractIdeProtocolServer +from ..core.main import FullState from ..libs.util.create_async_task import create_async_task from ..libs.util.errors import SessionNotFound from ..libs.util.logging import logger +from ..libs.util.paths import ( + getSessionFilePath, + getSessionsFolderPath, + getSessionsListFilePath, +) +from .ide_protocol import AbstractIdeProtocolServer router = APIRouter(prefix="/sessions", tags=["sessions"]) @@ -42,14 +45,21 @@ class SessionManager: # And only if the IDE is still alive sessions_folder = getSessionsFolderPath() session_files = os.listdir(sessions_folder) - if f"{session_id}.json" in session_files and session_id in self.registered_ides: + if ( + f"{session_id}.json" in session_files + and session_id in self.registered_ides + ): if self.registered_ides[session_id].session_id is not None: - return await self.new_session(self.registered_ides[session_id], session_id=session_id) + return await self.new_session( + self.registered_ides[session_id], session_id=session_id + ) raise KeyError("Session ID not recognized", session_id) return self.sessions[session_id] - async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None) -> Session: + async def new_session( + self, ide: AbstractIdeProtocolServer, session_id: Optional[str] = None + ) -> Session: logger.debug(f"New session: {session_id}") # Load the persisted state (not being used right now) @@ -68,9 +78,9 @@ class SessionManager: # Set up the autopilot to update the GUI async def on_update(state: FullState): - await session_manager.send_ws_data(session_id, "state_update", { - "state": state.dict() - }) + await session_manager.send_ws_data( + session_id, "state_update", {"state": state.dict()} + ) autopilot.on_update(on_update) @@ -81,7 +91,7 @@ class SessionManager: await ide.on_error(e) def on_error(e: Exception) -> Coroutine: - err_msg = '\n'.join(traceback.format_exception(e)) + err_msg = "\n".join(traceback.format_exception(e)) return ide.showMessage(f"Error in Continue server: {err_msg}") create_async_task(autopilot.run_policy(), on_error) @@ -90,9 +100,15 @@ class SessionManager: async def remove_session(self, session_id: str): logger.debug(f"Removing session: {session_id}") if session_id in self.sessions: - if session_id in self.registered_ides and self.registered_ides[session_id] is not None: + if ( + session_id in self.registered_ides + and self.registered_ides[session_id] is not None + ): ws_to_close = self.registered_ides[session_id].websocket - if ws_to_close is not None and ws_to_close.client_state != WebSocketState.DISCONNECTED: + if ( + ws_to_close is not None + and ws_to_close.client_state != WebSocketState.DISCONNECTED + ): await self.sessions[session_id].autopilot.ide.websocket.close() del self.sessions[session_id] @@ -117,7 +133,9 @@ class SessionManager: with open(getSessionsListFilePath(), "w") as f: json.dump(sessions_list, f) - async def load_session(self, old_session_id: str, new_session_id: Optional[str] = None) -> str: + async def load_session( + self, old_session_id: str, new_session_id: Optional[str] = None + ) -> str: """Load the session's FullState from a json file""" # First persist the current state @@ -142,10 +160,9 @@ class SessionManager: # logger.debug(f"Session {session_id} has no websocket") return - await self.sessions[session_id].ws.send_json({ - "messageType": message_type, - "data": data - }) + await self.sessions[session_id].ws.send_json( + {"messageType": message_type, "data": data} + ) session_manager = SessionManager() |