From 0940d756dec3b98071ae5e5a12966e02420b3cd2 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Wed, 13 Sep 2023 11:07:21 -0700 Subject: fix: :bug: numerous small fixes --- continuedev/src/continuedev/libs/llm/ggml.py | 7 +++++++ continuedev/src/continuedev/libs/llm/hf_tgi.py | 19 +++++++++++-------- continuedev/src/continuedev/libs/util/commonregex.py | 8 -------- continuedev/src/continuedev/libs/util/count_tokens.py | 7 ++++++- docs/sidebars.js | 6 +++--- 5 files changed, 27 insertions(+), 20 deletions(-) diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 4f942bd6..0ab43703 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -42,6 +42,10 @@ class GGML(LLM): None, description="Whether SSL certificates should be verified when making the HTTP request", ) + proxy: Optional[str] = Field( + None, + description="Proxy URL to use when making the HTTP request", + ) ca_bundle_path: str = Field( None, description="Path to a custom CA bundle to use when making the HTTP request", @@ -96,6 +100,7 @@ class GGML(LLM): **args, }, headers=self.get_headers(), + proxy=self.proxy, ) as resp: async for line in resp.content.iter_any(): if line: @@ -129,6 +134,7 @@ class GGML(LLM): f"{self.server_url}/v1/chat/completions", json={"messages": messages, "stream": True, **args}, headers=self.get_headers(), + proxy=self.proxy, ) as resp: async for line, end in resp.content.iter_chunks(): json_chunk = line.decode("utf-8") @@ -165,6 +171,7 @@ class GGML(LLM): **args, }, headers=self.get_headers(), + proxy=self.proxy, ) as resp: text = await resp.text() try: diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py index df10cb9f..7cd699fa 100644 --- a/continuedev/src/continuedev/libs/llm/hf_tgi.py +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -31,11 +31,10 @@ class HuggingFaceTGI(LLM): def collect_args(self, options: CompletionOptions) -> Any: args = super().collect_args(options) - args = { - **args, - "max_new_tokens": args.get("max_tokens", 1024), - } - args.pop("max_tokens", None) + args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1} + args.pop("max_tokens") + args.pop("model") + args.pop("functions") return args async def _stream_complete(self, prompt, options): @@ -46,13 +45,17 @@ class HuggingFaceTGI(LLM): timeout=aiohttp.ClientTimeout(total=self.timeout), ) as client_session: async with client_session.post( - f"{self.server_url}", - json={"inputs": prompt, "stream": True, **args}, + f"{self.server_url}/generate_stream", + json={"inputs": prompt, "parameters": args}, ) as resp: async for line in resp.content.iter_any(): if line: chunk = line.decode("utf-8") - json_chunk = json.loads(chunk) + try: + json_chunk = json.loads(chunk) + except Exception as e: + print(f"Error parsing JSON: {e}") + continue text = json_chunk["details"]["best_of_sequences"][0][ "generated_text" ] diff --git a/continuedev/src/continuedev/libs/util/commonregex.py b/continuedev/src/continuedev/libs/util/commonregex.py index 9f119122..c2f6bb82 100644 --- a/continuedev/src/continuedev/libs/util/commonregex.py +++ b/continuedev/src/continuedev/libs/util/commonregex.py @@ -45,14 +45,8 @@ po_box = re.compile(r"P\.? ?O\.? Box \d+", re.IGNORECASE) ssn = re.compile( "(?!000|666|333)0*(?:[0-6][0-9][0-9]|[0-7][0-6][0-9]|[0-7][0-7][0-2])[- ](?!00)[0-9]{2}[- ](?!0000)[0-9]{4}" ) -win_absolute_filepath = re.compile( - r"^(?:[a-zA-Z]\:|\\\\[\w\.]+\\[\w.$]+)\\(?:[\w]+\\)*\w([\w.])+", re.IGNORECASE -) -unix_absolute_filepath = re.compile(r"^\/(?:[\/\w]+\/)*\w([\w.])+", re.IGNORECASE) regexes = { - "win_absolute_filepath": win_absolute_filepath, - "unix_absolute_filepath": unix_absolute_filepath, "dates": date, "times": time, "phones": phone, @@ -71,8 +65,6 @@ regexes = { } placeholders = { - "win_absolute_filepath": "", - "unix_absolute_filepath": "", "dates": "", "times": "