diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 9 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 13 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 84 | 
3 files changed, 98 insertions, 8 deletions
| diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 1da190ff..b806ef73 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -15,6 +15,7 @@ from .observation import Observation  from ..server.ide_protocol import AbstractIdeProtocolServer  from .main import Context, ContinueCustomException, History, Step, ChatMessage, ChatMessageRole  from ..steps.core.core import * +from ..libs.llm.proxy_server import ProxyServer  class Autopilot: @@ -37,7 +38,9 @@ class Models:      def gpt35(self):          async def load_gpt35():              api_key = await self.sdk.get_user_secret( -                'OPENAI_API_KEY', 'Please add your OpenAI API key to the .env file') +                'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free') +            if api_key == "": +                return ProxyServer(self.sdk.ide.unique_id, "gpt-3.5-turbo")              return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo")          return asyncio.get_event_loop().run_until_complete(load_gpt35()) @@ -45,7 +48,9 @@ class Models:      def gpt4(self):          async def load_gpt4():              api_key = await self.sdk.get_user_secret( -                'OPENAI_API_KEY', 'Please add your OpenAI API key to the .env file') +                'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free') +            if api_key == "": +                return ProxyServer(self.sdk.ide.unique_id, "gpt-4")              return OpenAI(api_key=api_key, default_model="gpt-4")          return asyncio.get_event_loop().run_until_complete(load_gpt4()) diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 4889c01e..2986b2c4 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,24 +1,25 @@ -from typing import List, Union +from abc import ABC +from typing import Any, Dict, Generator, List, Union  from ...core.main import ChatMessage  from ...models.main import AbstractModel  from pydantic import BaseModel -class LLM(BaseModel): +class LLM(ABC):      system_message: Union[str, None] = None      def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):          """Return the completion of the text with the given temperature."""          raise +    def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        """Yield a stream of chat messages.""" +        raise NotImplementedError +      def __call__(self, prompt: str, **kwargs):          return self.complete(prompt, **kwargs) -    def fine_tune(self, pairs: list): -        """Fine tune the model on the given prompt/completion pairs.""" -        raise NotImplementedError -      def with_system_message(self, system_message: Union[str, None]):          """Return a new model with the given system message."""          raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py new file mode 100644 index 00000000..f75788d2 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -0,0 +1,84 @@ +from functools import cached_property +import json +from typing import Any, Dict, Generator, List, Literal, Union +import requests +import tiktoken + +from ...core.main import ChatMessage +from ..llm import LLM + +MAX_TOKENS_FOR_MODEL = { +    "gpt-3.5-turbo": 4097, +    "gpt-4": 4097, +} +DEFAULT_MAX_TOKENS = 2048 +CHAT_MODELS = { +    "gpt-3.5-turbo", "gpt-4" +} + +SERVER_URL = "http://127.0.0.1:8002" + + +class ProxyServer(LLM): +    unique_id: str +    default_model: Literal["gpt-3.5-turbo", "gpt-4"] + +    def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None): +        self.unique_id = unique_id +        self.default_model = default_model +        self.system_message = system_message + +    @cached_property +    def __encoding_for_model(self): +        aliases = { +            "gpt-3.5-turbo": "gpt3" +        } +        return tiktoken.encoding_for_model(self.default_model) + +    def count_tokens(self, text: str): +        return len(self.__encoding_for_model.encode(text, disallowed_special=())) + +    def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        resp = requests.post(f"{SERVER_URL}/stream_complete", json={ +            "chat_history": self.compile_chat_messages(with_history, prompt), +            "model": self.default_model, +            "unique_id": self.unique_id, +        }, stream=True) +        for line in resp.iter_lines(): +            if line: +                yield line.decode("utf-8") + +    def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): +        tokens = tokens_for_completion +        for i in range(len(chat_history) - 1, -1, -1): +            message = chat_history[i] +            tokens += self.count_tokens(message.content) +            if tokens > max_tokens: +                return chat_history[i + 1:] +        return chat_history + +    def compile_chat_messages(self, msgs: List[ChatMessage], prompt: str) -> List[Dict]: +        msgs = self.__prune_chat_history(msgs, MAX_TOKENS_FOR_MODEL[self.default_model], self.count_tokens( +            prompt) + 1000 + self.count_tokens(self.system_message or "")) +        history = [] +        if self.system_message: +            history.append({ +                "role": "system", +                "content": self.system_message +            }) +        history += [msg.dict() for msg in msgs] +        history.append({ +            "role": "user", +            "content": prompt +        }) + +        return history + +    def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> str: + +        resp = requests.post(f"{SERVER_URL}/complete", json={ +            "chat_history": self.compile_chat_messages(with_history, prompt), +            "model": self.default_model, +            "unique_id": self.unique_id, +        }) +        return json.loads(resp.text) | 
