diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/main.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/file_context_provider.py | 5 |
3 files changed, 18 insertions, 8 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index bb9ca323..565c617d 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -56,11 +56,11 @@ class ContinueConfig(BaseModel): # Want to force these to be the slash commands for now @validator('slash_commands', pre=True) def default_slash_commands_validator(cls, v): - from ..steps.open_config import OpenConfigStep - from ..steps.clear_history import ClearHistoryStep - from ..steps.feedback import FeedbackStep - from ..steps.comment_code import CommentCodeStep - from ..steps.main import EditHighlightedCodeStep + from ..plugins.steps.open_config import OpenConfigStep + from ..plugins.steps.clear_history import ClearHistoryStep + from ..plugins.steps.feedback import FeedbackStep + from ..plugins.steps.comment_code import CommentCodeStep + from ..plugins.steps.main import EditHighlightedCodeStep DEFAULT_SLASH_COMMANDS = [ SlashCommand( diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 6c6adccc..df9b98ef 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -207,12 +207,21 @@ class ContextItemId(BaseModel): provider_title: str item_id: str + @validator('provider_title', 'item_id') + def must_be_valid_id(cls, v): + import re + if not re.match(r'^[0-9a-zA-Z_-]*$', v): + raise ValueError( + "Both provider_title and item_id can only include characters 0-9, a-z, A-Z, -, and _") + return v + def to_string(self) -> str: return f"{self.provider_title}-{self.item_id}" @staticmethod def from_string(string: str) -> 'ContextItemId': - provider_title, item_id = string.split('-') + provider_title, *rest = string.split('-') + item_id = '-'.join(rest) return ContextItemId(provider_title=provider_title, item_id=item_id) diff --git a/continuedev/src/continuedev/plugins/context_providers/file_context_provider.py b/continuedev/src/continuedev/plugins/context_providers/file_context_provider.py index 854310b1..632a876c 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file_context_provider.py +++ b/continuedev/src/continuedev/plugins/context_providers/file_context_provider.py @@ -1,4 +1,5 @@ import os +import re from typing import List from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.context import ContextProvider @@ -45,11 +46,11 @@ class FileContextProvider(ContextProvider): content=get_file_contents(file)[:min( 2000, len(get_file_contents(file)))], description=ContextItemDescription( - name=f"File {os.path.basename(file)}", + name=os.path.basename(file), description=file, id=ContextItemId( provider_title=self.title, - item_id=file + item_id=re.sub(r'[^0-9a-zA-Z_-]', '', file) ) ) ) for file in filepaths] |