summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-04 11:58:02 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-04 11:58:02 -0700
commit142e2043b82fbf46af3841f9b07613ab8484dd65 (patch)
tree418ba409f34a01a7374e62287d9fe2c34886fd93 /continuedev/src
parent2192f332a7f40ad07dd039d512d68eb89ed3fb38 (diff)
downloadsncontinue-142e2043b82fbf46af3841f9b07613ab8484dd65.tar.gz
sncontinue-142e2043b82fbf46af3841f9b07613ab8484dd65.tar.bz2
sncontinue-142e2043b82fbf46af3841f9b07613ab8484dd65.zip
3.8 compatibility and deleting context all at once
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py8
-rw-r--r--continuedev/src/continuedev/libs/llm/hugging_face.py11
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py10
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py8
-rw-r--r--continuedev/src/continuedev/server/gui.py10
-rw-r--r--continuedev/src/continuedev/server/ide.py6
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py4
-rw-r--r--continuedev/src/continuedev/steps/search_directory.py4
8 files changed, 40 insertions, 21 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 5193a02b..05e48f40 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -173,8 +173,12 @@ class Autopilot(ContinueBaseModel):
self.history.timeline[index].deleted = True
await self.update_subscribers()
- async def delete_context_item_at_index(self, index: int):
- self._highlighted_ranges.pop(index)
+ async def delete_context_at_indices(self, indices: List[int]):
+ kept_ranges = []
+ for i, rif in enumerate(self._highlighted_ranges):
+ if i not in indices:
+ kept_ranges.append(rif)
+ self._highlighted_ranges = kept_ranges
await self.update_subscribers()
async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]:
diff --git a/continuedev/src/continuedev/libs/llm/hugging_face.py b/continuedev/src/continuedev/libs/llm/hugging_face.py
index 868cb560..b0db585b 100644
--- a/continuedev/src/continuedev/libs/llm/hugging_face.py
+++ b/continuedev/src/continuedev/libs/llm/hugging_face.py
@@ -1,14 +1,17 @@
from .llm import LLM
from transformers import AutoTokenizer, AutoModelForCausalLM
+
class HuggingFace(LLM):
def __init__(self, model_path: str = "Salesforce/codegen-2B-mono"):
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(model_path)
-
+
def complete(self, prompt: str, **kwargs):
- args = { "max_tokens": 100 } | kwargs
+ args = {"max_tokens": 100}
+ args.update(kwargs)
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
- generated_ids = self.model.generate(input_ids, max_length=args["max_tokens"])
- return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) \ No newline at end of file
+ generated_ids = self.model.generate(
+ input_ids, max_length=args["max_tokens"])
+ return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index a3ca5c80..c4e4139f 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -24,13 +24,14 @@ class OpenAI(LLM):
@property
def default_args(self):
- return DEFAULT_ARGS | {"model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.default_model}
def count_tokens(self, text: str):
return count_tokens(self.default_model, text)
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args | kwargs
+ args = self.default_args.copy()
+ args.update(kwargs)
args["stream"] = True
if args["model"] in CHAT_MODELS:
@@ -48,7 +49,8 @@ class OpenAI(LLM):
yield chunk.choices[0].text
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args | kwargs
+ args = self.default_args.copy()
+ args.update(kwargs)
args["stream"] = True
args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613"
if not args["model"].endswith("0613") and "functions" in args:
@@ -62,7 +64,7 @@ class OpenAI(LLM):
yield chunk.choices[0].delta
async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
- args = self.default_args | kwargs
+ args = {**self.default_args, **kwargs}
if args["model"] in CHAT_MODELS:
resp = (await openai.ChatCompletion.acreate(
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 69c96ee8..05ece394 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -28,13 +28,13 @@ class ProxyServer(LLM):
@property
def default_args(self):
- return DEFAULT_ARGS | {"model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.default_model}
def count_tokens(self, text: str):
return count_tokens(self.default_model, text)
async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
- args = self.default_args | kwargs
+ args = {**self.default_args, **kwargs}
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/complete", json={
@@ -48,7 +48,7 @@ class ProxyServer(LLM):
raise Exception(await resp.text())
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
- args = self.default_args | kwargs
+ args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
self.default_model, messages, None, functions=args.get("functions", None))
@@ -72,7 +72,7 @@ class ProxyServer(LLM):
raise Exception(str(line[0]))
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args | kwargs
+ args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
self.default_model, with_history, prompt, functions=args.get("functions", None))
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index b2f23bac..4e960f7c 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -1,6 +1,6 @@
import json
from fastapi import Depends, Header, WebSocket, APIRouter
-from typing import Any, Type, TypeVar, Union
+from typing import Any, List, Type, TypeVar, Union
from pydantic import BaseModel
from uvicorn.main import Server
@@ -83,8 +83,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.on_clear_history()
elif message_type == "delete_at_index":
self.on_delete_at_index(data["index"])
- elif message_type == "delete_context_item_at_index":
- self.on_delete_context_item_at_index(data["index"])
+ elif message_type == "delete_context_at_indices":
+ self.on_delete_context_at_indices(data["indices"])
except Exception as e:
print(e)
@@ -123,9 +123,9 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
def on_delete_at_index(self, index: int):
asyncio.create_task(self.session.autopilot.delete_at_index(index))
- def on_delete_context_item_at_index(self, index: int):
+ def on_delete_context_at_indices(self, indices: List[int]):
asyncio.create_task(
- self.session.autopilot.delete_context_item_at_index(index)
+ self.session.autopilot.delete_context_at_indices(indices)
)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index e2685493..ea355d3c 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -160,6 +160,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
"edit": file_edit.dict()
})
+ async def showDiff(self, filepath: str, replacement: str):
+ await self._send_json("showDiff", {
+ "filepath": filepath,
+ "replacement": replacement
+ })
+
async def setFileOpen(self, filepath: str, open: bool = True):
# Autopilot needs access to this.
await self._send_json("setFileOpen", {
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index de2eea27..2e1f78d7 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -95,6 +95,10 @@ class AbstractIdeProtocolServer(ABC):
def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]):
"""Called when highlighted code is updated"""
+ @abstractmethod
+ async def showDiff(self, filepath: str, replacement: str):
+ """Show a diff"""
+
@abstractproperty
def workspace_directory(self) -> str:
"""Get the workspace directory"""
diff --git a/continuedev/src/continuedev/steps/search_directory.py b/continuedev/src/continuedev/steps/search_directory.py
index d2966f46..2eecc99c 100644
--- a/continuedev/src/continuedev/steps/search_directory.py
+++ b/continuedev/src/continuedev/steps/search_directory.py
@@ -1,6 +1,6 @@
import asyncio
from textwrap import dedent
-from typing import List
+from typing import List, Union
from ..models.filesystem import RangeInFile
from ..models.main import Range
@@ -54,7 +54,7 @@ class WriteRegexPatternStep(Step):
class EditAllMatchesStep(Step):
pattern: str
user_request: str
- directory: str | None = None
+ directory: Union[str, None] = None
async def run(self, sdk: ContinueSDK):
# Search all files for a given string