summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py1
-rw-r--r--continuedev/src/continuedev/core/context.py32
-rw-r--r--continuedev/src/continuedev/core/main.py8
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py10
-rw-r--r--continuedev/src/continuedev/libs/llm/prompts/edit.py2
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/diff.py7
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/dynamic.py14
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/embeddings.py8
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/file.py4
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/filetree.py7
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/github.py4
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/google.py7
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/highlighted_code.py3
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/search.py7
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/terminal.py7
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/url.py7
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py34
17 files changed, 126 insertions, 36 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index cee7a2f9..a943a35f 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -177,6 +177,7 @@ class Autopilot(ContinueBaseModel):
session_info=self.session_info,
config=self.continue_sdk.config,
saved_context_groups=self._saved_context_groups,
+ context_providers=self.context_manager.get_provider_descriptions(),
)
self.full_state = full_state
return full_state
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index 571e5dc8..f1f309ba 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -11,7 +11,13 @@ from ..libs.util.devdata import dev_data_logger
from ..libs.util.logging import logger
from ..libs.util.telemetry import posthog_logger
from ..server.meilisearch_server import poll_meilisearch_running, restart_meilisearch
-from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId
+from .main import (
+ ChatMessage,
+ ContextItem,
+ ContextItemDescription,
+ ContextItemId,
+ ContextProviderDescription,
+)
class ContinueSDK(BaseModel):
@@ -36,6 +42,10 @@ class ContextProvider(BaseModel):
delete_documents: Callable[[List[str]], Awaitable] = None
update_documents: Callable[[List[ContextItem], str], Awaitable] = None
+ display_title: str
+ description: str
+ dynamic: bool
+
selected_items: List[ContextItem] = []
def dict(self, *args, **kwargs):
@@ -168,6 +178,20 @@ class ContextManager:
It is responsible for compiling all of this information into a single prompt without exceeding the token limit.
"""
+ def get_provider_descriptions(self) -> List[ContextProviderDescription]:
+ """
+ Returns a list of ContextProviderDescriptions for each context provider.
+ """
+ return [
+ ContextProviderDescription(
+ title=provider.title,
+ display_title=provider.display_title,
+ description=provider.description,
+ dynamic=provider.dynamic,
+ )
+ for provider in self.context_providers.values()
+ ]
+
async def get_selected_items(self) -> List[ContextItem]:
"""
Returns all of the selected ContextItems.
@@ -242,6 +266,7 @@ class ContextManager:
"description": item.description.description,
"content": item.content,
"workspace_dir": workspace_dir,
+ "provider_name": item.description.id.provider_title,
}
for item in context_items
]
@@ -282,7 +307,9 @@ class ContextManager:
await globalSearchIndex.update_searchable_attributes(
["name", "description"]
)
- await globalSearchIndex.update_filterable_attributes(["workspace_dir"])
+ await globalSearchIndex.update_filterable_attributes(
+ ["workspace_dir", "provider_name"]
+ )
async def load_context_provider(provider: ContextProvider):
context_items = await provider.provide_context_items(workspace_dir)
@@ -293,6 +320,7 @@ class ContextManager:
"description": item.description.description,
"content": item.content,
"workspace_dir": workspace_dir,
+ "provider_name": provider.title,
}
for item in context_items
]
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index ec2e2a07..3d3bef15 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -291,6 +291,13 @@ class ContinueConfig(ContinueBaseModel):
return original_dict
+class ContextProviderDescription(BaseModel):
+ title: str
+ display_title: str
+ description: str
+ dynamic: bool
+
+
class FullState(ContinueBaseModel):
"""A full state of the program, including the history"""
@@ -303,6 +310,7 @@ class FullState(ContinueBaseModel):
session_info: Optional[SessionInfo] = None
config: ContinueConfig
saved_context_groups: Dict[str, List[ContextItem]] = {}
+ context_providers: List[ContextProviderDescription] = []
class ContinueSDK:
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 85ca1969..653d2d6b 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -137,6 +137,7 @@ class LLM(ContinueBaseModel):
async def stream_complete(
self,
prompt: str,
+ raw: bool = False,
model: str = None,
temperature: float = None,
top_p: float = None,
@@ -163,7 +164,9 @@ class LLM(ContinueBaseModel):
prompt = prune_raw_prompt_from_top(
self.model, self.context_length, prompt, options.max_tokens
)
- prompt = self.template_prompt_like_messages(prompt)
+
+ if not raw:
+ prompt = self.template_prompt_like_messages(prompt)
self.write_log(f"Prompt: \n\n{prompt}")
@@ -181,6 +184,7 @@ class LLM(ContinueBaseModel):
async def complete(
self,
prompt: str,
+ raw: bool = False,
model: str = None,
temperature: float = None,
top_p: float = None,
@@ -207,7 +211,9 @@ class LLM(ContinueBaseModel):
prompt = prune_raw_prompt_from_top(
self.model, self.context_length, prompt, options.max_tokens
)
- prompt = self.template_prompt_like_messages(prompt)
+
+ if not raw:
+ prompt = self.template_prompt_like_messages(prompt)
self.write_log(f"Prompt: \n\n{prompt}")
diff --git a/continuedev/src/continuedev/libs/llm/prompts/edit.py b/continuedev/src/continuedev/libs/llm/prompts/edit.py
index a234fa61..275473bc 100644
--- a/continuedev/src/continuedev/libs/llm/prompts/edit.py
+++ b/continuedev/src/continuedev/libs/llm/prompts/edit.py
@@ -11,3 +11,5 @@ simplified_edit_prompt = dedent(
Output nothing except for the code. No code block, no English explanation, no start/end tags.
[/INST]"""
)
+
+codellama_infill_edit_prompt = "{{file_prefix}}<FILL>{{file_suffix}}"
diff --git a/continuedev/src/continuedev/plugins/context_providers/diff.py b/continuedev/src/continuedev/plugins/context_providers/diff.py
index c8345d02..4c16cabf 100644
--- a/continuedev/src/continuedev/plugins/context_providers/diff.py
+++ b/continuedev/src/continuedev/plugins/context_providers/diff.py
@@ -7,6 +7,9 @@ from ...core.main import ContextItem, ContextItemDescription, ContextItemId
class DiffContextProvider(ContextProvider):
title = "diff"
+ display_title = "Diff"
+ description = "Output of 'git diff' in current repo"
+ dynamic = True
DIFF_CONTEXT_ITEM_ID = "diff"
@@ -30,8 +33,8 @@ class DiffContextProvider(ContextProvider):
return [self.BASE_CONTEXT_ITEM]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if not id.item_id == self.DIFF_CONTEXT_ITEM_ID:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
diff = subprocess.check_output(["git", "diff"], cwd=self.workspace_dir).decode(
"utf-8"
diff --git a/continuedev/src/continuedev/plugins/context_providers/dynamic.py b/continuedev/src/continuedev/plugins/context_providers/dynamic.py
index e6ea0e88..50567621 100644
--- a/continuedev/src/continuedev/plugins/context_providers/dynamic.py
+++ b/continuedev/src/continuedev/plugins/context_providers/dynamic.py
@@ -13,16 +13,12 @@ class DynamicProvider(ContextProvider, ABC):
"""
title: str
- """
- A name representing the provider. Probably use capitalized version of title
- """
+ """A name representing the provider. Probably use capitalized version of title"""
+
name: str
- """
- A description for the provider
- """
- description: str
workspace_dir: str = None
+ dynamic: bool = True
@property
def BASE_CONTEXT_ITEM(self):
@@ -41,8 +37,8 @@ class DynamicProvider(ContextProvider, ABC):
return [self.BASE_CONTEXT_ITEM]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if id.item_id != self.title:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
query = query.lstrip(self.title + " ")
results = await self.get_content(query)
diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py
index 3f37232e..bd63eab8 100644
--- a/continuedev/src/continuedev/plugins/context_providers/embeddings.py
+++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py
@@ -17,6 +17,10 @@ class EmbeddingResult(BaseModel):
class EmbeddingsProvider(ContextProvider):
title = "embed"
+ display_title = "Embeddings Search"
+ description = "Search the codebase using embeddings"
+ dynamic = True
+
workspace_directory: str
EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings"
@@ -62,8 +66,8 @@ class EmbeddingsProvider(ContextProvider):
return [self.BASE_CONTEXT_ITEM]
async def add_context_item(self, id: ContextItemId, query: str):
- if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
results = await self._get_query_results(query)
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py
index 6b27a889..c4a61193 100644
--- a/continuedev/src/continuedev/plugins/context_providers/file.py
+++ b/continuedev/src/continuedev/plugins/context_providers/file.py
@@ -27,6 +27,10 @@ class FileContextProvider(ContextProvider):
title = "file"
ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS
+ display_title = "Files"
+ description = "Reference files in the current workspace"
+ dynamic = False
+
async def start(self, *args):
await super().start(*args)
diff --git a/continuedev/src/continuedev/plugins/context_providers/filetree.py b/continuedev/src/continuedev/plugins/context_providers/filetree.py
index 959a0a66..968b761d 100644
--- a/continuedev/src/continuedev/plugins/context_providers/filetree.py
+++ b/continuedev/src/continuedev/plugins/context_providers/filetree.py
@@ -35,6 +35,9 @@ def split_path(path: str, with_root=None) -> List[str]:
class FileTreeContextProvider(ContextProvider):
title = "tree"
+ display_title = "File Tree"
+ description = "Add a formatted file tree of this directory to the context"
+ dynamic = True
workspace_dir: str = None
@@ -78,7 +81,7 @@ class FileTreeContextProvider(ContextProvider):
return [await self._filetree_context_item()]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if not id.item_id == self.title:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
return await self._filetree_context_item()
diff --git a/continuedev/src/continuedev/plugins/context_providers/github.py b/continuedev/src/continuedev/plugins/context_providers/github.py
index d394add1..7a16d3c9 100644
--- a/continuedev/src/continuedev/plugins/context_providers/github.py
+++ b/continuedev/src/continuedev/plugins/context_providers/github.py
@@ -20,6 +20,10 @@ class GitHubIssuesContextProvider(ContextProvider):
repo_name: str
auth_token: str
+ display_title = "GitHub Issues"
+ description = "Reference GitHub issues"
+ dynamic = False
+
async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
auth = Auth.Token(self.auth_token)
gh = Github(auth=auth)
diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py
index fc9d4555..06681db0 100644
--- a/continuedev/src/continuedev/plugins/context_providers/google.py
+++ b/continuedev/src/continuedev/plugins/context_providers/google.py
@@ -10,6 +10,9 @@ from .util import remove_meilisearch_disallowed_chars
class GoogleContextProvider(ContextProvider):
title = "google"
+ display_title = "Google"
+ description = "Search Google"
+ dynamic = True
serper_api_key: str
@@ -42,8 +45,8 @@ class GoogleContextProvider(ContextProvider):
return [self.BASE_CONTEXT_ITEM]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if not id.item_id == self.GOOGLE_CONTEXT_ITEM_ID:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
results = await self._google_search(query)
json_results = json.loads(results)
diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
index 504764b9..0610a8c3 100644
--- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
+++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
@@ -28,6 +28,9 @@ class HighlightedCodeContextProvider(ContextProvider):
"""
title = "code"
+ display_title = "Highlighted Code"
+ description = "Highlight code"
+ dynamic = True
ide: Any # IdeProtocolServer
diff --git a/continuedev/src/continuedev/plugins/context_providers/search.py b/continuedev/src/continuedev/plugins/context_providers/search.py
index 8a3a3689..19fc15bc 100644
--- a/continuedev/src/continuedev/plugins/context_providers/search.py
+++ b/continuedev/src/continuedev/plugins/context_providers/search.py
@@ -11,6 +11,9 @@ from .util import remove_meilisearch_disallowed_chars
class SearchContextProvider(ContextProvider):
title = "search"
+ display_title = "Search"
+ description = "Search the workspace for all matches of an exact string (e.g. '@search console.log')"
+ dynamic = True
SEARCH_CONTEXT_ITEM_ID = "search"
@@ -86,8 +89,8 @@ class SearchContextProvider(ContextProvider):
return [self.BASE_CONTEXT_ITEM]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if not id.item_id == self.SEARCH_CONTEXT_ITEM_ID:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
query = query.lstrip("search ")
results = await self._search(query)
diff --git a/continuedev/src/continuedev/plugins/context_providers/terminal.py b/continuedev/src/continuedev/plugins/context_providers/terminal.py
index 57d06a4f..f63ed676 100644
--- a/continuedev/src/continuedev/plugins/context_providers/terminal.py
+++ b/continuedev/src/continuedev/plugins/context_providers/terminal.py
@@ -6,6 +6,9 @@ from ...core.main import ChatMessage, ContextItem, ContextItemDescription, Conte
class TerminalContextProvider(ContextProvider):
title = "terminal"
+ display_title = "Terminal"
+ description = "Reference the contents of the terminal"
+ dynamic = True
workspace_dir: str = None
get_last_n_commands: int = 3
@@ -31,8 +34,8 @@ class TerminalContextProvider(ContextProvider):
return [self._terminal_context_item()]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if not id.item_id == self.title:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
terminal_contents = await self.sdk.ide.getTerminalContents(
self.get_last_n_commands
diff --git a/continuedev/src/continuedev/plugins/context_providers/url.py b/continuedev/src/continuedev/plugins/context_providers/url.py
index a5ec8990..b9dc0e1d 100644
--- a/continuedev/src/continuedev/plugins/context_providers/url.py
+++ b/continuedev/src/continuedev/plugins/context_providers/url.py
@@ -10,6 +10,9 @@ from .util import remove_meilisearch_disallowed_chars
class URLContextProvider(ContextProvider):
title = "url"
+ display_title = "URL"
+ description = "Reference the contents of a webpage"
+ dynamic = True
# Allows users to provide a list of preset urls
preset_urls: List[str] = []
@@ -78,8 +81,8 @@ class URLContextProvider(ContextProvider):
return matching_static_item
# Check if the item is the dynamic item
- if not id.item_id == self.DYNAMIC_URL_CONTEXT_ITEM_ID:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
# Generate the dynamic item
url = query.lstrip("url ").strip()
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index 97235e6f..bf5eb144 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -604,7 +604,12 @@ Please output the code to be inserted at the cursor in order to fulfill the user
rendered = render_prompt_template(
template,
messages[:-1],
- {"code_to_edit": rif.contents, "user_input": self.user_input},
+ {
+ "code_to_edit": rif.contents,
+ "user_input": self.user_input,
+ "file_prefix": file_prefix,
+ "file_suffix": file_suffix,
+ },
)
if isinstance(rendered, str):
messages = [
@@ -617,11 +622,25 @@ Please output the code to be inserted at the cursor in order to fulfill the user
else:
messages = rendered
- generator = model_to_use.stream_chat(
- messages,
- temperature=sdk.config.temperature,
- max_tokens=min(max_tokens, model_to_use.context_length // 2),
- )
+ generator = model_to_use.stream_complete(
+ rendered,
+ raw=True,
+ temperature=sdk.config.temperature,
+ max_tokens=min(max_tokens, model_to_use.context_length // 2),
+ )
+
+ else:
+
+ async def gen():
+ async for chunk in model_to_use.stream_chat(
+ messages,
+ temperature=sdk.config.temperature,
+ max_tokens=min(max_tokens, model_to_use.context_length // 2),
+ ):
+ if "content" in chunk:
+ yield chunk["content"]
+
+ generator = gen()
posthog_logger.capture_event(
"model_use",
@@ -641,9 +660,6 @@ Please output the code to be inserted at the cursor in order to fulfill the user
return
# Accumulate lines
- if "content" not in chunk:
- continue # ayo
- chunk = chunk["content"]
chunk_lines = chunk.split("\n")
chunk_lines[0] = unfinished_line + chunk_lines[0]
if chunk.endswith("\n"):