From 3ded151331933c9a1352cc46c3cc67c5733d1c86 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 16 Jul 2023 16:16:41 -0700 Subject: ggml --- continuedev/src/continuedev/core/sdk.py | 6 ++ continuedev/src/continuedev/libs/llm/ggml.py | 99 ++++++++++++++++++++++ .../src/continuedev/libs/util/count_tokens.py | 7 +- continuedev/src/continuedev/steps/chat.py | 2 +- continuedev/src/continuedev/steps/core/core.py | 8 +- 5 files changed, 118 insertions(+), 4 deletions(-) create mode 100644 continuedev/src/continuedev/libs/llm/ggml.py diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 8649cd58..22393746 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -11,6 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi from ..models.filesystem import RangeInFile from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI from ..libs.llm.openai import OpenAI +from ..libs.llm.ggml import GGML from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer from .main import Context, ContinueCustomException, HighlightedRangeContext, History, Step, ChatMessage, ChatMessageRole @@ -59,6 +60,10 @@ class Models: def gpt4(self): return self.__load_openai_model("gpt-4") + @cached_property + def ggml(self): + return GGML("", "ggml") + def __model_from_name(self, model_name: str): if model_name == "starcoder": return self.starcoder @@ -73,6 +78,7 @@ class Models: @property def default(self): + return self.ggml default_model = self.sdk.config.default_model return self.__model_from_name(default_model) if default_model is not None else self.gpt35 diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py new file mode 100644 index 00000000..bef0d993 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -0,0 +1,99 @@ +from functools import cached_property +import json +from typing import Any, Coroutine, Dict, Generator, List, Union + +import aiohttp +from ...core.main import ChatMessage +import openai +from ..llm import LLM +from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +import certifi +import ssl + +ca_bundle_path = certifi.where() +ssl_context = ssl.create_default_context(cafile=ca_bundle_path) + +SERVER_URL = "http://localhost:8000" + + +class GGML(LLM): + api_key: str + default_model: str + + def __init__(self, api_key: str, default_model: str, system_message: str = None): + self.api_key = api_key + self.default_model = default_model + self.system_message = system_message + + openai.api_key = api_key + + @cached_property + def name(self): + return self.default_model + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.default_model} + + def count_tokens(self, text: str): + return count_tokens(self.default_model, text) + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) + + async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": messages, + **args + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + yield line.decode("utf-8") + except: + raise Exception(str(line)) + + async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None)) + args["stream"] = True + + async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ + "messages": messages, + **args + }) as resp: + # This is streaming application/json instaed of text/event-stream + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): + continue + json_chunk = "{}" if json_chunk == "" else json_chunk + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + yield json.loads(chunk[6:])["choices"][0]["delta"] + except: + raise Exception(str(line[0])) + + async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + args = {**self.default_args, **kwargs} + + async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None), + **args + }) as resp: + try: + return await resp.text() + except: + raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 73be0717..e1baeca1 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -3,13 +3,16 @@ from typing import Dict, List, Union from ...core.main import ChatMessage import tiktoken -aliases = {} +aliases = { + "ggml": "gpt-3.5-turbo", +} DEFAULT_MAX_TOKENS = 2048 MAX_TOKENS_FOR_MODEL = { "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-0613": 4096, "gpt-3.5-turbo-16k": 16384, - "gpt-4": 8192 + "gpt-4": 8192, + "ggml": 2048 } CHAT_MODELS = { "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index a10319d8..1df1e0bf 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -27,7 +27,7 @@ class SimpleChatStep(Step): async def run(self, sdk: ContinueSDK): completion = "" messages = self.messages or await sdk.get_chat_context() - async for chunk in sdk.models.gpt4.stream_chat(messages, temperature=0.5): + async for chunk in sdk.models.default.stream_chat(messages, temperature=0.5): if sdk.current_step_was_deleted(): return diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 4b35a758..0b067d7d 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -4,6 +4,7 @@ import subprocess from textwrap import dedent from typing import Coroutine, List, Literal, Union +from ...libs.llm.ggml import GGML from ...models.main import Range from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit @@ -180,7 +181,7 @@ class DefaultModelEditCodeStep(Step): # We don't know here all of the functions being passed in. # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. - model_to_use = sdk.models.gpt4 + model_to_use = sdk.models.default max_tokens = DEFAULT_MAX_TOKENS TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 @@ -442,6 +443,11 @@ class DefaultModelEditCodeStep(Step): completion_lines_covered = 0 repeating_file_suffix = False line_below_highlighted_range = file_suffix.lstrip().split("\n")[0] + + if isinstance(model_to_use, GGML): + messages = [ChatMessage( + role="user", content=f"```\n{rif.contents}\n```\n{self.user_input}\n```\n", summary=self.user_input)] + async for chunk in model_to_use.stream_chat(messages, temperature=0, max_tokens=max_tokens): # Stop early if it is repeating the file_suffix or the step was deleted if repeating_file_suffix: -- cgit v1.2.3-70-g09d2 From a3f4a2a59d6785499f3ce0c4af80b57b02de1b1f Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 16 Jul 2023 16:55:24 -0700 Subject: better prompt for editing --- continuedev/src/continuedev/core/config.py | 2 +- continuedev/src/continuedev/core/sdk.py | 5 ++-- continuedev/src/continuedev/libs/llm/ggml.py | 33 ++++++++------------------ continuedev/src/continuedev/steps/core/core.py | 5 ++-- 4 files changed, 17 insertions(+), 28 deletions(-) diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 6e430c04..957609c5 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -76,7 +76,7 @@ class ContinueConfig(BaseModel): server_url: Optional[str] = None allow_anonymous_telemetry: Optional[bool] = True default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", - "gpt-4"] = 'gpt-4' + "gpt-4", "ggml"] = 'gpt-4' custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", description="This is an example custom command. Use /config to edit it and create more", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 9389e1e9..eb60109c 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -82,7 +82,7 @@ class Models: @cached_property def ggml(self): - return GGML("", "ggml") + return GGML() def __model_from_name(self, model_name: str): if model_name == "starcoder": @@ -93,12 +93,13 @@ class Models: return self.gpt3516k elif model_name == "gpt-4": return self.gpt4 + elif model_name == "ggml": + return self.ggml else: raise Exception(f"Unknown model {model_name}") @property def default(self): - return self.ggml default_model = self.sdk.config.default_model return self.__model_from_name(default_model) if default_model is not None else self.gpt4 diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index bef0d993..d3589b70 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -4,39 +4,27 @@ from typing import Any, Coroutine, Dict, Generator, List, Union import aiohttp from ...core.main import ChatMessage -import openai from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top -import certifi -import ssl - -ca_bundle_path = certifi.where() -ssl_context = ssl.create_default_context(cafile=ca_bundle_path) +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens SERVER_URL = "http://localhost:8000" class GGML(LLM): - api_key: str - default_model: str - def __init__(self, api_key: str, default_model: str, system_message: str = None): - self.api_key = api_key - self.default_model = default_model + def __init__(self, system_message: str = None): self.system_message = system_message - openai.api_key = api_key - @cached_property def name(self): - return self.default_model + return "ggml" @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.default_model} + return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} def count_tokens(self, text: str): - return count_tokens(self.default_model, text) + return count_tokens(self.name, text) async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() @@ -45,9 +33,9 @@ class GGML(LLM): args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) + self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ "messages": messages, **args @@ -62,10 +50,10 @@ class GGML(LLM): async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None)) + self.name, messages, args["max_tokens"], None, functions=args.get("functions", None)) args["stream"] = True - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ "messages": messages, **args @@ -77,7 +65,6 @@ class GGML(LLM): json_chunk = line[0].decode("utf-8") if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): continue - json_chunk = "{}" if json_chunk == "" else json_chunk chunks = json_chunk.split("\n") for chunk in chunks: if chunk.strip() != "": @@ -88,7 +75,7 @@ class GGML(LLM): async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None), **args diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 2c9d8c01..d5a7cd9a 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -192,7 +192,8 @@ class DefaultModelEditCodeStep(Step): # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. model_to_use = sdk.models.default - max_tokens = DEFAULT_MAX_TOKENS + max_tokens = MAX_TOKENS_FOR_MODEL.get( + model_to_use.name, DEFAULT_MAX_TOKENS) / 2 TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: @@ -498,7 +499,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user if isinstance(model_to_use, GGML): messages = [ChatMessage( - role="user", content=f"```\n{rif.contents}\n```\n{self.user_input}\n```\n", summary=self.user_input)] + role="user", content=f"```\n{rif.contents}\n```\n\nUser request: \"{self.user_input}\"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n", summary=self.user_input)] generator = model_to_use.stream_chat( messages, temperature=0, max_tokens=max_tokens) -- cgit v1.2.3-70-g09d2 From 9687c05a5c8d6aeb15e7386129cdb16c0255b56e Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 16 Jul 2023 21:08:13 -0700 Subject: quick fix for quick fix --- extension/src/bridge.ts | 6 +----- extension/src/lang-server/codeActions.ts | 6 ++++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/extension/src/bridge.ts b/extension/src/bridge.ts index 7e6398be..d614ace4 100644 --- a/extension/src/bridge.ts +++ b/extension/src/bridge.ts @@ -1,11 +1,7 @@ import fetch from "node-fetch"; import * as path from "path"; import * as vscode from "vscode"; -import { - Configuration, - DebugApi, - UnittestApi, -} from "./client"; +import { Configuration, DebugApi, UnittestApi } from "./client"; import { convertSingleToDoubleQuoteJSON } from "./util/util"; import { getExtensionUri } from "./util/vscode"; import { extensionContext } from "./activation/activate"; diff --git a/extension/src/lang-server/codeActions.ts b/extension/src/lang-server/codeActions.ts index 07cf5f4e..f0d61ace 100644 --- a/extension/src/lang-server/codeActions.ts +++ b/extension/src/lang-server/codeActions.ts @@ -23,8 +23,10 @@ class ContinueQuickFixProvider implements vscode.CodeActionProvider { ); quickFix.isPreferred = false; const surroundingRange = new vscode.Range( - range.start.translate(-3, 0), - range.end.translate(3, 0) + Math.max(0, range.start.line - 3), + 0, + Math.min(document.lineCount, range.end.line + 3), + 0 ); quickFix.command = { command: "continue.quickFix", -- cgit v1.2.3-70-g09d2