summaryrefslogtreecommitdiff
path: root/server/continuedev/core/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'server/continuedev/core/main.py')
-rw-r--r--server/continuedev/core/main.py437
1 files changed, 437 insertions, 0 deletions
diff --git a/server/continuedev/core/main.py b/server/continuedev/core/main.py
new file mode 100644
index 00000000..617a5aaa
--- /dev/null
+++ b/server/continuedev/core/main.py
@@ -0,0 +1,437 @@
+import json
+from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
+
+from pydantic import BaseModel, validator
+from pydantic.schema import schema
+
+from ..models.main import ContinueBaseModel
+from .observation import Observation
+
+ChatMessageRole = Literal["assistant", "user", "system", "function"]
+
+
+class FunctionCall(ContinueBaseModel):
+ name: str
+ arguments: str
+
+
+class ChatMessage(ContinueBaseModel):
+ role: ChatMessageRole
+ content: Union[str, None] = None
+ name: Union[str, None] = None
+ # A summary for pruning chat context to fit context window. Often the Step name.
+ summary: str
+ function_call: Union[FunctionCall, None] = None
+
+ def to_dict(self, with_functions: bool) -> Dict:
+ d = self.dict()
+ del d["summary"]
+ if d["function_call"] is not None:
+ d["function_call"]["name"] = d["function_call"]["name"].replace(" ", "")
+
+ if d["content"] is None:
+ d["content"] = ""
+ for key, value in list(d.items()):
+ if value is None:
+ del d[key]
+
+ if not with_functions:
+ if d["role"] == "function":
+ d["role"] = "assistant"
+ if "name" in d:
+ del d["name"]
+ if "function_call" in d:
+ del d["function_call"]
+ return d
+
+
+def resolve_refs(schema_data):
+ def traverse(obj):
+ if isinstance(obj, dict):
+ if "$ref" in obj:
+ ref = obj["$ref"]
+ parts = ref.split("/")
+ ref_obj = schema_data
+ for part in parts[1:]:
+ ref_obj = ref_obj[part]
+ return traverse(ref_obj)
+ else:
+ for key, value in obj.items():
+ obj[key] = traverse(value)
+ elif isinstance(obj, list):
+ for i in range(len(obj)):
+ obj[i] = traverse(obj[i])
+ return obj
+
+ return traverse(schema_data)
+
+
+unincluded_parameters = [
+ "system_message",
+ "chat_context",
+ "manage_own_chat_context",
+ "hide",
+ "name",
+ "description",
+]
+
+
+def step_to_json_schema(step) -> str:
+ pydantic_class = step.__class__
+ schema_data = schema([pydantic_class])
+ resolved_schema = resolve_refs(schema_data)
+ parameters = resolved_schema["definitions"][pydantic_class.__name__]
+ for parameter in unincluded_parameters:
+ if parameter in parameters["properties"]:
+ del parameters["properties"][parameter]
+ return {
+ "name": step.name.replace(" ", ""),
+ "description": step.description or "",
+ "parameters": parameters,
+ }
+
+
+def step_to_fn_call_arguments(step: "Step") -> str:
+ args = step.dict()
+ for parameter in unincluded_parameters:
+ if parameter in args:
+ del args[parameter]
+ return json.dumps(args)
+
+
+class HistoryNode(ContinueBaseModel):
+ """A point in history, a list of which make up History"""
+
+ step: "Step"
+ observation: Union[Observation, None]
+ depth: int
+ deleted: bool = False
+ active: bool = True
+ logs: List[str] = []
+ context_used: List["ContextItem"] = []
+
+ def to_chat_messages(self) -> List[ChatMessage]:
+ if self.step.description is None or self.step.manage_own_chat_context:
+ return self.step.chat_context
+ return self.step.chat_context + [
+ ChatMessage(
+ role="assistant",
+ name=self.step.__class__.__name__,
+ content=self.step.description or f"Ran function {self.step.name}",
+ summary=f"Called function {self.step.name}",
+ )
+ ]
+
+
+class History(ContinueBaseModel):
+ """A history of steps taken and their results"""
+
+ timeline: List[HistoryNode]
+ current_index: int
+
+ def to_chat_history(self) -> List[ChatMessage]:
+ msgs = []
+ for node in self.timeline:
+ if not node.step.hide:
+ msgs += node.to_chat_messages()
+ return msgs
+
+ def add_node(self, node: HistoryNode) -> int:
+ """Add node and return the index where it was added"""
+ self.timeline.insert(self.current_index + 1, node)
+ self.current_index += 1
+ return self.current_index
+
+ def get_current(self) -> Union[HistoryNode, None]:
+ if self.current_index < 0:
+ return None
+ return self.timeline[self.current_index]
+
+ def get_last_at_depth(
+ self, depth: int, include_current: bool = False
+ ) -> Union[HistoryNode, None]:
+ i = self.current_index if include_current else self.current_index - 1
+ while i >= 0:
+ if (
+ self.timeline[i].depth == depth
+ and type(self.timeline[i].step).__name__ != "ManualEditStep"
+ ):
+ return self.timeline[i]
+ i -= 1
+ return None
+
+ def get_last_at_same_depth(self) -> Union[HistoryNode, None]:
+ return self.get_last_at_depth(self.get_current().depth)
+
+ def remove_current_and_substeps(self):
+ self.timeline.pop(self.current_index)
+ while self.get_current() is not None and self.get_current().depth > 0:
+ self.timeline.pop(self.current_index)
+
+ def take_next_step(self) -> Union["Step", None]:
+ if self.has_future():
+ self.current_index += 1
+ current_state = self.get_current()
+ if current_state is None:
+ return None
+ return current_state.step
+ return None
+
+ def get_current_index(self) -> int:
+ return self.current_index
+
+ def has_future(self) -> bool:
+ return self.current_index < len(self.timeline) - 1
+
+ def step_back(self):
+ self.current_index -= 1
+
+ def last_observation(self) -> Union[Observation, None]:
+ state = self.get_last_at_same_depth()
+ if state is None:
+ return None
+ return state.observation
+
+ def pop_step(self, index: int = None) -> Union[HistoryNode, None]:
+ index = index if index is not None else self.current_index
+ if index < 0 or self.current_index < 0:
+ return None
+
+ node = self.timeline.pop(index)
+
+ if index <= self.current_index:
+ self.current_index -= 1
+
+ return node.step
+
+ @classmethod
+ def from_empty(cls):
+ return cls(timeline=[], current_index=-1)
+
+
+class SlashCommandDescription(ContinueBaseModel):
+ name: str
+ description: str
+
+
+class ContextItemId(BaseModel):
+ """
+ A ContextItemId is a unique identifier for a ContextItem.
+ """
+
+ provider_title: str
+ item_id: str
+
+ @validator("provider_title", "item_id")
+ def must_be_valid_id(cls, v):
+ import re
+
+ if not re.match(r"^[0-9a-zA-Z_-]*$", v):
+ raise ValueError(
+ "Both provider_title and item_id can only include characters 0-9, a-z, A-Z, -, and _"
+ )
+ return v
+
+ def to_string(self) -> str:
+ return f"{self.provider_title}-{self.item_id}"
+
+ @staticmethod
+ def from_string(string: str) -> "ContextItemId":
+ provider_title, *rest = string.split("-")
+ item_id = "-".join(rest)
+ return ContextItemId(provider_title=provider_title, item_id=item_id)
+
+
+class ContextItemDescription(BaseModel):
+ """
+ A ContextItemDescription is a description of a ContextItem that is displayed to the user when they type '@'.
+
+ The id can be used to retrieve the ContextItem from the ContextManager.
+ """
+
+ name: str
+ description: str
+ id: ContextItemId
+
+
+class ContextItem(BaseModel):
+ """
+ A ContextItem is a single item that is stored in the ContextManager.
+ """
+
+ description: ContextItemDescription
+ content: str
+
+ @validator("content", pre=True)
+ def content_must_be_string(cls, v):
+ if v is None:
+ return ""
+ return v
+
+ editing: bool = False
+ editable: bool = False
+
+
+class SessionInfo(ContinueBaseModel):
+ session_id: str
+ title: str
+ date_created: str
+ workspace_directory: Optional[str] = None
+
+
+class ContinueConfig(ContinueBaseModel):
+ system_message: Optional[str]
+ temperature: Optional[float]
+
+ class Config:
+ extra = "allow"
+
+ def dict(self, **kwargs):
+ original_dict = super().dict(**kwargs)
+ original_dict.pop("policy", None)
+ return original_dict
+
+
+class ContextProviderDescription(BaseModel):
+ title: str
+ display_title: str
+ description: str
+ dynamic: bool
+ requires_query: bool
+
+
+class FullState(ContinueBaseModel):
+ """A full state of the program, including the history"""
+
+ history: History
+ active: bool
+ user_input_queue: List[str]
+ slash_commands: List[SlashCommandDescription]
+ adding_highlighted_code: bool
+ selected_context_items: List[ContextItem]
+ session_info: Optional[SessionInfo] = None
+ config: ContinueConfig
+ saved_context_groups: Dict[str, List[ContextItem]] = {}
+ context_providers: List[ContextProviderDescription] = []
+ meilisearch_url: Optional[str] = None
+
+
+class ContinueSDK:
+ ...
+
+
+class Models:
+ ...
+
+
+class Policy(ContinueBaseModel):
+ """A rule that determines which step to take next"""
+
+ # Note that history is mutable, kinda sus
+ def next(
+ self, config: ContinueConfig, history: History = History.from_empty()
+ ) -> "Step":
+ raise NotImplementedError
+
+
+class Step(ContinueBaseModel):
+ name: str = None
+ hide: bool = False
+ description: Union[str, None] = None
+
+ class_name: str = "Step"
+
+ @validator("class_name", pre=True, always=True)
+ def class_name_is_class_name(cls, class_name):
+ return cls.__name__
+
+ system_message: Union[str, None] = None
+ chat_context: List[ChatMessage] = []
+ manage_own_chat_context: bool = False
+
+ class Config:
+ copy_on_model_validation = False
+
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
+ if self.description is not None:
+ return self.description
+ return "Running step: " + self.name
+
+ def dict(self, *args, **kwargs):
+ d = super().dict(*args, **kwargs)
+ # Make sure description is always a string
+ d["description"] = self.description or ""
+ return d
+
+ @validator("name", pre=True, always=True)
+ def name_is_class_name(cls, name):
+ if name is None:
+ return cls.__name__
+ return name
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ raise NotImplementedError
+
+ async def __call__(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ return await self.run(sdk)
+
+ def __rshift__(self, other: "Step"):
+ steps = []
+ if isinstance(self, SequentialStep):
+ steps = self.steps
+ else:
+ steps.append(self)
+ if isinstance(other, SequentialStep):
+ steps += other.steps
+ else:
+ steps.append(other)
+ return SequentialStep(steps=steps)
+
+
+class SequentialStep(Step):
+ steps: List[Step]
+ hide: bool = True
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ for step in self.steps:
+ observation = await sdk.run_step(step)
+ return observation
+
+
+class ValidatorObservation(Observation):
+ passed: bool
+ observation: Observation
+
+
+class Validator(Step):
+ def run(self, sdk: ContinueSDK) -> ValidatorObservation:
+ raise NotImplementedError
+
+
+class Context:
+ key_value: Dict[str, Any] = {}
+
+ def set(self, key: str, value: Any):
+ self.key_value[key] = value
+
+ def get(self, key: str) -> Any:
+ return self.key_value.get(key, None)
+
+
+class ContinueCustomException(Exception):
+ title: str
+ message: str
+ with_step: Union[Step, None]
+
+ def __init__(
+ self,
+ message: str,
+ title: str = "Error while running step:",
+ with_step: Union[Step, None] = None,
+ ):
+ self.message = message
+ self.title = title
+ self.with_step = with_step
+
+
+HistoryNode.update_forward_refs()