summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-09-06 20:50:15 -0700
committerNate Sesti <sestinj@gmail.com>2023-09-06 20:50:15 -0700
commitdb19f6bc98285d8ea45b4db16f619dffbec7c3db (patch)
treebf02cde62b76fd62aad4852fd5ee42a93a6df746 /continuedev
parentd8e821e422678fd4248b472c7f3e67a32ecfefb5 (diff)
downloadsncontinue-db19f6bc98285d8ea45b4db16f619dffbec7c3db.tar.gz
sncontinue-db19f6bc98285d8ea45b4db16f619dffbec7c3db.tar.bz2
sncontinue-db19f6bc98285d8ea45b4db16f619dffbec7c3db.zip
fix: :adhesive_bandage: allow GGML to use api.openai.com
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py77
-rw-r--r--continuedev/src/continuedev/tests/llm_test.py23
2 files changed, 65 insertions, 35 deletions
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index e4971867..a183e643 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -1,4 +1,5 @@
import json
+import ssl
from typing import Any, Coroutine, List, Optional
import aiohttp
@@ -6,12 +7,15 @@ import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
from ..util.logging import logger
+from . import CompletionOptions
+from .openai import CHAT_MODELS
from .prompts.edit import simplified_edit_prompt
class GGML(LLM):
server_url: str = "http://localhost:8000"
verify_ssl: Optional[bool] = None
+ ca_bundle_path: str = None
model: str = "ggml"
prompt_templates = {
@@ -21,13 +25,33 @@ class GGML(LLM):
class Config:
arbitrary_types_allowed = True
- async def _stream_complete(self, prompt, options):
- args = self.collect_args(options)
+ def create_client_session(self):
+ if self.ca_bundle_path is None:
+ ssl_context = ssl.create_default_context(cafile=self.ca_bundle_path)
+ tcp_connector = aiohttp.TCPConnector(
+ verify_ssl=self.verify_ssl, ssl=ssl_context
+ )
+ else:
+ tcp_connector = aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
- async with aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
+ return aiohttp.ClientSession(
+ connector=tcp_connector,
timeout=aiohttp.ClientTimeout(total=self.timeout),
- ) as client_session:
+ )
+
+ def get_headers(self):
+ headers = {
+ "Content-Type": "application/json",
+ }
+ if self.api_key is not None:
+ headers["Authorization"] = f"Bearer {self.api_key}"
+
+ return headers
+
+ async def _raw_stream_complete(self, prompt, options):
+ args = self.collect_args(options)
+
+ async with self.create_client_session() as client_session:
async with client_session.post(
f"{self.server_url}/v1/completions",
json={
@@ -35,6 +59,7 @@ class GGML(LLM):
"stream": True,
**args,
},
+ headers=self.get_headers(),
) as resp:
async for line in resp.content.iter_any():
if line:
@@ -54,14 +79,11 @@ class GGML(LLM):
args = self.collect_args(options)
async def generator():
- async with aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
- timeout=aiohttp.ClientTimeout(total=self.timeout),
- ) as client_session:
+ async with self.create_client_session() as client_session:
async with client_session.post(
f"{self.server_url}/v1/chat/completions",
json={"messages": messages, "stream": True, **args},
- headers={"Content-Type": "application/json"},
+ headers=self.get_headers(),
) as resp:
async for line, end in resp.content.iter_chunks():
json_chunk = line.decode("utf-8")
@@ -87,19 +109,17 @@ class GGML(LLM):
async for chunk in generator():
yield chunk
- async def _complete(self, prompt: str, options) -> Coroutine[Any, Any, str]:
+ async def _raw_complete(self, prompt: str, options) -> Coroutine[Any, Any, str]:
args = self.collect_args(options)
- async with aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
- timeout=aiohttp.ClientTimeout(total=self.timeout),
- ) as client_session:
+ async with self.create_client_session() as client_session:
async with client_session.post(
f"{self.server_url}/v1/completions",
json={
"prompt": prompt,
**args,
},
+ headers=self.get_headers(),
) as resp:
text = await resp.text()
try:
@@ -109,3 +129,30 @@ class GGML(LLM):
raise Exception(
f"Error calling /completion endpoint: {e}\n\nResponse text: {text}"
)
+
+ async def _complete(self, prompt: str, options: CompletionOptions):
+ completion = ""
+ if self.model in CHAT_MODELS:
+ async for chunk in self._stream_chat(
+ [{"role": "user", "content": prompt}], options
+ ):
+ if "content" in chunk:
+ completion += chunk["content"]
+
+ else:
+ async for chunk in self._raw_stream_complete(prompt, options):
+ completion += chunk
+
+ return completion
+
+ async def _stream_complete(self, prompt, options: CompletionOptions):
+ if self.model in CHAT_MODELS:
+ async for chunk in self._stream_chat(
+ [{"role": "user", "content": prompt}], options
+ ):
+ if "content" in chunk:
+ yield chunk["content"]
+
+ else:
+ async for chunk in self._raw_stream_complete(prompt, options):
+ yield chunk
diff --git a/continuedev/src/continuedev/tests/llm_test.py b/continuedev/src/continuedev/tests/llm_test.py
index f4aea1fb..8c4fe0c6 100644
--- a/continuedev/src/continuedev/tests/llm_test.py
+++ b/continuedev/src/continuedev/tests/llm_test.py
@@ -12,7 +12,6 @@ from continuedev.libs.llm.ggml import GGML
from continuedev.libs.llm.openai import OpenAI
from continuedev.libs.llm.together import TogetherLLM
from continuedev.libs.util.count_tokens import DEFAULT_ARGS
-from continuedev.tests.util.openai_mock import start_openai
from continuedev.tests.util.prompts import tokyo_test_pair
load_dotenv()
@@ -140,30 +139,14 @@ class TestOpenAI(TestBaseLLM):
class TestGGML(TestBaseLLM):
def setup_class(cls):
super().setup_class(cls)
- port = 8000
cls.llm = GGML(
- model=cls.model,
+ model="gpt-3.5-turbo",
context_length=cls.context_length,
system_message=cls.system_message,
- api_base=f"http://localhost:{port}",
+ server_url="https://api.openai.com",
+ api_key=os.environ["OPENAI_API_KEY"],
)
start_model(cls.llm)
- cls.server = start_openai(port=port)
-
- def teardown_class(cls):
- cls.server.terminate()
-
- @pytest.mark.asyncio
- async def test_stream_chat(self):
- pytest.skip(reason="GGML is not working")
-
- @pytest.mark.asyncio
- async def test_stream_complete(self):
- pytest.skip(reason="GGML is not working")
-
- @pytest.mark.asyncio
- async def test_completion(self):
- pytest.skip(reason="GGML is not working")
@pytest.mark.skipif(True, reason="Together is not working")