diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-09-13 11:07:21 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-09-13 11:07:21 -0700 | 
| commit | 0940d756dec3b98071ae5e5a12966e02420b3cd2 (patch) | |
| tree | 8e9a6165d73ebb9584eab65e6cfd88b4b2c9a7a5 /continuedev/src | |
| parent | 8cce7eb46b325b7d44a6ed66ce77c860142fa97d (diff) | |
| download | sncontinue-0940d756dec3b98071ae5e5a12966e02420b3cd2.tar.gz sncontinue-0940d756dec3b98071ae5e5a12966e02420b3cd2.tar.bz2 sncontinue-0940d756dec3b98071ae5e5a12966e02420b3cd2.zip | |
fix: :bug: numerous small fixes
Diffstat (limited to 'continuedev/src')
4 files changed, 24 insertions, 17 deletions
| diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 4f942bd6..0ab43703 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -42,6 +42,10 @@ class GGML(LLM):          None,          description="Whether SSL certificates should be verified when making the HTTP request",      ) +    proxy: Optional[str] = Field( +        None, +        description="Proxy URL to use when making the HTTP request", +    )      ca_bundle_path: str = Field(          None,          description="Path to a custom CA bundle to use when making the HTTP request", @@ -96,6 +100,7 @@ class GGML(LLM):                      **args,                  },                  headers=self.get_headers(), +                proxy=self.proxy,              ) as resp:                  async for line in resp.content.iter_any():                      if line: @@ -129,6 +134,7 @@ class GGML(LLM):                      f"{self.server_url}/v1/chat/completions",                      json={"messages": messages, "stream": True, **args},                      headers=self.get_headers(), +                    proxy=self.proxy,                  ) as resp:                      async for line, end in resp.content.iter_chunks():                          json_chunk = line.decode("utf-8") @@ -165,6 +171,7 @@ class GGML(LLM):                      **args,                  },                  headers=self.get_headers(), +                proxy=self.proxy,              ) as resp:                  text = await resp.text()                  try: diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py index df10cb9f..7cd699fa 100644 --- a/continuedev/src/continuedev/libs/llm/hf_tgi.py +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -31,11 +31,10 @@ class HuggingFaceTGI(LLM):      def collect_args(self, options: CompletionOptions) -> Any:          args = super().collect_args(options) -        args = { -            **args, -            "max_new_tokens": args.get("max_tokens", 1024), -        } -        args.pop("max_tokens", None) +        args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1} +        args.pop("max_tokens") +        args.pop("model") +        args.pop("functions")          return args      async def _stream_complete(self, prompt, options): @@ -46,13 +45,17 @@ class HuggingFaceTGI(LLM):              timeout=aiohttp.ClientTimeout(total=self.timeout),          ) as client_session:              async with client_session.post( -                f"{self.server_url}", -                json={"inputs": prompt, "stream": True, **args}, +                f"{self.server_url}/generate_stream", +                json={"inputs": prompt, "parameters": args},              ) as resp:                  async for line in resp.content.iter_any():                      if line:                          chunk = line.decode("utf-8") -                        json_chunk = json.loads(chunk) +                        try: +                            json_chunk = json.loads(chunk) +                        except Exception as e: +                            print(f"Error parsing JSON: {e}") +                            continue                          text = json_chunk["details"]["best_of_sequences"][0][                              "generated_text"                          ] diff --git a/continuedev/src/continuedev/libs/util/commonregex.py b/continuedev/src/continuedev/libs/util/commonregex.py index 9f119122..c2f6bb82 100644 --- a/continuedev/src/continuedev/libs/util/commonregex.py +++ b/continuedev/src/continuedev/libs/util/commonregex.py @@ -45,14 +45,8 @@ po_box = re.compile(r"P\.? ?O\.? Box \d+", re.IGNORECASE)  ssn = re.compile(      "(?!000|666|333)0*(?:[0-6][0-9][0-9]|[0-7][0-6][0-9]|[0-7][0-7][0-2])[- ](?!00)[0-9]{2}[- ](?!0000)[0-9]{4}"  ) -win_absolute_filepath = re.compile( -    r"^(?:[a-zA-Z]\:|\\\\[\w\.]+\\[\w.$]+)\\(?:[\w]+\\)*\w([\w.])+", re.IGNORECASE -) -unix_absolute_filepath = re.compile(r"^\/(?:[\/\w]+\/)*\w([\w.])+", re.IGNORECASE)  regexes = { -    "win_absolute_filepath": win_absolute_filepath, -    "unix_absolute_filepath": unix_absolute_filepath,      "dates": date,      "times": time,      "phones": phone, @@ -71,8 +65,6 @@ regexes = {  }  placeholders = { -    "win_absolute_filepath": "<FILEPATH>", -    "unix_absolute_filepath": "<FILEPATH>",      "dates": "<DATE>",      "times": "<TIME>",      "phones": "<PHONE>", diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index aaa32907..1c1e020e 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -18,8 +18,11 @@ DEFAULT_ARGS = {      "temperature": 0.5,  } +already_saw_import_err = False +  def encoding_for_model(model_name: str): +    global already_saw_import_err      try:          import tiktoken          from tiktoken_ext import openai_public  # noqa: F401 @@ -29,7 +32,9 @@ def encoding_for_model(model_name: str):          except:              return tiktoken.encoding_for_model("gpt-3.5-turbo")      except Exception as e: -        print("Error importing tiktoken", e) +        if not already_saw_import_err: +            print("Error importing tiktoken", e) +            already_saw_import_err = True          return None | 
