summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuna <git@l4.pm>2023-07-30 23:23:50 -0300
committerLuna <git@l4.pm>2023-07-30 23:23:50 -0300
commit374058e07ca699b5a345b270067636f6785df3af (patch)
tree8e1fdb9f9b86d0953a4acf462e2daba9e6abd75c
parent39076efbd74106ad59ad65e31d52b8d591c1d485 (diff)
downloadsncontinue-374058e07ca699b5a345b270067636f6785df3af.tar.gz
sncontinue-374058e07ca699b5a345b270067636f6785df3af.tar.bz2
sncontinue-374058e07ca699b5a345b270067636f6785df3af.zip
fix GGML client session usage
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py78
1 files changed, 37 insertions, 41 deletions
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 7fa51e34..a760f7fb 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -13,7 +13,7 @@ SERVER_URL = "http://localhost:8000"
class GGML(LLM):
- _client_session: aiohttp.ClientSession
+ _client_session: aiohttp.ClientSession = None
def __init__(self, system_message: str = None):
self.system_message = system_message
@@ -22,7 +22,7 @@ class GGML(LLM):
self._client_session = aiohttp.ClientSession()
async def stop(self):
- pass
+ await self._client_session.close()
@property
def name(self):
@@ -48,18 +48,16 @@ class GGML(LLM):
messages = compile_chat_messages(
self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
- # TODO move to single self.session variable (proxy setting etc)
- async with self._client_session as session:
- async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": messages,
- **args
- }) as resp:
- async for line in resp.content.iter_any():
- if line:
- try:
- yield line.decode("utf-8")
- except:
- raise Exception(str(line))
+ async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ yield line.decode("utf-8")
+ except:
+ raise Exception(str(line))
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
@@ -67,34 +65,32 @@ class GGML(LLM):
self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
args["stream"] = True
- async with self._client_session as session:
- async with session.post(f"{SERVER_URL}/v1/chat/completions", json={
- "messages": messages,
- **args
- }) as resp:
- # This is streaming application/json instaed of text/event-stream
- async for line in resp.content.iter_chunks():
- if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
- continue
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- yield json.loads(chunk[6:])["choices"][0]["delta"]
- except:
- raise Exception(str(line[0]))
+ async with self._client_session.post(f"{SERVER_URL}/v1/chat/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
+ continue
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield json.loads(chunk[6:])["choices"][0]["delta"]
+ except:
+ raise Exception(str(line[0]))
async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
- async with self._client_session as session:
- async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
- **args
- }) as resp:
- try:
- return await resp.text()
- except:
- raise Exception(await resp.text())
+ async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ **args
+ }) as resp:
+ try:
+ return await resp.text()
+ except:
+ raise Exception(await resp.text())