diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-04 11:58:02 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-04 11:58:02 -0700 |
commit | 142e2043b82fbf46af3841f9b07613ab8484dd65 (patch) | |
tree | 418ba409f34a01a7374e62287d9fe2c34886fd93 /continuedev/src | |
parent | 2192f332a7f40ad07dd039d512d68eb89ed3fb38 (diff) | |
download | sncontinue-142e2043b82fbf46af3841f9b07613ab8484dd65.tar.gz sncontinue-142e2043b82fbf46af3841f9b07613ab8484dd65.tar.bz2 sncontinue-142e2043b82fbf46af3841f9b07613ab8484dd65.zip |
3.8 compatibility and deleting context all at once
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hugging_face.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/search_directory.py | 4 |
8 files changed, 40 insertions, 21 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 5193a02b..05e48f40 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -173,8 +173,12 @@ class Autopilot(ContinueBaseModel): self.history.timeline[index].deleted = True await self.update_subscribers() - async def delete_context_item_at_index(self, index: int): - self._highlighted_ranges.pop(index) + async def delete_context_at_indices(self, indices: List[int]): + kept_ranges = [] + for i, rif in enumerate(self._highlighted_ranges): + if i not in indices: + kept_ranges.append(rif) + self._highlighted_ranges = kept_ranges await self.update_subscribers() async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]: diff --git a/continuedev/src/continuedev/libs/llm/hugging_face.py b/continuedev/src/continuedev/libs/llm/hugging_face.py index 868cb560..b0db585b 100644 --- a/continuedev/src/continuedev/libs/llm/hugging_face.py +++ b/continuedev/src/continuedev/libs/llm/hugging_face.py @@ -1,14 +1,17 @@ from .llm import LLM from transformers import AutoTokenizer, AutoModelForCausalLM + class HuggingFace(LLM): def __init__(self, model_path: str = "Salesforce/codegen-2B-mono"): self.model_path = model_path self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForCausalLM.from_pretrained(model_path) - + def complete(self, prompt: str, **kwargs): - args = { "max_tokens": 100 } | kwargs + args = {"max_tokens": 100} + args.update(kwargs) input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids - generated_ids = self.model.generate(input_ids, max_length=args["max_tokens"]) - return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
\ No newline at end of file + generated_ids = self.model.generate( + input_ids, max_length=args["max_tokens"]) + return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index a3ca5c80..c4e4139f 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -24,13 +24,14 @@ class OpenAI(LLM): @property def default_args(self): - return DEFAULT_ARGS | {"model": self.default_model} + return {**DEFAULT_ARGS, "model": self.default_model} def count_tokens(self, text: str): return count_tokens(self.default_model, text) async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args | kwargs + args = self.default_args.copy() + args.update(kwargs) args["stream"] = True if args["model"] in CHAT_MODELS: @@ -48,7 +49,8 @@ class OpenAI(LLM): yield chunk.choices[0].text async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args | kwargs + args = self.default_args.copy() + args.update(kwargs) args["stream"] = True args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613" if not args["model"].endswith("0613") and "functions" in args: @@ -62,7 +64,7 @@ class OpenAI(LLM): yield chunk.choices[0].delta async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: - args = self.default_args | kwargs + args = {**self.default_args, **kwargs} if args["model"] in CHAT_MODELS: resp = (await openai.ChatCompletion.acreate( diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 69c96ee8..05ece394 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -28,13 +28,13 @@ class ProxyServer(LLM): @property def default_args(self): - return DEFAULT_ARGS | {"model": self.default_model} + return {**DEFAULT_ARGS, "model": self.default_model} def count_tokens(self, text: str): return count_tokens(self.default_model, text) async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: - args = self.default_args | kwargs + args = {**self.default_args, **kwargs} async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/complete", json={ @@ -48,7 +48,7 @@ class ProxyServer(LLM): raise Exception(await resp.text()) async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: - args = self.default_args | kwargs + args = {**self.default_args, **kwargs} messages = compile_chat_messages( self.default_model, messages, None, functions=args.get("functions", None)) @@ -72,7 +72,7 @@ class ProxyServer(LLM): raise Exception(str(line[0])) async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args | kwargs + args = {**self.default_args, **kwargs} messages = compile_chat_messages( self.default_model, with_history, prompt, functions=args.get("functions", None)) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index b2f23bac..4e960f7c 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -1,6 +1,6 @@ import json from fastapi import Depends, Header, WebSocket, APIRouter -from typing import Any, Type, TypeVar, Union +from typing import Any, List, Type, TypeVar, Union from pydantic import BaseModel from uvicorn.main import Server @@ -83,8 +83,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.on_clear_history() elif message_type == "delete_at_index": self.on_delete_at_index(data["index"]) - elif message_type == "delete_context_item_at_index": - self.on_delete_context_item_at_index(data["index"]) + elif message_type == "delete_context_at_indices": + self.on_delete_context_at_indices(data["indices"]) except Exception as e: print(e) @@ -123,9 +123,9 @@ class GUIProtocolServer(AbstractGUIProtocolServer): def on_delete_at_index(self, index: int): asyncio.create_task(self.session.autopilot.delete_at_index(index)) - def on_delete_context_item_at_index(self, index: int): + def on_delete_context_at_indices(self, indices: List[int]): asyncio.create_task( - self.session.autopilot.delete_context_item_at_index(index) + self.session.autopilot.delete_context_at_indices(indices) ) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index e2685493..ea355d3c 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -160,6 +160,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer): "edit": file_edit.dict() }) + async def showDiff(self, filepath: str, replacement: str): + await self._send_json("showDiff", { + "filepath": filepath, + "replacement": replacement + }) + async def setFileOpen(self, filepath: str, open: bool = True): # Autopilot needs access to this. await self._send_json("setFileOpen", { diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index de2eea27..2e1f78d7 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -95,6 +95,10 @@ class AbstractIdeProtocolServer(ABC): def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): """Called when highlighted code is updated""" + @abstractmethod + async def showDiff(self, filepath: str, replacement: str): + """Show a diff""" + @abstractproperty def workspace_directory(self) -> str: """Get the workspace directory""" diff --git a/continuedev/src/continuedev/steps/search_directory.py b/continuedev/src/continuedev/steps/search_directory.py index d2966f46..2eecc99c 100644 --- a/continuedev/src/continuedev/steps/search_directory.py +++ b/continuedev/src/continuedev/steps/search_directory.py @@ -1,6 +1,6 @@ import asyncio from textwrap import dedent -from typing import List +from typing import List, Union from ..models.filesystem import RangeInFile from ..models.main import Range @@ -54,7 +54,7 @@ class WriteRegexPatternStep(Step): class EditAllMatchesStep(Step): pattern: str user_request: str - directory: str | None = None + directory: Union[str, None] = None async def run(self, sdk: ContinueSDK): # Search all files for a given string |