summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py19
-rw-r--r--continuedev/src/continuedev/core/config.py11
-rw-r--r--continuedev/src/continuedev/core/policy.py2
-rw-r--r--continuedev/src/continuedev/core/sdk.py15
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py86
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py3
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py25
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py6
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py14
-rw-r--r--continuedev/src/continuedev/libs/util/strings.py (renamed from continuedev/src/continuedev/libs/util/dedent.py)24
-rw-r--r--continuedev/src/continuedev/libs/util/templating.py39
-rw-r--r--continuedev/src/continuedev/server/ide.py2
-rw-r--r--continuedev/src/continuedev/steps/chat.py8
-rw-r--r--continuedev/src/continuedev/steps/core/core.py38
14 files changed, 244 insertions, 48 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 0696c360..4e177ac9 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -36,7 +36,7 @@ def get_error_title(e: Exception) -> str:
elif isinstance(e, openai_errors.APIConnectionError):
return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\""
elif isinstance(e, openai_errors.InvalidRequestError):
- return 'Your API key does not have access to GPT-4. You can use ours for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""'
+ return 'Invalid request sent to OpenAI. Please try again.'
elif e.__str__().startswith("Cannot connect to host"):
return "The request failed. Please check your internet connection and try again."
return e.__str__() or e.__repr__()
@@ -166,6 +166,22 @@ class Autopilot(ContinueBaseModel):
if not any(map(lambda x: x.editing, self._highlighted_ranges)):
self._highlighted_ranges[0].editing = True
+ def _disambiguate_highlighted_ranges(self):
+ """If any files have the same name, also display their folder name"""
+ name_counts = {}
+ for rif in self._highlighted_ranges:
+ if rif.display_name in name_counts:
+ name_counts[rif.display_name] += 1
+ else:
+ name_counts[rif.display_name] = 1
+
+ for rif in self._highlighted_ranges:
+ if name_counts[rif.display_name] > 1:
+ rif.display_name = os.path.join(
+ os.path.basename(os.path.dirname(rif.range.filepath)), rif.display_name)
+ else:
+ rif.display_name = os.path.basename(rif.range.filepath)
+
async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]):
# Filter out rifs from ~/.continue/diffs folder
range_in_files = [
@@ -211,6 +227,7 @@ class Autopilot(ContinueBaseModel):
) for rif in range_in_files]
self._make_sure_is_editing_range()
+ self._disambiguate_highlighted_ranges()
await self.update_subscribers()
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 05ba48c6..6af0878d 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -67,16 +67,21 @@ DEFAULT_SLASH_COMMANDS = [
]
+class AzureInfo(BaseModel):
+ endpoint: str
+ engine: str
+ api_version: str
+
+
class ContinueConfig(BaseModel):
"""
A pydantic class for the continue config file.
"""
steps_on_startup: Optional[Dict[str, Dict]] = {}
disallowed_steps: Optional[List[str]] = []
- server_url: Optional[str] = None
allow_anonymous_telemetry: Optional[bool] = True
default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
- "gpt-4", "claude-2"] = 'gpt-4'
+ "gpt-4", "claude-2", "ggml"] = 'gpt-4'
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
description="This is an example custom command. Use /config to edit it and create more",
@@ -85,6 +90,8 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = DEFAULT_SLASH_COMMANDS
on_traceback: Optional[List[OnTracebackSteps]] = [
OnTracebackSteps(step_name="DefaultOnTracebackStep")]
+ system_message: Optional[str] = None
+ azure_openai_info: Optional[AzureInfo] = None
# Want to force these to be the slash commands for now
@validator('slash_commands', pre=True)
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index bc897357..d007c92b 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -58,7 +58,7 @@ class DemoPolicy(Policy):
if history.get_current() is None:
return (
MessageStep(name="Welcome to Continue", message=dedent("""\
- - Highlight code and ask a question or give instructions
+ - Highlight code section and ask a question or give instructions
- Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue
- Use `/help` to ask questions about how to use Continue""")) >>
WelcomeStep() >>
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 28487600..d3501f08 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -12,6 +12,7 @@ from ..models.filesystem import RangeInFile
from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
from ..libs.llm.openai import OpenAI
from ..libs.llm.anthropic import Anthropic
+from ..libs.llm.ggml import GGML
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
from .main import Context, ContinueCustomException, History, Step, ChatMessage
@@ -34,10 +35,12 @@ MODEL_PROVIDER_TO_ENV_VAR = {
class Models:
provider_keys: Dict[ModelProvider, str] = {}
model_providers: List[ModelProvider]
+ system_message: str
def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
self.sdk = sdk
self.model_providers = model_providers
+ self.system_message = sdk.config.system_message
@classmethod
async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
@@ -56,12 +59,12 @@ class Models:
def __load_openai_model(self, model: str) -> OpenAI:
api_key = self.provider_keys["openai"]
if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model)
- return OpenAI(api_key=api_key, default_model=model)
+ return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message)
+ return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info)
def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
api_key = self.provider_keys["hf_inference_api"]
- return HuggingFaceInferenceAPI(api_key=api_key, model=model)
+ return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)
def __load_anthropic_model(self, model: str) -> Anthropic:
api_key = self.provider_keys["anthropic"]
@@ -91,6 +94,10 @@ class Models:
def gpt4(self):
return self.__load_openai_model("gpt-4")
+ @cached_property
+ def ggml(self):
+ return GGML(system_message=self.system_message)
+
def __model_from_name(self, model_name: str):
if model_name == "starcoder":
return self.starcoder
@@ -102,6 +109,8 @@ class Models:
return self.gpt4
elif model_name == "claude-2":
return self.claude2
+ elif model_name == "ggml":
+ return self.ggml
else:
raise Exception(f"Unknown model {model_name}")
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
new file mode 100644
index 00000000..6007fdb4
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -0,0 +1,86 @@
+from functools import cached_property
+import json
+from typing import Any, Coroutine, Dict, Generator, List, Union
+
+import aiohttp
+from ...core.main import ChatMessage
+from ..llm import LLM
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens
+
+SERVER_URL = "http://localhost:8000"
+
+
+class GGML(LLM):
+
+ def __init__(self, system_message: str = None):
+ self.system_message = system_message
+
+ @cached_property
+ def name(self):
+ return "ggml"
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.name, text)
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = self.default_args.copy()
+ args.update(kwargs)
+ args["stream"] = True
+
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ yield line.decode("utf-8")
+ except:
+ raise Exception(str(line))
+
+ async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["stream"] = True
+
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f"{SERVER_URL}/v1/chat/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
+ continue
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield json.loads(chunk[6:])["choices"][0]["delta"]
+ except:
+ raise Exception(str(line[0]))
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
+ args = {**self.default_args, **kwargs}
+
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ **args
+ }) as resp:
+ try:
+ return await resp.text()
+ except:
+ raise Exception(await resp.text())
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 803ba122..7e11fbbe 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -11,9 +11,10 @@ class HuggingFaceInferenceAPI(LLM):
api_key: str
model: str
- def __init__(self, api_key: str, model: str):
+ def __init__(self, api_key: str, model: str, system_message: str = None):
self.api_key = api_key
self.model = model
+ self.system_message = system_message # TODO: Nothing being done with this
def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):
"""Return the completion of the text with the given temperature."""
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index f0877d90..33d10985 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,30 +1,41 @@
from functools import cached_property
-import time
from typing import Any, Coroutine, Dict, Generator, List, Union
+
from ...core.main import ChatMessage
import openai
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ...core.config import AzureInfo
class OpenAI(LLM):
api_key: str
default_model: str
- def __init__(self, api_key: str, default_model: str, system_message: str = None):
+ def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None):
self.api_key = api_key
self.default_model = default_model
self.system_message = system_message
+ self.azure_info = azure_info
openai.api_key = api_key
+ # Using an Azure OpenAI deployment
+ if azure_info is not None:
+ openai.api_type = "azure"
+ openai.api_base = azure_info.endpoint
+ openai.api_version = azure_info.api_version
+
@cached_property
def name(self):
return self.default_model
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ args = {**DEFAULT_ARGS, "model": self.default_model}
+ if self.azure_info is not None:
+ args["engine"] = self.azure_info.engine
+ return args
def count_tokens(self, text: str):
return count_tokens(self.default_model, text)
@@ -37,7 +48,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
**args,
):
if "content" in chunk.choices[0].delta:
@@ -58,7 +69,7 @@ class OpenAI(LLM):
async for chunk in await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], messages, args["max_tokens"], functions=args.get("functions", None)),
+ args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message),
**args,
):
yield chunk.choices[0].delta
@@ -69,7 +80,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
resp = (await openai.ChatCompletion.acreate(
messages=compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None),
+ args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
**args,
)).choices[0].message.content
else:
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index eab6e441..3ec492f3 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -38,7 +38,7 @@ class ProxyServer(LLM):
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/complete", json={
- "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None),
+ "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
"unique_id": self.unique_id,
**args
}) as resp:
@@ -50,7 +50,7 @@ class ProxyServer(LLM):
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None))
+ self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_chat", json={
@@ -74,7 +74,7 @@ class ProxyServer(LLM):
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None))
+ self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
async with session.post(f"{SERVER_URL}/stream_complete", json={
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 73be0717..1ca98fe6 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -1,15 +1,19 @@
import json
from typing import Dict, List, Union
from ...core.main import ChatMessage
+from .templating import render_system_message
import tiktoken
-aliases = {}
+aliases = {
+ "ggml": "gpt-3.5-turbo",
+}
DEFAULT_MAX_TOKENS = 2048
MAX_TOKENS_FOR_MODEL = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16384,
- "gpt-4": 8192
+ "gpt-4": 8192,
+ "ggml": 2048
}
CHAT_MODELS = {
"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
@@ -82,13 +86,15 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int,
for function in functions:
prompt_tokens += count_tokens(model, json.dumps(function))
+ rendered_system_message = render_system_message(system_message)
+
msgs = prune_chat_history(model,
- msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, system_message))
+ msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, rendered_system_message))
history = []
if system_message:
history.append({
"role": "system",
- "content": system_message
+ "content": rendered_system_message
})
history += [msg.to_dict(with_functions=functions is not None)
for msg in msgs]
diff --git a/continuedev/src/continuedev/libs/util/dedent.py b/continuedev/src/continuedev/libs/util/strings.py
index e59c2e97..f1fb8d0b 100644
--- a/continuedev/src/continuedev/libs/util/dedent.py
+++ b/continuedev/src/continuedev/libs/util/strings.py
@@ -23,3 +23,27 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]:
break
return "\n".join(map(lambda x: x.lstrip(lcp), lines)), lcp
+
+
+def remove_quotes_and_escapes(output: str) -> str:
+ """
+ Clean up the output of the completion API, removing unnecessary escapes and quotes
+ """
+ output = output.strip()
+
+ # Replace smart quotes
+ output = output.replace("“", '"')
+ output = output.replace("”", '"')
+ output = output.replace("‘", "'")
+ output = output.replace("’", "'")
+
+ # Remove escapes
+ output = output.replace('\\"', '"')
+ output = output.replace("\\'", "'")
+ output = output.replace("\\n", "\n")
+ output = output.replace("\\t", "\t")
+ output = output.replace("\\\\", "\\")
+ if (output.startswith('"') and output.endswith('"')) or (output.startswith("'") and output.endswith("'")):
+ output = output[1:-1]
+
+ return output
diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py
new file mode 100644
index 00000000..ebfc2e31
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/templating.py
@@ -0,0 +1,39 @@
+import os
+import chevron
+
+
+def get_vars_in_template(template):
+ """
+ Get the variables in a template
+ """
+ return [token[1] for token in chevron.tokenizer.tokenize(template) if token[0] == 'variable']
+
+
+def escape_var(var: str) -> str:
+ """
+ Escape a variable so it can be used in a template
+ """
+ return var.replace(os.path.sep, '').replace('.', '')
+
+
+def render_system_message(system_message: str) -> str:
+ """
+ Render system message with mustache syntax.
+ Right now it only supports rendering absolute file paths as their contents.
+ """
+ vars = get_vars_in_template(system_message)
+
+ args = {}
+ for var in vars:
+ if var.startswith(os.path.sep):
+ # Escape vars which are filenames, because mustache doesn't allow / in variable names
+ escaped_var = escape_var(var)
+ system_message = system_message.replace(
+ var, escaped_var)
+
+ if os.path.exists(var):
+ args[escaped_var] = open(var, 'r').read()
+ else:
+ args[escaped_var] = ''
+
+ return chevron.render(system_message, args)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index a91708ec..43538407 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -442,6 +442,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
if session_id is not None:
session_manager.registered_ides[session_id] = ideProtocolServer
other_msgs = await ideProtocolServer.initialize(session_id)
+ capture_event(ideProtocolServer.unique_id, "session_started", { "session_id": ideProtocolServer.session_id })
for other_msg in other_msgs:
handle_msg(other_msg)
@@ -462,4 +463,5 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
+ capture_event(ideProtocolServer.unique_id, "session_ended", { "session_id": ideProtocolServer.session_id })
session_manager.registered_ides.pop(ideProtocolServer.session_id)
diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py
index 3751dec2..7c6b42db 100644
--- a/continuedev/src/continuedev/steps/chat.py
+++ b/continuedev/src/continuedev/steps/chat.py
@@ -3,6 +3,7 @@ from typing import Any, Coroutine, List
from pydantic import Field
+from ..libs.util.strings import remove_quotes_and_escapes
from .main import EditHighlightedCodeStep
from .core.core import MessageStep
from ..core.main import FunctionCall, Models
@@ -43,11 +44,8 @@ class SimpleChatStep(Step):
finally:
await generator.aclose()
- self.name = (await sdk.models.gpt35.complete(
- f"Write a short title for the following chat message: {self.description}")).strip()
-
- if self.name.startswith('"') and self.name.endswith('"'):
- self.name = self.name[1:-1]
+ self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete(
+ f"Write a short title for the following chat message: {self.description}"))
self.chat_context.append(ChatMessage(
role="assistant",
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index 90d64287..2b049ecc 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -1,9 +1,11 @@
# These steps are depended upon by ContinueSDK
import os
import subprocess
+import difflib
from textwrap import dedent
from typing import Coroutine, List, Literal, Union
+from ...libs.llm.ggml import GGML
from ...models.main import Range
from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder
from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit
@@ -11,7 +13,7 @@ from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContent
from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
from ...core.main import ChatMessage, ContinueCustomException, Step, SequentialStep
from ...libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS
-from ...libs.util.dedent import dedent_and_get_common_whitespace
+from ...libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes
import difflib
@@ -156,42 +158,32 @@ class DefaultModelEditCodeStep(Step):
_new_contents: str = ""
_prompt_and_completion: str = ""
- def _cleanup_output(self, output: str) -> str:
- output = output.replace('\\"', '"')
- output = output.replace("\\'", "'")
- output = output.replace("\\n", "\n")
- output = output.replace("\\t", "\t")
- output = output.replace("\\\\", "\\")
- if output.startswith('"') and output.endswith('"'):
- output = output[1:-1]
-
- return output
-
async def describe(self, models: Models) -> Coroutine[str, None, None]:
if self._previous_contents.strip() == self._new_contents.strip():
description = "No edits were made"
else:
+ changes = '\n'.join(difflib.ndiff(
+ self._previous_contents.splitlines(), self._new_contents.splitlines()))
description = await models.gpt3516k.complete(dedent(f"""\
- ```original
- {self._previous_contents}
- ```
+ Diff summary: "{self.user_input}"
- ```new
- {self._new_contents}
+ ```diff
+ {changes}
```
Please give brief a description of the changes made above using markdown bullet points. Be concise:"""))
name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")
- self.name = self._cleanup_output(name)
+ self.name = remove_quotes_and_escapes(name)
- return f"{self._cleanup_output(description)}"
+ return f"{remove_quotes_and_escapes(description)}"
async def get_prompt_parts(self, rif: RangeInFileWithContents, sdk: ContinueSDK, full_file_contents: str):
# We don't know here all of the functions being passed in.
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
- model_to_use = sdk.models.gpt4
- max_tokens = DEFAULT_MAX_TOKENS
+ model_to_use = sdk.models.default
+ max_tokens = int(MAX_TOKENS_FOR_MODEL.get(
+ model_to_use.name, DEFAULT_MAX_TOKENS) / 2)
TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE:
@@ -495,6 +487,10 @@ Please output the code to be inserted at the cursor in order to fulfill the user
repeating_file_suffix = False
line_below_highlighted_range = file_suffix.lstrip().split("\n")[0]
+ if isinstance(model_to_use, GGML):
+ messages = [ChatMessage(
+ role="user", content=f"```\n{rif.contents}\n```\n\nUser request: \"{self.user_input}\"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n", summary=self.user_input)]
+
generator = model_to_use.stream_chat(
messages, temperature=0, max_tokens=max_tokens)