diff options
Diffstat (limited to 'continuedev/src')
17 files changed, 126 insertions, 36 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index cee7a2f9..a943a35f 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -177,6 +177,7 @@ class Autopilot(ContinueBaseModel):              session_info=self.session_info,              config=self.continue_sdk.config,              saved_context_groups=self._saved_context_groups, +            context_providers=self.context_manager.get_provider_descriptions(),          )          self.full_state = full_state          return full_state diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 571e5dc8..f1f309ba 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -11,7 +11,13 @@ from ..libs.util.devdata import dev_data_logger  from ..libs.util.logging import logger  from ..libs.util.telemetry import posthog_logger  from ..server.meilisearch_server import poll_meilisearch_running, restart_meilisearch -from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId +from .main import ( +    ChatMessage, +    ContextItem, +    ContextItemDescription, +    ContextItemId, +    ContextProviderDescription, +)  class ContinueSDK(BaseModel): @@ -36,6 +42,10 @@ class ContextProvider(BaseModel):      delete_documents: Callable[[List[str]], Awaitable] = None      update_documents: Callable[[List[ContextItem], str], Awaitable] = None +    display_title: str +    description: str +    dynamic: bool +      selected_items: List[ContextItem] = []      def dict(self, *args, **kwargs): @@ -168,6 +178,20 @@ class ContextManager:      It is responsible for compiling all of this information into a single prompt without exceeding the token limit.      """ +    def get_provider_descriptions(self) -> List[ContextProviderDescription]: +        """ +        Returns a list of ContextProviderDescriptions for each context provider. +        """ +        return [ +            ContextProviderDescription( +                title=provider.title, +                display_title=provider.display_title, +                description=provider.description, +                dynamic=provider.dynamic, +            ) +            for provider in self.context_providers.values() +        ] +      async def get_selected_items(self) -> List[ContextItem]:          """          Returns all of the selected ContextItems. @@ -242,6 +266,7 @@ class ContextManager:                  "description": item.description.description,                  "content": item.content,                  "workspace_dir": workspace_dir, +                "provider_name": item.description.id.provider_title,              }              for item in context_items          ] @@ -282,7 +307,9 @@ class ContextManager:                  await globalSearchIndex.update_searchable_attributes(                      ["name", "description"]                  ) -                await globalSearchIndex.update_filterable_attributes(["workspace_dir"]) +                await globalSearchIndex.update_filterable_attributes( +                    ["workspace_dir", "provider_name"] +                )                  async def load_context_provider(provider: ContextProvider):                      context_items = await provider.provide_context_items(workspace_dir) @@ -293,6 +320,7 @@ class ContextManager:                              "description": item.description.description,                              "content": item.content,                              "workspace_dir": workspace_dir, +                            "provider_name": provider.title,                          }                          for item in context_items                      ] diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index ec2e2a07..3d3bef15 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -291,6 +291,13 @@ class ContinueConfig(ContinueBaseModel):          return original_dict +class ContextProviderDescription(BaseModel): +    title: str +    display_title: str +    description: str +    dynamic: bool + +  class FullState(ContinueBaseModel):      """A full state of the program, including the history""" @@ -303,6 +310,7 @@ class FullState(ContinueBaseModel):      session_info: Optional[SessionInfo] = None      config: ContinueConfig      saved_context_groups: Dict[str, List[ContextItem]] = {} +    context_providers: List[ContextProviderDescription] = []  class ContinueSDK: diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 85ca1969..653d2d6b 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -137,6 +137,7 @@ class LLM(ContinueBaseModel):      async def stream_complete(          self,          prompt: str, +        raw: bool = False,          model: str = None,          temperature: float = None,          top_p: float = None, @@ -163,7 +164,9 @@ class LLM(ContinueBaseModel):          prompt = prune_raw_prompt_from_top(              self.model, self.context_length, prompt, options.max_tokens          ) -        prompt = self.template_prompt_like_messages(prompt) + +        if not raw: +            prompt = self.template_prompt_like_messages(prompt)          self.write_log(f"Prompt: \n\n{prompt}") @@ -181,6 +184,7 @@ class LLM(ContinueBaseModel):      async def complete(          self,          prompt: str, +        raw: bool = False,          model: str = None,          temperature: float = None,          top_p: float = None, @@ -207,7 +211,9 @@ class LLM(ContinueBaseModel):          prompt = prune_raw_prompt_from_top(              self.model, self.context_length, prompt, options.max_tokens          ) -        prompt = self.template_prompt_like_messages(prompt) + +        if not raw: +            prompt = self.template_prompt_like_messages(prompt)          self.write_log(f"Prompt: \n\n{prompt}") diff --git a/continuedev/src/continuedev/libs/llm/prompts/edit.py b/continuedev/src/continuedev/libs/llm/prompts/edit.py index a234fa61..275473bc 100644 --- a/continuedev/src/continuedev/libs/llm/prompts/edit.py +++ b/continuedev/src/continuedev/libs/llm/prompts/edit.py @@ -11,3 +11,5 @@ simplified_edit_prompt = dedent(              Output nothing except for the code. No code block, no English explanation, no start/end tags.              [/INST]"""  ) + +codellama_infill_edit_prompt = "{{file_prefix}}<FILL>{{file_suffix}}" diff --git a/continuedev/src/continuedev/plugins/context_providers/diff.py b/continuedev/src/continuedev/plugins/context_providers/diff.py index c8345d02..4c16cabf 100644 --- a/continuedev/src/continuedev/plugins/context_providers/diff.py +++ b/continuedev/src/continuedev/plugins/context_providers/diff.py @@ -7,6 +7,9 @@ from ...core.main import ContextItem, ContextItemDescription, ContextItemId  class DiffContextProvider(ContextProvider):      title = "diff" +    display_title = "Diff" +    description = "Output of 'git diff' in current repo" +    dynamic = True      DIFF_CONTEXT_ITEM_ID = "diff" @@ -30,8 +33,8 @@ class DiffContextProvider(ContextProvider):          return [self.BASE_CONTEXT_ITEM]      async def get_item(self, id: ContextItemId, query: str) -> ContextItem: -        if not id.item_id == self.DIFF_CONTEXT_ITEM_ID: -            raise Exception("Invalid item id") +        if not id.provider_title == self.title: +            raise Exception("Invalid provider title for item")          diff = subprocess.check_output(["git", "diff"], cwd=self.workspace_dir).decode(              "utf-8" diff --git a/continuedev/src/continuedev/plugins/context_providers/dynamic.py b/continuedev/src/continuedev/plugins/context_providers/dynamic.py index e6ea0e88..50567621 100644 --- a/continuedev/src/continuedev/plugins/context_providers/dynamic.py +++ b/continuedev/src/continuedev/plugins/context_providers/dynamic.py @@ -13,16 +13,12 @@ class DynamicProvider(ContextProvider, ABC):      """      title: str -    """ -    A name representing the provider. Probably use capitalized version of title -    """ +    """A name representing the provider. Probably use capitalized version of title""" +      name: str -    """ -    A description for the provider -    """ -    description: str      workspace_dir: str = None +    dynamic: bool = True      @property      def BASE_CONTEXT_ITEM(self): @@ -41,8 +37,8 @@ class DynamicProvider(ContextProvider, ABC):          return [self.BASE_CONTEXT_ITEM]      async def get_item(self, id: ContextItemId, query: str) -> ContextItem: -        if id.item_id != self.title: -            raise Exception("Invalid item id") +        if not id.provider_title == self.title: +            raise Exception("Invalid provider title for item")          query = query.lstrip(self.title + " ")          results = await self.get_content(query) diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py index 3f37232e..bd63eab8 100644 --- a/continuedev/src/continuedev/plugins/context_providers/embeddings.py +++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py @@ -17,6 +17,10 @@ class EmbeddingResult(BaseModel):  class EmbeddingsProvider(ContextProvider):      title = "embed" +    display_title = "Embeddings Search" +    description = "Search the codebase using embeddings" +    dynamic = True +      workspace_directory: str      EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings" @@ -62,8 +66,8 @@ class EmbeddingsProvider(ContextProvider):          return [self.BASE_CONTEXT_ITEM]      async def add_context_item(self, id: ContextItemId, query: str): -        if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID: -            raise Exception("Invalid item id") +        if not id.provider_title == self.title: +            raise Exception("Invalid provider title for item")          results = await self._get_query_results(query) diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 6b27a889..c4a61193 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -27,6 +27,10 @@ class FileContextProvider(ContextProvider):      title = "file"      ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS +    display_title = "Files" +    description = "Reference files in the current workspace" +    dynamic = False +      async def start(self, *args):          await super().start(*args) diff --git a/continuedev/src/continuedev/plugins/context_providers/filetree.py b/continuedev/src/continuedev/plugins/context_providers/filetree.py index 959a0a66..968b761d 100644 --- a/continuedev/src/continuedev/plugins/context_providers/filetree.py +++ b/continuedev/src/continuedev/plugins/context_providers/filetree.py @@ -35,6 +35,9 @@ def split_path(path: str, with_root=None) -> List[str]:  class FileTreeContextProvider(ContextProvider):      title = "tree" +    display_title = "File Tree" +    description = "Add a formatted file tree of this directory to the context" +    dynamic = True      workspace_dir: str = None @@ -78,7 +81,7 @@ class FileTreeContextProvider(ContextProvider):          return [await self._filetree_context_item()]      async def get_item(self, id: ContextItemId, query: str) -> ContextItem: -        if not id.item_id == self.title: -            raise Exception("Invalid item id") +        if not id.provider_title == self.title: +            raise Exception("Invalid provider title for item")          return await self._filetree_context_item() diff --git a/continuedev/src/continuedev/plugins/context_providers/github.py b/continuedev/src/continuedev/plugins/context_providers/github.py index d394add1..7a16d3c9 100644 --- a/continuedev/src/continuedev/plugins/context_providers/github.py +++ b/continuedev/src/continuedev/plugins/context_providers/github.py @@ -20,6 +20,10 @@ class GitHubIssuesContextProvider(ContextProvider):      repo_name: str      auth_token: str +    display_title = "GitHub Issues" +    description = "Reference GitHub issues" +    dynamic = False +      async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:          auth = Auth.Token(self.auth_token)          gh = Github(auth=auth) diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py index fc9d4555..06681db0 100644 --- a/continuedev/src/continuedev/plugins/context_providers/google.py +++ b/continuedev/src/continuedev/plugins/context_providers/google.py @@ -10,6 +10,9 @@ from .util import remove_meilisearch_disallowed_chars  class GoogleContextProvider(ContextProvider):      title = "google" +    display_title = "Google" +    description = "Search Google" +    dynamic = True      serper_api_key: str @@ -42,8 +45,8 @@ class GoogleContextProvider(ContextProvider):          return [self.BASE_CONTEXT_ITEM]      async def get_item(self, id: ContextItemId, query: str) -> ContextItem: -        if not id.item_id == self.GOOGLE_CONTEXT_ITEM_ID: -            raise Exception("Invalid item id") +        if not id.provider_title == self.title: +            raise Exception("Invalid provider title for item")          results = await self._google_search(query)          json_results = json.loads(results) diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index 504764b9..0610a8c3 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -28,6 +28,9 @@ class HighlightedCodeContextProvider(ContextProvider):      """      title = "code" +    display_title = "Highlighted Code" +    description = "Highlight code" +    dynamic = True      ide: Any  # IdeProtocolServer diff --git a/continuedev/src/continuedev/plugins/context_providers/search.py b/continuedev/src/continuedev/plugins/context_providers/search.py index 8a3a3689..19fc15bc 100644 --- a/continuedev/src/continuedev/plugins/context_providers/search.py +++ b/continuedev/src/continuedev/plugins/context_providers/search.py @@ -11,6 +11,9 @@ from .util import remove_meilisearch_disallowed_chars  class SearchContextProvider(ContextProvider):      title = "search" +    display_title = "Search" +    description = "Search the workspace for all matches of an exact string (e.g. '@search console.log')" +    dynamic = True      SEARCH_CONTEXT_ITEM_ID = "search" @@ -86,8 +89,8 @@ class SearchContextProvider(ContextProvider):          return [self.BASE_CONTEXT_ITEM]      async def get_item(self, id: ContextItemId, query: str) -> ContextItem: -        if not id.item_id == self.SEARCH_CONTEXT_ITEM_ID: -            raise Exception("Invalid item id") +        if not id.provider_title == self.title: +            raise Exception("Invalid provider title for item")          query = query.lstrip("search ")          results = await self._search(query) diff --git a/continuedev/src/continuedev/plugins/context_providers/terminal.py b/continuedev/src/continuedev/plugins/context_providers/terminal.py index 57d06a4f..f63ed676 100644 --- a/continuedev/src/continuedev/plugins/context_providers/terminal.py +++ b/continuedev/src/continuedev/plugins/context_providers/terminal.py @@ -6,6 +6,9 @@ from ...core.main import ChatMessage, ContextItem, ContextItemDescription, Conte  class TerminalContextProvider(ContextProvider):      title = "terminal" +    display_title = "Terminal" +    description = "Reference the contents of the terminal" +    dynamic = True      workspace_dir: str = None      get_last_n_commands: int = 3 @@ -31,8 +34,8 @@ class TerminalContextProvider(ContextProvider):          return [self._terminal_context_item()]      async def get_item(self, id: ContextItemId, query: str) -> ContextItem: -        if not id.item_id == self.title: -            raise Exception("Invalid item id") +        if not id.provider_title == self.title: +            raise Exception("Invalid provider title for item")          terminal_contents = await self.sdk.ide.getTerminalContents(              self.get_last_n_commands diff --git a/continuedev/src/continuedev/plugins/context_providers/url.py b/continuedev/src/continuedev/plugins/context_providers/url.py index a5ec8990..b9dc0e1d 100644 --- a/continuedev/src/continuedev/plugins/context_providers/url.py +++ b/continuedev/src/continuedev/plugins/context_providers/url.py @@ -10,6 +10,9 @@ from .util import remove_meilisearch_disallowed_chars  class URLContextProvider(ContextProvider):      title = "url" +    display_title = "URL" +    description = "Reference the contents of a webpage" +    dynamic = True      # Allows users to provide a list of preset urls      preset_urls: List[str] = [] @@ -78,8 +81,8 @@ class URLContextProvider(ContextProvider):              return matching_static_item          # Check if the item is the dynamic item -        if not id.item_id == self.DYNAMIC_URL_CONTEXT_ITEM_ID: -            raise Exception("Invalid item id") +        if not id.provider_title == self.title: +            raise Exception("Invalid provider title for item")          # Generate the dynamic item          url = query.lstrip("url ").strip() diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 97235e6f..bf5eb144 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -604,7 +604,12 @@ Please output the code to be inserted at the cursor in order to fulfill the user              rendered = render_prompt_template(                  template,                  messages[:-1], -                {"code_to_edit": rif.contents, "user_input": self.user_input}, +                { +                    "code_to_edit": rif.contents, +                    "user_input": self.user_input, +                    "file_prefix": file_prefix, +                    "file_suffix": file_suffix, +                },              )              if isinstance(rendered, str):                  messages = [ @@ -617,11 +622,25 @@ Please output the code to be inserted at the cursor in order to fulfill the user              else:                  messages = rendered -        generator = model_to_use.stream_chat( -            messages, -            temperature=sdk.config.temperature, -            max_tokens=min(max_tokens, model_to_use.context_length // 2), -        ) +            generator = model_to_use.stream_complete( +                rendered, +                raw=True, +                temperature=sdk.config.temperature, +                max_tokens=min(max_tokens, model_to_use.context_length // 2), +            ) + +        else: + +            async def gen(): +                async for chunk in model_to_use.stream_chat( +                    messages, +                    temperature=sdk.config.temperature, +                    max_tokens=min(max_tokens, model_to_use.context_length // 2), +                ): +                    if "content" in chunk: +                        yield chunk["content"] + +            generator = gen()          posthog_logger.capture_event(              "model_use", @@ -641,9 +660,6 @@ Please output the code to be inserted at the cursor in order to fulfill the user                      return                  # Accumulate lines -                if "content" not in chunk: -                    continue  # ayo -                chunk = chunk["content"]                  chunk_lines = chunk.split("\n")                  chunk_lines[0] = unfinished_line + chunk_lines[0]                  if chunk.endswith("\n"): | 
