diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 36 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/main.py | 1 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 23 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/main.py | 49 | 
8 files changed, 73 insertions, 59 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 02fd61de..5c3baafd 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -151,30 +151,26 @@ class Autopilot(ContinueBaseModel):              self._highlighted_ranges[0].editing = True      async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): - -        # If un-highlighting, then remove the range -        if len(self._highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end) and not self._adding_highlighted_code: -            self._highlighted_ranges = [] -            await self.update_subscribers() -            return - -        # If not toggled to be adding context, only edit or add the first range -        if not self._adding_highlighted_code and len(self._highlighted_ranges) > 0: -            if len(range_in_files) == 0: -                return -            if range_in_files[0].range.overlaps_with(self._highlighted_ranges[0].range) and range_in_files[0].filepath == self._highlighted_ranges[0].range.filepath: -                self._highlighted_ranges = [HighlightedRangeContext( -                    range=range_in_files[0].range, editing=True, pinned=False)] -                await self.update_subscribers() -                return -          # Filter out rifs from ~/.continue/diffs folder          range_in_files = [              rif for rif in range_in_files if not os.path.dirname(rif.filepath) == os.path.expanduser("~/.continue/diffs")] +        # Make sure all filepaths are relative to workspace          workspace_path = self.continue_sdk.ide.workspace_directory -        for rif in range_in_files: -            rif.filepath = os.path.basename(rif.filepath) + +        # If not adding highlighted code +        if not self._adding_highlighted_code: +            if len(self._highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end): +                # If un-highlighting the range to edit, then remove the range +                self._highlighted_ranges = [] +                await self.update_subscribers() +            elif len(range_in_files) > 0: +                # Otherwise, replace the current range with the new one +                # This is the first range to be highlighted +                self._highlighted_ranges = [HighlightedRangeContext( +                    range=range_in_files[0], editing=True, pinned=False, display_name=os.path.basename(range_in_files[0].filepath))] +                await self.update_subscribers() +            return          # If current range overlaps with any others, delete them and only keep the new range          new_ranges = [] @@ -195,7 +191,7 @@ class Autopilot(ContinueBaseModel):                  new_ranges.append(rif)          self._highlighted_ranges = new_ranges + [HighlightedRangeContext( -            range=rif, editing=False, pinned=False +            range=rif, editing=False, pinned=False, display_name=os.path.basename(rif.filepath)          ) for rif in range_in_files]          self._make_sure_is_editing_range() diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 4ea17f20..88690c83 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -205,6 +205,7 @@ class HighlightedRangeContext(ContinueBaseModel):      range: RangeInFileWithContents      editing: bool      pinned: bool +    display_name: str  class FullState(ContinueBaseModel): diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index ed670799..8649cd58 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -13,7 +13,7 @@ from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI  from ..libs.llm.openai import OpenAI  from .observation import Observation  from ..server.ide_protocol import AbstractIdeProtocolServer -from .main import Context, ContinueCustomException, History, Step, ChatMessage, ChatMessageRole +from .main import Context, ContinueCustomException, HighlightedRangeContext, History, Step, ChatMessage, ChatMessageRole  from ..steps.core.core import *  from ..libs.llm.proxy_server import ProxyServer @@ -178,6 +178,11 @@ class ContinueSDK(AbstractContinueSDK):          else:              return load_global_config() +    def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]: +        context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges) +                       ) if only_editing else self.__autopilot._highlighted_ranges +        return [c.range for c in context] +      def update_default_model(self, model: str):          config = self.config          config.default_model = model diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index c4e4139f..f0877d90 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -37,7 +37,7 @@ class OpenAI(LLM):          if args["model"] in CHAT_MODELS:              async for chunk in await openai.ChatCompletion.acreate(                  messages=compile_chat_messages( -                    args["model"], with_history, prompt, functions=None), +                    args["model"], with_history, args["max_tokens"], prompt, functions=None),                  **args,              ):                  if "content" in chunk.choices[0].delta: @@ -58,7 +58,7 @@ class OpenAI(LLM):          async for chunk in await openai.ChatCompletion.acreate(              messages=compile_chat_messages( -                args["model"], messages, functions=args.get("functions", None)), +                args["model"], messages, args["max_tokens"], functions=args.get("functions", None)),              **args,          ):              yield chunk.choices[0].delta @@ -69,7 +69,7 @@ class OpenAI(LLM):          if args["model"] in CHAT_MODELS:              resp = (await openai.ChatCompletion.acreate(                  messages=compile_chat_messages( -                    args["model"], with_history, prompt, functions=None), +                    args["model"], with_history, args["max_tokens"], prompt, functions=None),                  **args,              )).choices[0].message.content          else: diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 05ece394..eab6e441 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -38,7 +38,7 @@ class ProxyServer(LLM):          async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:              async with session.post(f"{SERVER_URL}/complete", json={ -                "messages": compile_chat_messages(args["model"], with_history, prompt, functions=None), +                "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None),                  "unique_id": self.unique_id,                  **args              }) as resp: @@ -50,7 +50,7 @@ class ProxyServer(LLM):      async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.default_model, messages, None, functions=args.get("functions", None)) +            self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None))          async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:              async with session.post(f"{SERVER_URL}/stream_chat", json={ @@ -74,7 +74,7 @@ class ProxyServer(LLM):      async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.default_model, with_history, prompt, functions=args.get("functions", None)) +            self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None))          async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:              async with session.post(f"{SERVER_URL}/stream_complete", json={ diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 8b06fef9..73be0717 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -76,14 +76,14 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:      return chat_history -def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:      prompt_tokens = count_tokens(model, prompt)      if functions is not None:          for function in functions:              prompt_tokens += count_tokens(model, json.dumps(function))      msgs = prune_chat_history(model, -                              msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + DEFAULT_MAX_TOKENS + count_tokens(model, system_message)) +                              msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, system_message))      history = []      if system_message:          history.append({ diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 10853828..4b35a758 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -181,15 +181,22 @@ class DefaultModelEditCodeStep(Step):          # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.          # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.          model_to_use = sdk.models.gpt4 +        max_tokens = DEFAULT_MAX_TOKENS -        BUFFER_FOR_FUNCTIONS = 400 -        total_tokens = model_to_use.count_tokens( -            full_file_contents + self._prompt + self.user_input) + BUFFER_FOR_FUNCTIONS + DEFAULT_MAX_TOKENS - -        TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1000 +        TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200          if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE:              self.description += "\n\n**It looks like you've selected a large range to edit, which may take a while to complete. If you'd like to cancel, click the 'X' button above. If you highlight a more specific range, Continue will only edit within it.**" +            # At this point, we also increase the max_tokens parameter so it doesn't stop in the middle of generation +            # Increase max_tokens to be double the size of the range +            # But don't exceed twice default max tokens +            max_tokens = int(min(model_to_use.count_tokens( +                rif.contents), DEFAULT_MAX_TOKENS) * 2.5) + +        BUFFER_FOR_FUNCTIONS = 400 +        total_tokens = model_to_use.count_tokens( +            full_file_contents + self._prompt + self.user_input) + BUFFER_FOR_FUNCTIONS + max_tokens +          # If using 3.5 and overflows, upgrade to 3.5.16k          if model_to_use.name == "gpt-3.5-turbo":              if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: @@ -252,7 +259,7 @@ class DefaultModelEditCodeStep(Step):                  file_suffix = "\n" + file_suffix                  rif.contents = rif.contents[:-1] -        return file_prefix, rif.contents, file_suffix, model_to_use +        return file_prefix, rif.contents, file_suffix, model_to_use, max_tokens      def compile_prompt(self, file_prefix: str, contents: str, file_suffix: str, sdk: ContinueSDK) -> str:          prompt = self._prompt @@ -289,7 +296,7 @@ class DefaultModelEditCodeStep(Step):          await sdk.ide.saveFile(rif.filepath)          full_file_contents = await sdk.ide.readFile(rif.filepath) -        file_prefix, contents, file_suffix, model_to_use = await self.get_prompt_parts( +        file_prefix, contents, file_suffix, model_to_use, max_tokens = await self.get_prompt_parts(              rif, sdk, full_file_contents)          contents, common_whitespace = dedent_and_get_common_whitespace(              contents) @@ -435,7 +442,7 @@ class DefaultModelEditCodeStep(Step):          completion_lines_covered = 0          repeating_file_suffix = False          line_below_highlighted_range = file_suffix.lstrip().split("\n")[0] -        async for chunk in model_to_use.stream_chat(messages, temperature=0): +        async for chunk in model_to_use.stream_chat(messages, temperature=0, max_tokens=max_tokens):              # Stop early if it is repeating the file_suffix or the step was deleted              if repeating_file_suffix:                  break diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py index 5ccffbfe..4f543022 100644 --- a/continuedev/src/continuedev/steps/main.py +++ b/continuedev/src/continuedev/steps/main.py @@ -97,7 +97,7 @@ class FasterEditHighlightedCodeStep(Step):          return "Editing highlighted code"      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: -        range_in_files = await sdk.ide.getHighlightedCode() +        range_in_files = await sdk.get_code_context(only_editing=True)          if len(range_in_files) == 0:              # Get the full contents of all open files              files = await sdk.ide.getOpenFiles() @@ -105,21 +105,16 @@ class FasterEditHighlightedCodeStep(Step):              for file in files:                  contents[file] = await sdk.ide.readFile(file) -            range_in_files = [RangeInFile.from_entire_file( +            range_in_files = [RangeInFileWithContents.from_entire_file(                  filepath, content) for filepath, content in contents.items()] -        rif_with_contents = [] -        for range_in_file in range_in_files: -            file_contents = await sdk.ide.readRangeInFile(range_in_file) -            rif_with_contents.append( -                RangeInFileWithContents.from_range_in_file(range_in_file, file_contents)) -        enc_dec = MarkdownStyleEncoderDecoder(rif_with_contents) +        enc_dec = MarkdownStyleEncoderDecoder(range_in_files)          code_string = enc_dec.encode()          prompt = self._prompt.format(              code=code_string, user_input=self.user_input)          rif_dict = {} -        for rif in rif_with_contents: +        for rif in range_in_files:              rif_dict[rif.filepath] = rif.contents          completion = await sdk.models.gpt35.complete(prompt) @@ -193,7 +188,7 @@ class StarCoderEditHighlightedCodeStep(Step):          return await models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: -        range_in_files = await sdk.ide.getHighlightedCode() +        range_in_files = await sdk.get_code_context(only_editing=True)          found_highlighted_code = len(range_in_files) > 0          if not found_highlighted_code:              # Get the full contents of all open files @@ -202,20 +197,14 @@ class StarCoderEditHighlightedCodeStep(Step):              for file in files:                  contents[file] = await sdk.ide.readFile(file) -            range_in_files = [RangeInFile.from_entire_file( +            range_in_files = [RangeInFileWithContents.from_entire_file(                  filepath, content) for filepath, content in contents.items()] -        rif_with_contents = [] -        for range_in_file in range_in_files: -            file_contents = await sdk.ide.readRangeInFile(range_in_file) -            rif_with_contents.append( -                RangeInFileWithContents.from_range_in_file(range_in_file, file_contents)) -          rif_dict = {} -        for rif in rif_with_contents: +        for rif in range_in_files:              rif_dict[rif.filepath] = rif.contents -        for rif in rif_with_contents: +        for rif in range_in_files:              prompt = self._prompt.format(                  code=rif.contents, user_request=self.user_input) @@ -255,7 +244,18 @@ class EditHighlightedCodeStep(Step):          return "Editing code"      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: -        range_in_files = await sdk.ide.getHighlightedCode() +        range_in_files = sdk.get_code_context(only_editing=True) + +        # If nothing highlighted, insert at the cursor if possible +        if len(range_in_files) == 0: +            highlighted_code = await sdk.ide.getHighlightedCode() +            if highlighted_code is not None: +                for rif in highlighted_code: +                    if rif.range.start == rif.range.end: +                        range_in_files.append( +                            RangeInFileWithContents.from_range_in_file(rif, "")) + +        # If nothing highlighted, edit the first open file          if len(range_in_files) == 0:              # Get the full contents of all open files              files = await sdk.ide.getOpenFiles() @@ -263,7 +263,7 @@ class EditHighlightedCodeStep(Step):              for file in files:                  contents[file] = await sdk.ide.readFile(file) -            range_in_files = [RangeInFile.from_entire_file( +            range_in_files = [RangeInFileWithContents.from_entire_file(                  filepath, content) for filepath, content in contents.items()]          # If still no highlighted code, create a new file and edit there @@ -271,7 +271,12 @@ class EditHighlightedCodeStep(Step):              # Create a new file              new_file_path = "new_file.txt"              await sdk.add_file(new_file_path, "") -            range_in_files = [RangeInFile.from_entire_file(new_file_path, "")] +            range_in_files = [ +                RangeInFileWithContents.from_entire_file(new_file_path, "")] + +        range_in_files = list(map(lambda x: RangeInFile( +            filepath=x.filepath, range=x.range +        ), range_in_files))          await sdk.run_step(DefaultModelEditCodeStep(user_input=self.user_input, range_in_files=range_in_files)) | 
