diff options
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 77 | ||||
| -rw-r--r-- | continuedev/src/continuedev/tests/llm_test.py | 23 | 
2 files changed, 65 insertions, 35 deletions
| diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index e4971867..a183e643 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,4 +1,5 @@  import json +import ssl  from typing import Any, Coroutine, List, Optional  import aiohttp @@ -6,12 +7,15 @@ import aiohttp  from ...core.main import ChatMessage  from ..llm import LLM  from ..util.logging import logger +from . import CompletionOptions +from .openai import CHAT_MODELS  from .prompts.edit import simplified_edit_prompt  class GGML(LLM):      server_url: str = "http://localhost:8000"      verify_ssl: Optional[bool] = None +    ca_bundle_path: str = None      model: str = "ggml"      prompt_templates = { @@ -21,13 +25,33 @@ class GGML(LLM):      class Config:          arbitrary_types_allowed = True -    async def _stream_complete(self, prompt, options): -        args = self.collect_args(options) +    def create_client_session(self): +        if self.ca_bundle_path is None: +            ssl_context = ssl.create_default_context(cafile=self.ca_bundle_path) +            tcp_connector = aiohttp.TCPConnector( +                verify_ssl=self.verify_ssl, ssl=ssl_context +            ) +        else: +            tcp_connector = aiohttp.TCPConnector(verify_ssl=self.verify_ssl) -        async with aiohttp.ClientSession( -            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), +        return aiohttp.ClientSession( +            connector=tcp_connector,              timeout=aiohttp.ClientTimeout(total=self.timeout), -        ) as client_session: +        ) + +    def get_headers(self): +        headers = { +            "Content-Type": "application/json", +        } +        if self.api_key is not None: +            headers["Authorization"] = f"Bearer {self.api_key}" + +        return headers + +    async def _raw_stream_complete(self, prompt, options): +        args = self.collect_args(options) + +        async with self.create_client_session() as client_session:              async with client_session.post(                  f"{self.server_url}/v1/completions",                  json={ @@ -35,6 +59,7 @@ class GGML(LLM):                      "stream": True,                      **args,                  }, +                headers=self.get_headers(),              ) as resp:                  async for line in resp.content.iter_any():                      if line: @@ -54,14 +79,11 @@ class GGML(LLM):          args = self.collect_args(options)          async def generator(): -            async with aiohttp.ClientSession( -                connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), -                timeout=aiohttp.ClientTimeout(total=self.timeout), -            ) as client_session: +            async with self.create_client_session() as client_session:                  async with client_session.post(                      f"{self.server_url}/v1/chat/completions",                      json={"messages": messages, "stream": True, **args}, -                    headers={"Content-Type": "application/json"}, +                    headers=self.get_headers(),                  ) as resp:                      async for line, end in resp.content.iter_chunks():                          json_chunk = line.decode("utf-8") @@ -87,19 +109,17 @@ class GGML(LLM):              async for chunk in generator():                  yield chunk -    async def _complete(self, prompt: str, options) -> Coroutine[Any, Any, str]: +    async def _raw_complete(self, prompt: str, options) -> Coroutine[Any, Any, str]:          args = self.collect_args(options) -        async with aiohttp.ClientSession( -            connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), -            timeout=aiohttp.ClientTimeout(total=self.timeout), -        ) as client_session: +        async with self.create_client_session() as client_session:              async with client_session.post(                  f"{self.server_url}/v1/completions",                  json={                      "prompt": prompt,                      **args,                  }, +                headers=self.get_headers(),              ) as resp:                  text = await resp.text()                  try: @@ -109,3 +129,30 @@ class GGML(LLM):                      raise Exception(                          f"Error calling /completion endpoint: {e}\n\nResponse text: {text}"                      ) + +    async def _complete(self, prompt: str, options: CompletionOptions): +        completion = "" +        if self.model in CHAT_MODELS: +            async for chunk in self._stream_chat( +                [{"role": "user", "content": prompt}], options +            ): +                if "content" in chunk: +                    completion += chunk["content"] + +        else: +            async for chunk in self._raw_stream_complete(prompt, options): +                completion += chunk + +        return completion + +    async def _stream_complete(self, prompt, options: CompletionOptions): +        if self.model in CHAT_MODELS: +            async for chunk in self._stream_chat( +                [{"role": "user", "content": prompt}], options +            ): +                if "content" in chunk: +                    yield chunk["content"] + +        else: +            async for chunk in self._raw_stream_complete(prompt, options): +                yield chunk diff --git a/continuedev/src/continuedev/tests/llm_test.py b/continuedev/src/continuedev/tests/llm_test.py index f4aea1fb..8c4fe0c6 100644 --- a/continuedev/src/continuedev/tests/llm_test.py +++ b/continuedev/src/continuedev/tests/llm_test.py @@ -12,7 +12,6 @@ from continuedev.libs.llm.ggml import GGML  from continuedev.libs.llm.openai import OpenAI  from continuedev.libs.llm.together import TogetherLLM  from continuedev.libs.util.count_tokens import DEFAULT_ARGS -from continuedev.tests.util.openai_mock import start_openai  from continuedev.tests.util.prompts import tokyo_test_pair  load_dotenv() @@ -140,30 +139,14 @@ class TestOpenAI(TestBaseLLM):  class TestGGML(TestBaseLLM):      def setup_class(cls):          super().setup_class(cls) -        port = 8000          cls.llm = GGML( -            model=cls.model, +            model="gpt-3.5-turbo",              context_length=cls.context_length,              system_message=cls.system_message, -            api_base=f"http://localhost:{port}", +            server_url="https://api.openai.com", +            api_key=os.environ["OPENAI_API_KEY"],          )          start_model(cls.llm) -        cls.server = start_openai(port=port) - -    def teardown_class(cls): -        cls.server.terminate() - -    @pytest.mark.asyncio -    async def test_stream_chat(self): -        pytest.skip(reason="GGML is not working") - -    @pytest.mark.asyncio -    async def test_stream_complete(self): -        pytest.skip(reason="GGML is not working") - -    @pytest.mark.asyncio -    async def test_completion(self): -        pytest.skip(reason="GGML is not working")  @pytest.mark.skipif(True, reason="Together is not working") | 
