diff options
author | Luna <git@l4.pm> | 2023-07-30 23:23:50 -0300 |
---|---|---|
committer | Luna <git@l4.pm> | 2023-07-30 23:23:50 -0300 |
commit | 374058e07ca699b5a345b270067636f6785df3af (patch) | |
tree | 8e1fdb9f9b86d0953a4acf462e2daba9e6abd75c | |
parent | 39076efbd74106ad59ad65e31d52b8d591c1d485 (diff) | |
download | sncontinue-374058e07ca699b5a345b270067636f6785df3af.tar.gz sncontinue-374058e07ca699b5a345b270067636f6785df3af.tar.bz2 sncontinue-374058e07ca699b5a345b270067636f6785df3af.zip |
fix GGML client session usage
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 78 |
1 files changed, 37 insertions, 41 deletions
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 7fa51e34..a760f7fb 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -13,7 +13,7 @@ SERVER_URL = "http://localhost:8000" class GGML(LLM): - _client_session: aiohttp.ClientSession + _client_session: aiohttp.ClientSession = None def __init__(self, system_message: str = None): self.system_message = system_message @@ -22,7 +22,7 @@ class GGML(LLM): self._client_session = aiohttp.ClientSession() async def stop(self): - pass + await self._client_session.close() @property def name(self): @@ -48,18 +48,16 @@ class GGML(LLM): messages = compile_chat_messages( self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) - # TODO move to single self.session variable (proxy setting etc) - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/v1/completions", json={ - "messages": messages, - **args - }) as resp: - async for line in resp.content.iter_any(): - if line: - try: - yield line.decode("utf-8") - except: - raise Exception(str(line)) + async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": messages, + **args + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + yield line.decode("utf-8") + except: + raise Exception(str(line)) async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} @@ -67,34 +65,32 @@ class GGML(LLM): self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) args["stream"] = True - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ - "messages": messages, - **args - }) as resp: - # This is streaming application/json instaed of text/event-stream - async for line in resp.content.iter_chunks(): - if line[1]: - try: - json_chunk = line[0].decode("utf-8") - if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): - continue - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - yield json.loads(chunk[6:])["choices"][0]["delta"] - except: - raise Exception(str(line[0])) + async with self._client_session.post(f"{SERVER_URL}/v1/chat/completions", json={ + "messages": messages, + **args + }) as resp: + # This is streaming application/json instaed of text/event-stream + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): + continue + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + yield json.loads(chunk[6:])["choices"][0]["delta"] + except: + raise Exception(str(line[0])) async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/v1/completions", json={ - "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), - **args - }) as resp: - try: - return await resp.text() - except: - raise Exception(await resp.text()) + async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), + **args + }) as resp: + try: + return await resp.text() + except: + raise Exception(await resp.text()) |