diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-08-18 15:17:04 -0700 |
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-08-18 15:17:04 -0700 |
| commit | 70370bf0d033c2575c84ffe10c9e5c484bbad54f (patch) | |
| tree | be55005a3ef2867594a432ea12786098ae6a6c94 /continuedev/src/continuedev/core | |
| parent | ab7a90a0972188dcc7b8c28b1263c918776ca19d (diff) | |
| download | sncontinue-70370bf0d033c2575c84ffe10c9e5c484bbad54f.tar.gz sncontinue-70370bf0d033c2575c84ffe10c9e5c484bbad54f.tar.bz2 sncontinue-70370bf0d033c2575c84ffe10c9e5c484bbad54f.zip | |
style: :art: autoformat with black on all python files
Diffstat (limited to 'continuedev/src/continuedev/core')
| -rw-r--r-- | continuedev/src/continuedev/core/abstract_sdk.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/env.py | 17 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/main.py | 78 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/models.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/observation.py | 1 |
5 files changed, 70 insertions, 40 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index e048f877..98730d38 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -1,11 +1,10 @@ -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from typing import Coroutine, List, Union -from .config import ContinueConfig from ..models.filesystem_edit import FileSystemEdit +from .config import ContinueConfig +from .main import ChatMessage, History, Step from .observation import Observation -from .main import ChatMessage, History, Step, ChatMessageRole - """ [[Generate]] diff --git a/continuedev/src/continuedev/core/env.py b/continuedev/src/continuedev/core/env.py index 2692c348..60b86538 100644 --- a/continuedev/src/continuedev/core/env.py +++ b/continuedev/src/continuedev/core/env.py @@ -1,6 +1,7 @@ -from dotenv import load_dotenv import os +from dotenv import load_dotenv + def get_env_var(var_name: str): load_dotenv() @@ -8,21 +9,21 @@ def get_env_var(var_name: str): def make_sure_env_exists(): - if not os.path.exists('.env'): - with open('.env', 'w') as f: - f.write('') + if not os.path.exists(".env"): + with open(".env", "w") as f: + f.write("") def save_env_var(var_name: str, var_value: str): make_sure_env_exists() - with open('.env', 'r') as f: + with open(".env", "r") as f: lines = f.readlines() - with open('.env', 'w') as f: + with open(".env", "w") as f: values = {} for line in lines: - key, value = line.split('=') - value = value.replace('"', '') + key, value = line.split("=") + value = value.replace('"', "") values[key] = value values[var_name] = var_value diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index a33d777e..53440dae 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -1,10 +1,10 @@ import json from typing import Coroutine, Dict, List, Literal, Optional, Union -from pydantic.schema import schema +from pydantic import BaseModel, validator +from pydantic.schema import schema from ..models.main import ContinueBaseModel -from pydantic import BaseModel, validator from .observation import Observation ChatMessageRole = Literal["assistant", "user", "system", "function"] @@ -27,8 +27,7 @@ class ChatMessage(ContinueBaseModel): d = self.dict() del d["summary"] if d["function_call"] is not None: - d["function_call"]["name"] = d["function_call"]["name"].replace( - " ", "") + d["function_call"]["name"] = d["function_call"]["name"].replace(" ", "") if d["content"] is None: d["content"] = "" @@ -49,9 +48,9 @@ class ChatMessage(ContinueBaseModel): def resolve_refs(schema_data): def traverse(obj): if isinstance(obj, dict): - if '$ref' in obj: - ref = obj['$ref'] - parts = ref.split('/') + if "$ref" in obj: + ref = obj["$ref"] + parts = ref.split("/") ref_obj = schema_data for part in parts[1:]: ref_obj = ref_obj[part] @@ -67,8 +66,14 @@ def resolve_refs(schema_data): return traverse(schema_data) -unincluded_parameters = ["system_message", "chat_context", - "manage_own_chat_context", "hide", "name", "description"] +unincluded_parameters = [ + "system_message", + "chat_context", + "manage_own_chat_context", + "hide", + "name", + "description", +] def step_to_json_schema(step) -> str: @@ -82,7 +87,7 @@ def step_to_json_schema(step) -> str: return { "name": step.name.replace(" ", ""), "description": step.description or "", - "parameters": parameters + "parameters": parameters, } @@ -96,6 +101,7 @@ def step_to_fn_call_arguments(step: "Step") -> str: class HistoryNode(ContinueBaseModel): """A point in history, a list of which make up History""" + step: "Step" observation: Union[Observation, None] depth: int @@ -111,12 +117,14 @@ class HistoryNode(ContinueBaseModel): role="assistant", name=self.step.__class__.__name__, content=self.step.description or f"Ran function {self.step.name}", - summary=f"Called function {self.step.name}" - )] + summary=f"Called function {self.step.name}", + ) + ] class History(ContinueBaseModel): """A history of steps taken and their results""" + timeline: List[HistoryNode] current_index: int @@ -128,7 +136,7 @@ class History(ContinueBaseModel): return msgs def add_node(self, node: HistoryNode) -> int: - """ Add node and return the index where it was added """ + """Add node and return the index where it was added""" self.timeline.insert(self.current_index + 1, node) self.current_index += 1 return self.current_index @@ -138,10 +146,15 @@ class History(ContinueBaseModel): return None return self.timeline[self.current_index] - def get_last_at_depth(self, depth: int, include_current: bool = False) -> Union[HistoryNode, None]: + def get_last_at_depth( + self, depth: int, include_current: bool = False + ) -> Union[HistoryNode, None]: i = self.current_index if include_current else self.current_index - 1 while i >= 0: - if self.timeline[i].depth == depth and type(self.timeline[i].step).__name__ != "ManualEditStep": + if ( + self.timeline[i].depth == depth + and type(self.timeline[i].step).__name__ != "ManualEditStep" + ): return self.timeline[i] i -= 1 return None @@ -204,24 +217,27 @@ class ContextItemId(BaseModel): """ A ContextItemId is a unique identifier for a ContextItem. """ + provider_title: str item_id: str - @validator('provider_title', 'item_id') + @validator("provider_title", "item_id") def must_be_valid_id(cls, v): import re - if not re.match(r'^[0-9a-zA-Z_-]*$', v): + + if not re.match(r"^[0-9a-zA-Z_-]*$", v): raise ValueError( - "Both provider_title and item_id can only include characters 0-9, a-z, A-Z, -, and _") + "Both provider_title and item_id can only include characters 0-9, a-z, A-Z, -, and _" + ) return v def to_string(self) -> str: return f"{self.provider_title}-{self.item_id}" @staticmethod - def from_string(string: str) -> 'ContextItemId': - provider_title, *rest = string.split('-') - item_id = '-'.join(rest) + def from_string(string: str) -> "ContextItemId": + provider_title, *rest = string.split("-") + item_id = "-".join(rest) return ContextItemId(provider_title=provider_title, item_id=item_id) @@ -231,22 +247,24 @@ class ContextItemDescription(BaseModel): The id can be used to retrieve the ContextItem from the ContextManager. """ + name: str description: str - id: ContextItemId + id: ContextItemId class ContextItem(BaseModel): """ A ContextItem is a single item that is stored in the ContextManager. """ + description: ContextItemDescription content: str - @validator('content', pre=True) + @validator("content", pre=True) def content_must_be_string(cls, v): if v is None: - return '' + return "" return v editing: bool = False @@ -261,6 +279,7 @@ class SessionInfo(ContinueBaseModel): class FullState(ContinueBaseModel): """A full state of the program, including the history""" + history: History active: bool user_input_queue: List[str] @@ -286,7 +305,9 @@ class Policy(ContinueBaseModel): """A rule that determines which step to take next""" # Note that history is mutable, kinda sus - def next(self, config: ContinueConfig, history: History = History.from_empty()) -> "Step": + def next( + self, config: ContinueConfig, history: History = History.from_empty() + ) -> "Step": raise NotImplementedError @@ -373,7 +394,12 @@ class ContinueCustomException(Exception): message: str with_step: Union[Step, None] - def __init__(self, message: str, title: str = "Error while running step:", with_step: Union[Step, None] = None): + def __init__( + self, + message: str, + title: str = "Error while running step:", + with_step: Union[Step, None] = None, + ): self.message = message self.title = title self.with_step = with_step diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index 900762b6..52a52b1d 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,10 +1,13 @@ -from typing import Optional, Any -from pydantic import BaseModel, validator +from typing import Any, Optional + +from pydantic import BaseModel + from ..libs.llm import LLM class Models(BaseModel): """Main class that holds the current model configuration""" + default: LLM small: Optional[LLM] = None medium: Optional[LLM] = None diff --git a/continuedev/src/continuedev/core/observation.py b/continuedev/src/continuedev/core/observation.py index 126cf19e..8a5e454e 100644 --- a/continuedev/src/continuedev/core/observation.py +++ b/continuedev/src/continuedev/core/observation.py @@ -1,4 +1,5 @@ from pydantic import BaseModel, validator + from ..models.main import Traceback |
