summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-07-16 21:09:48 -0700
committerGitHub <noreply@github.com>2023-07-16 21:09:48 -0700
commitd4319f09a3b8c1b0d9d1a7178910f09eac01fce9 (patch)
treebe5ef32725ab4ad0cba85bcc5415a79388fb6da0
parenta4a815628f702af806603015ec6805edd151328b (diff)
parent9687c05a5c8d6aeb15e7386129cdb16c0255b56e (diff)
downloadsncontinue-d4319f09a3b8c1b0d9d1a7178910f09eac01fce9.tar.gz
sncontinue-d4319f09a3b8c1b0d9d1a7178910f09eac01fce9.tar.bz2
sncontinue-d4319f09a3b8c1b0d9d1a7178910f09eac01fce9.zip
Merge pull request #278 from continuedev/ggml-server
ggml server
-rw-r--r--continuedev/src/continuedev/core/config.py2
-rw-r--r--continuedev/src/continuedev/core/sdk.py7
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py86
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py7
-rw-r--r--continuedev/src/continuedev/steps/chat.py2
-rw-r--r--continuedev/src/continuedev/steps/core/core.py10
-rw-r--r--extension/src/bridge.ts6
-rw-r--r--extension/src/lang-server/codeActions.ts6
8 files changed, 113 insertions, 13 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 6e430c04..957609c5 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -76,7 +76,7 @@ class ContinueConfig(BaseModel):
server_url: Optional[str] = None
allow_anonymous_telemetry: Optional[bool] = True
default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
- "gpt-4"] = 'gpt-4'
+ "gpt-4", "ggml"] = 'gpt-4'
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
description="This is an example custom command. Use /config to edit it and create more",
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index d73561d2..eb60109c 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -11,6 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi
from ..models.filesystem import RangeInFile
from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
from ..libs.llm.openai import OpenAI
+from ..libs.llm.ggml import GGML
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
from .main import Context, ContinueCustomException, History, Step, ChatMessage
@@ -79,6 +80,10 @@ class Models:
def gpt4(self):
return self.__load_openai_model("gpt-4")
+ @cached_property
+ def ggml(self):
+ return GGML()
+
def __model_from_name(self, model_name: str):
if model_name == "starcoder":
return self.starcoder
@@ -88,6 +93,8 @@ class Models:
return self.gpt3516k
elif model_name == "gpt-4":
return self.gpt4
+ elif model_name == "ggml":
+ return self.ggml
else:
raise Exception(f"Unknown model {model_name}")
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
new file mode 100644
index 00000000..d3589b70
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -0,0 +1,86 @@
+from functools import cached_property
+import json
+from typing import Any, Coroutine, Dict, Generator, List, Union
+
+import aiohttp
+from ...core.main import ChatMessage
+from ..llm import LLM
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens
+
+SERVER_URL = "http://localhost:8000"
+
+
+class GGML(LLM):
+
+ def __init__(self, system_message: str = None):
+ self.system_message = system_message
+
+ @cached_property
+ def name(self):
+ return "ggml"
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.name, text)
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = self.default_args.copy()
+ args.update(kwargs)
+ args["stream"] = True
+
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None))
+
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ yield line.decode("utf-8")
+ except:
+ raise Exception(str(line))
+
+ async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name, messages, args["max_tokens"], None, functions=args.get("functions", None))
+ args["stream"] = True
+
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f"{SERVER_URL}/v1/chat/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
+ continue
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield json.loads(chunk[6:])["choices"][0]["delta"]
+ except:
+ raise Exception(str(line[0]))
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
+ args = {**self.default_args, **kwargs}
+
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None),
+ **args
+ }) as resp:
+ try:
+ return await resp.text()
+ except:
+ raise Exception(await resp.text())
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 73be0717..e1baeca1 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -3,13 +3,16 @@ from typing import Dict, List, Union
from ...core.main import ChatMessage
import tiktoken
-aliases = {}
+aliases = {
+ "ggml": "gpt-3.5-turbo",
+}
DEFAULT_MAX_TOKENS = 2048
MAX_TOKENS_FOR_MODEL = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16384,
- "gpt-4": 8192
+ "gpt-4": 8192,
+ "ggml": 2048
}
CHAT_MODELS = {
"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py
index 14a1cd41..3751dec2 100644
--- a/continuedev/src/continuedev/steps/chat.py
+++ b/continuedev/src/continuedev/steps/chat.py
@@ -28,7 +28,7 @@ class SimpleChatStep(Step):
completion = ""
messages = self.messages or await sdk.get_chat_context()
- generator = sdk.models.gpt4.stream_chat(messages, temperature=0.5)
+ generator = sdk.models.default.stream_chat(messages, temperature=0.5)
try:
async for chunk in generator:
if sdk.current_step_was_deleted():
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index 90d64287..d5a7cd9a 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -4,6 +4,7 @@ import subprocess
from textwrap import dedent
from typing import Coroutine, List, Literal, Union
+from ...libs.llm.ggml import GGML
from ...models.main import Range
from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder
from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit
@@ -190,8 +191,9 @@ class DefaultModelEditCodeStep(Step):
# We don't know here all of the functions being passed in.
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
- model_to_use = sdk.models.gpt4
- max_tokens = DEFAULT_MAX_TOKENS
+ model_to_use = sdk.models.default
+ max_tokens = MAX_TOKENS_FOR_MODEL.get(
+ model_to_use.name, DEFAULT_MAX_TOKENS) / 2
TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE:
@@ -495,6 +497,10 @@ Please output the code to be inserted at the cursor in order to fulfill the user
repeating_file_suffix = False
line_below_highlighted_range = file_suffix.lstrip().split("\n")[0]
+ if isinstance(model_to_use, GGML):
+ messages = [ChatMessage(
+ role="user", content=f"```\n{rif.contents}\n```\n\nUser request: \"{self.user_input}\"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n", summary=self.user_input)]
+
generator = model_to_use.stream_chat(
messages, temperature=0, max_tokens=max_tokens)
diff --git a/extension/src/bridge.ts b/extension/src/bridge.ts
index 7e6398be..d614ace4 100644
--- a/extension/src/bridge.ts
+++ b/extension/src/bridge.ts
@@ -1,11 +1,7 @@
import fetch from "node-fetch";
import * as path from "path";
import * as vscode from "vscode";
-import {
- Configuration,
- DebugApi,
- UnittestApi,
-} from "./client";
+import { Configuration, DebugApi, UnittestApi } from "./client";
import { convertSingleToDoubleQuoteJSON } from "./util/util";
import { getExtensionUri } from "./util/vscode";
import { extensionContext } from "./activation/activate";
diff --git a/extension/src/lang-server/codeActions.ts b/extension/src/lang-server/codeActions.ts
index 07cf5f4e..f0d61ace 100644
--- a/extension/src/lang-server/codeActions.ts
+++ b/extension/src/lang-server/codeActions.ts
@@ -23,8 +23,10 @@ class ContinueQuickFixProvider implements vscode.CodeActionProvider {
);
quickFix.isPreferred = false;
const surroundingRange = new vscode.Range(
- range.start.translate(-3, 0),
- range.end.translate(3, 0)
+ Math.max(0, range.start.line - 3),
+ 0,
+ Math.min(document.lineCount, range.end.line + 3),
+ 0
);
quickFix.command = {
command: "continue.quickFix",