diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-09-06 20:50:15 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-09-06 20:50:15 -0700 |
commit | db19f6bc98285d8ea45b4db16f619dffbec7c3db (patch) | |
tree | bf02cde62b76fd62aad4852fd5ee42a93a6df746 /continuedev | |
parent | d8e821e422678fd4248b472c7f3e67a32ecfefb5 (diff) | |
download | sncontinue-db19f6bc98285d8ea45b4db16f619dffbec7c3db.tar.gz sncontinue-db19f6bc98285d8ea45b4db16f619dffbec7c3db.tar.bz2 sncontinue-db19f6bc98285d8ea45b4db16f619dffbec7c3db.zip |
fix: :adhesive_bandage: allow GGML to use api.openai.com
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 77 | ||||
-rw-r--r-- | continuedev/src/continuedev/tests/llm_test.py | 23 |
2 files changed, 65 insertions, 35 deletions
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index e4971867..a183e643 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,4 +1,5 @@ import json +import ssl from typing import Any, Coroutine, List, Optional import aiohttp @@ -6,12 +7,15 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM from ..util.logging import logger +from . import CompletionOptions +from .openai import CHAT_MODELS from .prompts.edit import simplified_edit_prompt class GGML(LLM): server_url: str = "http://localhost:8000" verify_ssl: Optional[bool] = None + ca_bundle_path: str = None model: str = "ggml" prompt_templates = { @@ -21,13 +25,33 @@ class GGML(LLM): class Config: arbitrary_types_allowed = True - async def _stream_complete(self, prompt, options): - args = self.collect_args(options) + def create_client_session(self): + if self.ca_bundle_path is None: + ssl_context = ssl.create_default_context(cafile=self.ca_bundle_path) + tcp_connector = aiohttp.TCPConnector( + verify_ssl=self.verify_ssl, ssl=ssl_context + ) + else: + tcp_connector = aiohttp.TCPConnector(verify_ssl=self.verify_ssl) - async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), + return aiohttp.ClientSession( + connector=tcp_connector, timeout=aiohttp.ClientTimeout(total=self.timeout), - ) as client_session: + ) + + def get_headers(self): + headers = { + "Content-Type": "application/json", + } + if self.api_key is not None: + headers["Authorization"] = f"Bearer {self.api_key}" + + return headers + + async def _raw_stream_complete(self, prompt, options): + args = self.collect_args(options) + + async with self.create_client_session() as client_session: async with client_session.post( f"{self.server_url}/v1/completions", json={ @@ -35,6 +59,7 @@ class GGML(LLM): "stream": True, **args, }, + headers=self.get_headers(), ) as resp: async for line in resp.content.iter_any(): if line: @@ -54,14 +79,11 @@ class GGML(LLM): args = self.collect_args(options) async def generator(): - async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), - timeout=aiohttp.ClientTimeout(total=self.timeout), - ) as client_session: + async with self.create_client_session() as client_session: async with client_session.post( f"{self.server_url}/v1/chat/completions", json={"messages": messages, "stream": True, **args}, - headers={"Content-Type": "application/json"}, + headers=self.get_headers(), ) as resp: async for line, end in resp.content.iter_chunks(): json_chunk = line.decode("utf-8") @@ -87,19 +109,17 @@ class GGML(LLM): async for chunk in generator(): yield chunk - async def _complete(self, prompt: str, options) -> Coroutine[Any, Any, str]: + async def _raw_complete(self, prompt: str, options) -> Coroutine[Any, Any, str]: args = self.collect_args(options) - async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), - timeout=aiohttp.ClientTimeout(total=self.timeout), - ) as client_session: + async with self.create_client_session() as client_session: async with client_session.post( f"{self.server_url}/v1/completions", json={ "prompt": prompt, **args, }, + headers=self.get_headers(), ) as resp: text = await resp.text() try: @@ -109,3 +129,30 @@ class GGML(LLM): raise Exception( f"Error calling /completion endpoint: {e}\n\nResponse text: {text}" ) + + async def _complete(self, prompt: str, options: CompletionOptions): + completion = "" + if self.model in CHAT_MODELS: + async for chunk in self._stream_chat( + [{"role": "user", "content": prompt}], options + ): + if "content" in chunk: + completion += chunk["content"] + + else: + async for chunk in self._raw_stream_complete(prompt, options): + completion += chunk + + return completion + + async def _stream_complete(self, prompt, options: CompletionOptions): + if self.model in CHAT_MODELS: + async for chunk in self._stream_chat( + [{"role": "user", "content": prompt}], options + ): + if "content" in chunk: + yield chunk["content"] + + else: + async for chunk in self._raw_stream_complete(prompt, options): + yield chunk diff --git a/continuedev/src/continuedev/tests/llm_test.py b/continuedev/src/continuedev/tests/llm_test.py index f4aea1fb..8c4fe0c6 100644 --- a/continuedev/src/continuedev/tests/llm_test.py +++ b/continuedev/src/continuedev/tests/llm_test.py @@ -12,7 +12,6 @@ from continuedev.libs.llm.ggml import GGML from continuedev.libs.llm.openai import OpenAI from continuedev.libs.llm.together import TogetherLLM from continuedev.libs.util.count_tokens import DEFAULT_ARGS -from continuedev.tests.util.openai_mock import start_openai from continuedev.tests.util.prompts import tokyo_test_pair load_dotenv() @@ -140,30 +139,14 @@ class TestOpenAI(TestBaseLLM): class TestGGML(TestBaseLLM): def setup_class(cls): super().setup_class(cls) - port = 8000 cls.llm = GGML( - model=cls.model, + model="gpt-3.5-turbo", context_length=cls.context_length, system_message=cls.system_message, - api_base=f"http://localhost:{port}", + server_url="https://api.openai.com", + api_key=os.environ["OPENAI_API_KEY"], ) start_model(cls.llm) - cls.server = start_openai(port=port) - - def teardown_class(cls): - cls.server.terminate() - - @pytest.mark.asyncio - async def test_stream_chat(self): - pytest.skip(reason="GGML is not working") - - @pytest.mark.asyncio - async def test_stream_complete(self): - pytest.skip(reason="GGML is not working") - - @pytest.mark.asyncio - async def test_completion(self): - pytest.skip(reason="GGML is not working") @pytest.mark.skipif(True, reason="Together is not working") |