diff options
| author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-07-16 21:09:48 -0700 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-16 21:09:48 -0700 | 
| commit | d4319f09a3b8c1b0d9d1a7178910f09eac01fce9 (patch) | |
| tree | be5ef32725ab4ad0cba85bcc5415a79388fb6da0 /continuedev/src | |
| parent | a4a815628f702af806603015ec6805edd151328b (diff) | |
| parent | 9687c05a5c8d6aeb15e7386129cdb16c0255b56e (diff) | |
| download | sncontinue-d4319f09a3b8c1b0d9d1a7178910f09eac01fce9.tar.gz sncontinue-d4319f09a3b8c1b0d9d1a7178910f09eac01fce9.tar.bz2 sncontinue-d4319f09a3b8c1b0d9d1a7178910f09eac01fce9.zip | |
Merge pull request #278 from continuedev/ggml-server
ggml server
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 86 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 10 | 
6 files changed, 108 insertions, 6 deletions
| diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 6e430c04..957609c5 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -76,7 +76,7 @@ class ContinueConfig(BaseModel):      server_url: Optional[str] = None      allow_anonymous_telemetry: Optional[bool] = True      default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", -                           "gpt-4"] = 'gpt-4' +                           "gpt-4", "ggml"] = 'gpt-4'      custom_commands: Optional[List[CustomCommand]] = [CustomCommand(          name="test",          description="This is an example custom command. Use /config to edit it and create more", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index d73561d2..eb60109c 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -11,6 +11,7 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi  from ..models.filesystem import RangeInFile  from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI  from ..libs.llm.openai import OpenAI +from ..libs.llm.ggml import GGML  from .observation import Observation  from ..server.ide_protocol import AbstractIdeProtocolServer  from .main import Context, ContinueCustomException, History, Step, ChatMessage @@ -79,6 +80,10 @@ class Models:      def gpt4(self):          return self.__load_openai_model("gpt-4") +    @cached_property +    def ggml(self): +        return GGML() +      def __model_from_name(self, model_name: str):          if model_name == "starcoder":              return self.starcoder @@ -88,6 +93,8 @@ class Models:              return self.gpt3516k          elif model_name == "gpt-4":              return self.gpt4 +        elif model_name == "ggml": +            return self.ggml          else:              raise Exception(f"Unknown model {model_name}") diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py new file mode 100644 index 00000000..d3589b70 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -0,0 +1,86 @@ +from functools import cached_property +import json +from typing import Any, Coroutine, Dict, Generator, List, Union + +import aiohttp +from ...core.main import ChatMessage +from ..llm import LLM +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens + +SERVER_URL = "http://localhost:8000" + + +class GGML(LLM): + +    def __init__(self, system_message: str = None): +        self.system_message = system_message + +    @cached_property +    def name(self): +        return "ggml" + +    @property +    def default_args(self): +        return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + +    def count_tokens(self, text: str): +        return count_tokens(self.name, text) + +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        args = self.default_args.copy() +        args.update(kwargs) +        args["stream"] = True + +        args = {**self.default_args, **kwargs} +        messages = compile_chat_messages( +            self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) + +        async with aiohttp.ClientSession() as session: +            async with session.post(f"{SERVER_URL}/v1/completions", json={ +                "messages": messages, +                **args +            }) as resp: +                async for line in resp.content.iter_any(): +                    if line: +                        try: +                            yield line.decode("utf-8") +                        except: +                            raise Exception(str(line)) + +    async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        args = {**self.default_args, **kwargs} +        messages = compile_chat_messages( +            self.name, messages, args["max_tokens"], None, functions=args.get("functions", None)) +        args["stream"] = True + +        async with aiohttp.ClientSession() as session: +            async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ +                "messages": messages, +                **args +            }) as resp: +                # This is streaming application/json instaed of text/event-stream +                async for line in resp.content.iter_chunks(): +                    if line[1]: +                        try: +                            json_chunk = line[0].decode("utf-8") +                            if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): +                                continue +                            chunks = json_chunk.split("\n") +                            for chunk in chunks: +                                if chunk.strip() != "": +                                    yield json.loads(chunk[6:])["choices"][0]["delta"] +                        except: +                            raise Exception(str(line[0])) + +    async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: +        args = {**self.default_args, **kwargs} + +        async with aiohttp.ClientSession() as session: +            async with session.post(f"{SERVER_URL}/v1/completions", json={ +                "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None), +                **args +            }) as resp: +                try: +                    return await resp.text() +                except: +                    raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 73be0717..e1baeca1 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -3,13 +3,16 @@ from typing import Dict, List, Union  from ...core.main import ChatMessage  import tiktoken -aliases = {} +aliases = { +    "ggml": "gpt-3.5-turbo", +}  DEFAULT_MAX_TOKENS = 2048  MAX_TOKENS_FOR_MODEL = {      "gpt-3.5-turbo": 4096,      "gpt-3.5-turbo-0613": 4096,      "gpt-3.5-turbo-16k": 16384, -    "gpt-4": 8192 +    "gpt-4": 8192, +    "ggml": 2048  }  CHAT_MODELS = {      "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index 14a1cd41..3751dec2 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -28,7 +28,7 @@ class SimpleChatStep(Step):          completion = ""          messages = self.messages or await sdk.get_chat_context() -        generator = sdk.models.gpt4.stream_chat(messages, temperature=0.5) +        generator = sdk.models.default.stream_chat(messages, temperature=0.5)          try:              async for chunk in generator:                  if sdk.current_step_was_deleted(): diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 90d64287..d5a7cd9a 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -4,6 +4,7 @@ import subprocess  from textwrap import dedent  from typing import Coroutine, List, Literal, Union +from ...libs.llm.ggml import GGML  from ...models.main import Range  from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder  from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit @@ -190,8 +191,9 @@ class DefaultModelEditCodeStep(Step):          # We don't know here all of the functions being passed in.          # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.          # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. -        model_to_use = sdk.models.gpt4 -        max_tokens = DEFAULT_MAX_TOKENS +        model_to_use = sdk.models.default +        max_tokens = MAX_TOKENS_FOR_MODEL.get( +            model_to_use.name, DEFAULT_MAX_TOKENS) / 2          TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200          if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: @@ -495,6 +497,10 @@ Please output the code to be inserted at the cursor in order to fulfill the user          repeating_file_suffix = False          line_below_highlighted_range = file_suffix.lstrip().split("\n")[0] +        if isinstance(model_to_use, GGML): +            messages = [ChatMessage( +                role="user", content=f"```\n{rif.contents}\n```\n\nUser request: \"{self.user_input}\"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n", summary=self.user_input)] +          generator = model_to_use.stream_chat(              messages, temperature=0, max_tokens=max_tokens) | 
