diff options
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 19 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 8 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 19 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/strings.py (renamed from continuedev/src/continuedev/libs/util/dedent.py) | 24 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 8 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 24 | 
7 files changed, 75 insertions, 29 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 0696c360..4e177ac9 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -36,7 +36,7 @@ def get_error_title(e: Exception) -> str:      elif isinstance(e, openai_errors.APIConnectionError):          return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\""      elif isinstance(e, openai_errors.InvalidRequestError): -        return 'Your API key does not have access to GPT-4. You can use ours for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""' +        return 'Invalid request sent to OpenAI. Please try again.'      elif e.__str__().startswith("Cannot connect to host"):          return "The request failed. Please check your internet connection and try again."      return e.__str__() or e.__repr__() @@ -166,6 +166,22 @@ class Autopilot(ContinueBaseModel):          if not any(map(lambda x: x.editing, self._highlighted_ranges)):              self._highlighted_ranges[0].editing = True +    def _disambiguate_highlighted_ranges(self): +        """If any files have the same name, also display their folder name""" +        name_counts = {} +        for rif in self._highlighted_ranges: +            if rif.display_name in name_counts: +                name_counts[rif.display_name] += 1 +            else: +                name_counts[rif.display_name] = 1 + +        for rif in self._highlighted_ranges: +            if name_counts[rif.display_name] > 1: +                rif.display_name = os.path.join( +                    os.path.basename(os.path.dirname(rif.range.filepath)), rif.display_name) +            else: +                rif.display_name = os.path.basename(rif.range.filepath) +      async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]):          # Filter out rifs from ~/.continue/diffs folder          range_in_files = [ @@ -211,6 +227,7 @@ class Autopilot(ContinueBaseModel):          ) for rif in range_in_files]          self._make_sure_is_editing_range() +        self._disambiguate_highlighted_ranges()          await self.update_subscribers() diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 91a47c8e..98615c64 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -67,13 +67,18 @@ DEFAULT_SLASH_COMMANDS = [  ] +class AzureInfo(BaseModel): +    endpoint: str +    engine: str +    api_version: str + +  class ContinueConfig(BaseModel):      """      A pydantic class for the continue config file.      """      steps_on_startup: Optional[Dict[str, Dict]] = {}      disallowed_steps: Optional[List[str]] = [] -    server_url: Optional[str] = None      allow_anonymous_telemetry: Optional[bool] = True      default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",                             "gpt-4", "ggml"] = 'gpt-4' @@ -86,6 +91,7 @@ class ContinueConfig(BaseModel):      on_traceback: Optional[List[OnTracebackSteps]] = [          OnTracebackSteps(step_name="DefaultOnTracebackStep")]      system_message: Optional[str] = None +    azure_openai_info: Optional[AzureInfo] = None      # Want to force these to be the slash commands for now      @validator('slash_commands', pre=True) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index ac57c122..7e612d3b 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -56,7 +56,7 @@ class Models:          api_key = self.provider_keys["openai"]          if api_key == "":              return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message) -        return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message) +        return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info)      def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:          api_key = self.provider_keys["hf_inference_api"] diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index d973f19e..33d10985 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,30 +1,41 @@  from functools import cached_property -import time  from typing import Any, Coroutine, Dict, Generator, List, Union +  from ...core.main import ChatMessage  import openai  from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ...core.config import AzureInfo  class OpenAI(LLM):      api_key: str      default_model: str -    def __init__(self, api_key: str, default_model: str, system_message: str = None): +    def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None):          self.api_key = api_key          self.default_model = default_model          self.system_message = system_message +        self.azure_info = azure_info          openai.api_key = api_key +        # Using an Azure OpenAI deployment +        if azure_info is not None: +            openai.api_type = "azure" +            openai.api_base = azure_info.endpoint +            openai.api_version = azure_info.api_version +      @cached_property      def name(self):          return self.default_model      @property      def default_args(self): -        return {**DEFAULT_ARGS, "model": self.default_model} +        args = {**DEFAULT_ARGS, "model": self.default_model} +        if self.azure_info is not None: +            args["engine"] = self.azure_info.engine +        return args      def count_tokens(self, text: str):          return count_tokens(self.default_model, text) diff --git a/continuedev/src/continuedev/libs/util/dedent.py b/continuedev/src/continuedev/libs/util/strings.py index e59c2e97..f1fb8d0b 100644 --- a/continuedev/src/continuedev/libs/util/dedent.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -23,3 +23,27 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]:                  break      return "\n".join(map(lambda x: x.lstrip(lcp), lines)), lcp + + +def remove_quotes_and_escapes(output: str) -> str: +    """ +    Clean up the output of the completion API, removing unnecessary escapes and quotes +    """ +    output = output.strip() + +    # Replace smart quotes +    output = output.replace("“", '"') +    output = output.replace("”", '"') +    output = output.replace("‘", "'") +    output = output.replace("’", "'") + +    # Remove escapes +    output = output.replace('\\"', '"') +    output = output.replace("\\'", "'") +    output = output.replace("\\n", "\n") +    output = output.replace("\\t", "\t") +    output = output.replace("\\\\", "\\") +    if (output.startswith('"') and output.endswith('"')) or (output.startswith("'") and output.endswith("'")): +        output = output[1:-1] + +    return output diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index 3751dec2..7c6b42db 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -3,6 +3,7 @@ from typing import Any, Coroutine, List  from pydantic import Field +from ..libs.util.strings import remove_quotes_and_escapes  from .main import EditHighlightedCodeStep  from .core.core import MessageStep  from ..core.main import FunctionCall, Models @@ -43,11 +44,8 @@ class SimpleChatStep(Step):          finally:              await generator.aclose() -        self.name = (await sdk.models.gpt35.complete( -            f"Write a short title for the following chat message: {self.description}")).strip() - -        if self.name.startswith('"') and self.name.endswith('"'): -            self.name = self.name[1:-1] +        self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( +            f"Write a short title for the following chat message: {self.description}"))          self.chat_context.append(ChatMessage(              role="assistant", diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 41988000..2b049ecc 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -13,7 +13,7 @@ from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContent  from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation  from ...core.main import ChatMessage, ContinueCustomException, Step, SequentialStep  from ...libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS -from ...libs.util.dedent import dedent_and_get_common_whitespace +from ...libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes  import difflib @@ -158,22 +158,12 @@ class DefaultModelEditCodeStep(Step):      _new_contents: str = ""      _prompt_and_completion: str = "" -    def _cleanup_output(self, output: str) -> str: -        output = output.replace('\\"', '"') -        output = output.replace("\\'", "'") -        output = output.replace("\\n", "\n") -        output = output.replace("\\t", "\t") -        output = output.replace("\\\\", "\\") -        if output.startswith('"') and output.endswith('"'): -            output = output[1:-1] - -        return output -      async def describe(self, models: Models) -> Coroutine[str, None, None]:          if self._previous_contents.strip() == self._new_contents.strip():              description = "No edits were made"          else: -            changes = '\n'.join(difflib.ndiff(self._previous_contents.splitlines(), self._new_contents.splitlines())) +            changes = '\n'.join(difflib.ndiff( +                self._previous_contents.splitlines(), self._new_contents.splitlines()))              description = await models.gpt3516k.complete(dedent(f"""\                  Diff summary: "{self.user_input}" @@ -183,17 +173,17 @@ class DefaultModelEditCodeStep(Step):                  Please give brief a description of the changes made above using markdown bullet points. Be concise:"""))          name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:") -        self.name = self._cleanup_output(name) +        self.name = remove_quotes_and_escapes(name) -        return f"{self._cleanup_output(description)}" +        return f"{remove_quotes_and_escapes(description)}"      async def get_prompt_parts(self, rif: RangeInFileWithContents, sdk: ContinueSDK, full_file_contents: str):          # We don't know here all of the functions being passed in.          # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.          # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.          model_to_use = sdk.models.default -        max_tokens = MAX_TOKENS_FOR_MODEL.get( -            model_to_use.name, DEFAULT_MAX_TOKENS) / 2 +        max_tokens = int(MAX_TOKENS_FOR_MODEL.get( +            model_to_use.name, DEFAULT_MAX_TOKENS) / 2)          TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200          if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: | 
