summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/libs
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-16 16:16:41 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-16 16:16:41 -0700
commitd80119982e9b60ca0022533a0086eb526dc7d957 (patch)
tree5926683c67e68ee436cce950dd0c0553a68fe0b4 /continuedev/src/continuedev/libs
parent20f4d07eb1d584569752e67c754951b7892e3e6b (diff)
downloadsncontinue-d80119982e9b60ca0022533a0086eb526dc7d957.tar.gz
sncontinue-d80119982e9b60ca0022533a0086eb526dc7d957.tar.bz2
sncontinue-d80119982e9b60ca0022533a0086eb526dc7d957.zip
ggml
Diffstat (limited to 'continuedev/src/continuedev/libs')
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py99
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py7
2 files changed, 104 insertions, 2 deletions
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
new file mode 100644
index 00000000..bef0d993
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -0,0 +1,99 @@
+from functools import cached_property
+import json
+from typing import Any, Coroutine, Dict, Generator, List, Union
+
+import aiohttp
+from ...core.main import ChatMessage
+import openai
+from ..llm import LLM
+from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+import certifi
+import ssl
+
+ca_bundle_path = certifi.where()
+ssl_context = ssl.create_default_context(cafile=ca_bundle_path)
+
+SERVER_URL = "http://localhost:8000"
+
+
+class GGML(LLM):
+ api_key: str
+ default_model: str
+
+ def __init__(self, api_key: str, default_model: str, system_message: str = None):
+ self.api_key = api_key
+ self.default_model = default_model
+ self.system_message = system_message
+
+ openai.api_key = api_key
+
+ @cached_property
+ def name(self):
+ return self.default_model
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.default_model}
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.default_model, text)
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = self.default_args.copy()
+ args.update(kwargs)
+ args["stream"] = True
+
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None))
+
+ async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
+ async with session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ yield line.decode("utf-8")
+ except:
+ raise Exception(str(line))
+
+ async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None))
+ args["stream"] = True
+
+ async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
+ async with session.post(f"{SERVER_URL}/v1/chat/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
+ continue
+ json_chunk = "{}" if json_chunk == "" else json_chunk
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield json.loads(chunk[6:])["choices"][0]["delta"]
+ except:
+ raise Exception(str(line[0]))
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
+ args = {**self.default_args, **kwargs}
+
+ async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
+ async with session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None),
+ **args
+ }) as resp:
+ try:
+ return await resp.text()
+ except:
+ raise Exception(await resp.text())
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 73be0717..e1baeca1 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -3,13 +3,16 @@ from typing import Dict, List, Union
from ...core.main import ChatMessage
import tiktoken
-aliases = {}
+aliases = {
+ "ggml": "gpt-3.5-turbo",
+}
DEFAULT_MAX_TOKENS = 2048
MAX_TOKENS_FOR_MODEL = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16384,
- "gpt-4": 8192
+ "gpt-4": 8192,
+ "ggml": 2048
}
CHAT_MODELS = {
"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"