diff options
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 78 | 
1 files changed, 37 insertions, 41 deletions
| diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 7fa51e34..a760f7fb 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -13,7 +13,7 @@ SERVER_URL = "http://localhost:8000"  class GGML(LLM): -    _client_session: aiohttp.ClientSession +    _client_session: aiohttp.ClientSession = None      def __init__(self, system_message: str = None):          self.system_message = system_message @@ -22,7 +22,7 @@ class GGML(LLM):          self._client_session = aiohttp.ClientSession()      async def stop(self): -        pass +        await self._client_session.close()      @property      def name(self): @@ -48,18 +48,16 @@ class GGML(LLM):          messages = compile_chat_messages(              self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) -        # TODO move to single self.session variable (proxy setting etc) -        async with self._client_session as session: -            async with session.post(f"{SERVER_URL}/v1/completions", json={ -                "messages": messages, -                **args -            }) as resp: -                async for line in resp.content.iter_any(): -                    if line: -                        try: -                            yield line.decode("utf-8") -                        except: -                            raise Exception(str(line)) +        async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ +            "messages": messages, +            **args +        }) as resp: +            async for line in resp.content.iter_any(): +                if line: +                    try: +                        yield line.decode("utf-8") +                    except: +                        raise Exception(str(line))      async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs} @@ -67,34 +65,32 @@ class GGML(LLM):              self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)          args["stream"] = True -        async with self._client_session as session: -            async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ -                "messages": messages, -                **args -            }) as resp: -                # This is streaming application/json instaed of text/event-stream -                async for line in resp.content.iter_chunks(): -                    if line[1]: -                        try: -                            json_chunk = line[0].decode("utf-8") -                            if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): -                                continue -                            chunks = json_chunk.split("\n") -                            for chunk in chunks: -                                if chunk.strip() != "": -                                    yield json.loads(chunk[6:])["choices"][0]["delta"] -                        except: -                            raise Exception(str(line[0])) +        async with self._client_session.post(f"{SERVER_URL}/v1/chat/completions", json={ +            "messages": messages, +            **args +        }) as resp: +            # This is streaming application/json instaed of text/event-stream +            async for line in resp.content.iter_chunks(): +                if line[1]: +                    try: +                        json_chunk = line[0].decode("utf-8") +                        if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): +                            continue +                        chunks = json_chunk.split("\n") +                        for chunk in chunks: +                            if chunk.strip() != "": +                                yield json.loads(chunk[6:])["choices"][0]["delta"] +                    except: +                        raise Exception(str(line[0]))      async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} -        async with self._client_session as session: -            async with session.post(f"{SERVER_URL}/v1/completions", json={ -                "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), -                **args -            }) as resp: -                try: -                    return await resp.text() -                except: -                    raise Exception(await resp.text()) +        async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ +            "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), +            **args +        }) as resp: +            try: +                return await resp.text() +            except: +                raise Exception(await resp.text()) | 
