diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-07-16 16:55:24 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-07-16 16:55:24 -0700 | 
| commit | a3f4a2a59d6785499f3ce0c4af80b57b02de1b1f (patch) | |
| tree | bcf02a1382641ed6cbea226d7b1cbdeeeadb5bf9 /continuedev | |
| parent | 4c3a25a1c8938f8132233e021c74d98eb19d7ddd (diff) | |
| download | sncontinue-a3f4a2a59d6785499f3ce0c4af80b57b02de1b1f.tar.gz sncontinue-a3f4a2a59d6785499f3ce0c4af80b57b02de1b1f.tar.bz2 sncontinue-a3f4a2a59d6785499f3ce0c4af80b57b02de1b1f.zip | |
better prompt for editing
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 33 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 5 | 
4 files changed, 17 insertions, 28 deletions
| diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 6e430c04..957609c5 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -76,7 +76,7 @@ class ContinueConfig(BaseModel):      server_url: Optional[str] = None      allow_anonymous_telemetry: Optional[bool] = True      default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", -                           "gpt-4"] = 'gpt-4' +                           "gpt-4", "ggml"] = 'gpt-4'      custom_commands: Optional[List[CustomCommand]] = [CustomCommand(          name="test",          description="This is an example custom command. Use /config to edit it and create more", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 9389e1e9..eb60109c 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -82,7 +82,7 @@ class Models:      @cached_property      def ggml(self): -        return GGML("", "ggml") +        return GGML()      def __model_from_name(self, model_name: str):          if model_name == "starcoder": @@ -93,12 +93,13 @@ class Models:              return self.gpt3516k          elif model_name == "gpt-4":              return self.gpt4 +        elif model_name == "ggml": +            return self.ggml          else:              raise Exception(f"Unknown model {model_name}")      @property      def default(self): -        return self.ggml          default_model = self.sdk.config.default_model          return self.__model_from_name(default_model) if default_model is not None else self.gpt4 diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index bef0d993..d3589b70 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -4,39 +4,27 @@ from typing import Any, Coroutine, Dict, Generator, List, Union  import aiohttp  from ...core.main import ChatMessage -import openai  from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top -import certifi -import ssl - -ca_bundle_path = certifi.where() -ssl_context = ssl.create_default_context(cafile=ca_bundle_path) +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens  SERVER_URL = "http://localhost:8000"  class GGML(LLM): -    api_key: str -    default_model: str -    def __init__(self, api_key: str, default_model: str, system_message: str = None): -        self.api_key = api_key -        self.default_model = default_model +    def __init__(self, system_message: str = None):          self.system_message = system_message -        openai.api_key = api_key -      @cached_property      def name(self): -        return self.default_model +        return "ggml"      @property      def default_args(self): -        return {**DEFAULT_ARGS, "model": self.default_model} +        return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}      def count_tokens(self, text: str): -        return count_tokens(self.default_model, text) +        return count_tokens(self.name, text)      async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy() @@ -45,9 +33,9 @@ class GGML(LLM):          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) +            self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) -        async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: +        async with aiohttp.ClientSession() as session:              async with session.post(f"{SERVER_URL}/v1/completions", json={                  "messages": messages,                  **args @@ -62,10 +50,10 @@ class GGML(LLM):      async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None)) +            self.name, messages, args["max_tokens"], None, functions=args.get("functions", None))          args["stream"] = True -        async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: +        async with aiohttp.ClientSession() as session:              async with session.post(f"{SERVER_URL}/v1/chat/completions", json={                  "messages": messages,                  **args @@ -77,7 +65,6 @@ class GGML(LLM):                              json_chunk = line[0].decode("utf-8")                              if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):                                  continue -                            json_chunk = "{}" if json_chunk == "" else json_chunk                              chunks = json_chunk.split("\n")                              for chunk in chunks:                                  if chunk.strip() != "": @@ -88,7 +75,7 @@ class GGML(LLM):      async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} -        async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: +        async with aiohttp.ClientSession() as session:              async with session.post(f"{SERVER_URL}/v1/completions", json={                  "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None),                  **args diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 2c9d8c01..d5a7cd9a 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -192,7 +192,8 @@ class DefaultModelEditCodeStep(Step):          # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.          # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.          model_to_use = sdk.models.default -        max_tokens = DEFAULT_MAX_TOKENS +        max_tokens = MAX_TOKENS_FOR_MODEL.get( +            model_to_use.name, DEFAULT_MAX_TOKENS) / 2          TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200          if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: @@ -498,7 +499,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user          if isinstance(model_to_use, GGML):              messages = [ChatMessage( -                role="user", content=f"```\n{rif.contents}\n```\n{self.user_input}\n```\n", summary=self.user_input)] +                role="user", content=f"```\n{rif.contents}\n```\n\nUser request: \"{self.user_input}\"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n", summary=self.user_input)]          generator = model_to_use.stream_chat(              messages, temperature=0, max_tokens=max_tokens) | 
