summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-16 16:55:24 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-16 16:55:24 -0700
commita3f4a2a59d6785499f3ce0c4af80b57b02de1b1f (patch)
treebcf02a1382641ed6cbea226d7b1cbdeeeadb5bf9 /continuedev
parent4c3a25a1c8938f8132233e021c74d98eb19d7ddd (diff)
downloadsncontinue-a3f4a2a59d6785499f3ce0c4af80b57b02de1b1f.tar.gz
sncontinue-a3f4a2a59d6785499f3ce0c4af80b57b02de1b1f.tar.bz2
sncontinue-a3f4a2a59d6785499f3ce0c4af80b57b02de1b1f.zip
better prompt for editing
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/config.py2
-rw-r--r--continuedev/src/continuedev/core/sdk.py5
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py33
-rw-r--r--continuedev/src/continuedev/steps/core/core.py5
4 files changed, 17 insertions, 28 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 6e430c04..957609c5 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -76,7 +76,7 @@ class ContinueConfig(BaseModel):
server_url: Optional[str] = None
allow_anonymous_telemetry: Optional[bool] = True
default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
- "gpt-4"] = 'gpt-4'
+ "gpt-4", "ggml"] = 'gpt-4'
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
description="This is an example custom command. Use /config to edit it and create more",
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 9389e1e9..eb60109c 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -82,7 +82,7 @@ class Models:
@cached_property
def ggml(self):
- return GGML("", "ggml")
+ return GGML()
def __model_from_name(self, model_name: str):
if model_name == "starcoder":
@@ -93,12 +93,13 @@ class Models:
return self.gpt3516k
elif model_name == "gpt-4":
return self.gpt4
+ elif model_name == "ggml":
+ return self.ggml
else:
raise Exception(f"Unknown model {model_name}")
@property
def default(self):
- return self.ggml
default_model = self.sdk.config.default_model
return self.__model_from_name(default_model) if default_model is not None else self.gpt4
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index bef0d993..d3589b70 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -4,39 +4,27 @@ from typing import Any, Coroutine, Dict, Generator, List, Union
import aiohttp
from ...core.main import ChatMessage
-import openai
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
-import certifi
-import ssl
-
-ca_bundle_path = certifi.where()
-ssl_context = ssl.create_default_context(cafile=ca_bundle_path)
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens
SERVER_URL = "http://localhost:8000"
class GGML(LLM):
- api_key: str
- default_model: str
- def __init__(self, api_key: str, default_model: str, system_message: str = None):
- self.api_key = api_key
- self.default_model = default_model
+ def __init__(self, system_message: str = None):
self.system_message = system_message
- openai.api_key = api_key
-
@cached_property
def name(self):
- return self.default_model
+ return "ggml"
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.name, text)
async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
@@ -45,9 +33,9 @@ class GGML(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None))
+ self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None))
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
+ async with aiohttp.ClientSession() as session:
async with session.post(f"{SERVER_URL}/v1/completions", json={
"messages": messages,
**args
@@ -62,10 +50,10 @@ class GGML(LLM):
async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None))
+ self.name, messages, args["max_tokens"], None, functions=args.get("functions", None))
args["stream"] = True
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
+ async with aiohttp.ClientSession() as session:
async with session.post(f"{SERVER_URL}/v1/chat/completions", json={
"messages": messages,
**args
@@ -77,7 +65,6 @@ class GGML(LLM):
json_chunk = line[0].decode("utf-8")
if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
continue
- json_chunk = "{}" if json_chunk == "" else json_chunk
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
@@ -88,7 +75,7 @@ class GGML(LLM):
async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
+ async with aiohttp.ClientSession() as session:
async with session.post(f"{SERVER_URL}/v1/completions", json={
"messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None),
**args
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index 2c9d8c01..d5a7cd9a 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -192,7 +192,8 @@ class DefaultModelEditCodeStep(Step):
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
model_to_use = sdk.models.default
- max_tokens = DEFAULT_MAX_TOKENS
+ max_tokens = MAX_TOKENS_FOR_MODEL.get(
+ model_to_use.name, DEFAULT_MAX_TOKENS) / 2
TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE:
@@ -498,7 +499,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user
if isinstance(model_to_use, GGML):
messages = [ChatMessage(
- role="user", content=f"```\n{rif.contents}\n```\n{self.user_input}\n```\n", summary=self.user_input)]
+ role="user", content=f"```\n{rif.contents}\n```\n\nUser request: \"{self.user_input}\"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n", summary=self.user_input)]
generator = model_to_use.stream_chat(
messages, temperature=0, max_tokens=max_tokens)