summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/main.yaml8
-rw-r--r--continuedev/src/continuedev/core/autopilot.py1
-rw-r--r--continuedev/src/continuedev/core/context.py34
-rw-r--r--continuedev/src/continuedev/core/main.py19
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py10
-rw-r--r--continuedev/src/continuedev/libs/llm/prompts/edit.py2
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/diff.py7
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/dynamic.py14
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/embeddings.py9
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/file.py4
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/filetree.py7
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/github.py4
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/google.py8
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/highlighted_code.py3
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/search.py8
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/terminal.py7
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/url.py8
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py39
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py80
-rw-r--r--extension/react-app/src/components/ComboBox.tsx446
-rw-r--r--extension/react-app/src/components/EditableDiv.tsx84
-rw-r--r--extension/react-app/src/components/Layout.tsx22
-rw-r--r--extension/react-app/src/components/ProgressBar.tsx77
-rw-r--r--extension/react-app/src/components/index.ts4
-rw-r--r--extension/react-app/src/pages/gui.tsx45
-rw-r--r--extension/react-app/src/pages/settings.tsx9
-rw-r--r--extension/src/commands.ts13
-rw-r--r--extension/src/diffs.ts2
-rw-r--r--extension/src/lang-server/codeLens.ts4
29 files changed, 819 insertions, 159 deletions
diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml
index 6674d296..1c4fb4dc 100644
--- a/.github/workflows/main.yaml
+++ b/.github/workflows/main.yaml
@@ -256,10 +256,10 @@ jobs:
cd extension
npm ci
- # - name: Publish (Open VSX Registry)
- # run: |
- # cd extension
- # npx ovsx publish -p ${{ secrets.VSX_REGISTRY_TOKEN }} --packagePath ./build/*.vsix
+ - name: Publish (Open VSX Registry)
+ run: |
+ cd extension
+ npx ovsx publish -p ${{ secrets.VSX_REGISTRY_TOKEN }} --packagePath ./build/*.vsix
- name: Publish
run: |
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index cee7a2f9..a943a35f 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -177,6 +177,7 @@ class Autopilot(ContinueBaseModel):
session_info=self.session_info,
config=self.continue_sdk.config,
saved_context_groups=self._saved_context_groups,
+ context_providers=self.context_manager.get_provider_descriptions(),
)
self.full_state = full_state
return full_state
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index 571e5dc8..25f6be14 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -11,7 +11,13 @@ from ..libs.util.devdata import dev_data_logger
from ..libs.util.logging import logger
from ..libs.util.telemetry import posthog_logger
from ..server.meilisearch_server import poll_meilisearch_running, restart_meilisearch
-from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId
+from .main import (
+ ChatMessage,
+ ContextItem,
+ ContextItemDescription,
+ ContextItemId,
+ ContextProviderDescription,
+)
class ContinueSDK(BaseModel):
@@ -36,6 +42,11 @@ class ContextProvider(BaseModel):
delete_documents: Callable[[List[str]], Awaitable] = None
update_documents: Callable[[List[ContextItem], str], Awaitable] = None
+ display_title: str
+ description: str
+ dynamic: bool
+ requires_query: bool = False
+
selected_items: List[ContextItem] = []
def dict(self, *args, **kwargs):
@@ -168,6 +179,21 @@ class ContextManager:
It is responsible for compiling all of this information into a single prompt without exceeding the token limit.
"""
+ def get_provider_descriptions(self) -> List[ContextProviderDescription]:
+ """
+ Returns a list of ContextProviderDescriptions for each context provider.
+ """
+ return [
+ ContextProviderDescription(
+ title=provider.title,
+ display_title=provider.display_title,
+ description=provider.description,
+ dynamic=provider.dynamic,
+ requires_query=provider.requires_query,
+ )
+ for provider in self.context_providers.values()
+ ]
+
async def get_selected_items(self) -> List[ContextItem]:
"""
Returns all of the selected ContextItems.
@@ -242,6 +268,7 @@ class ContextManager:
"description": item.description.description,
"content": item.content,
"workspace_dir": workspace_dir,
+ "provider_name": item.description.id.provider_title,
}
for item in context_items
]
@@ -282,7 +309,9 @@ class ContextManager:
await globalSearchIndex.update_searchable_attributes(
["name", "description"]
)
- await globalSearchIndex.update_filterable_attributes(["workspace_dir"])
+ await globalSearchIndex.update_filterable_attributes(
+ ["workspace_dir", "provider_name"]
+ )
async def load_context_provider(provider: ContextProvider):
context_items = await provider.provide_context_items(workspace_dir)
@@ -293,6 +322,7 @@ class ContextManager:
"description": item.description.description,
"content": item.content,
"workspace_dir": workspace_dir,
+ "provider_name": provider.title,
}
for item in context_items
]
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index ec2e2a07..63a3e6a9 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -1,5 +1,5 @@
import json
-from typing import Coroutine, Dict, List, Literal, Optional, Union
+from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, validator
from pydantic.schema import schema
@@ -291,6 +291,14 @@ class ContinueConfig(ContinueBaseModel):
return original_dict
+class ContextProviderDescription(BaseModel):
+ title: str
+ display_title: str
+ description: str
+ dynamic: bool
+ requires_query: bool
+
+
class FullState(ContinueBaseModel):
"""A full state of the program, including the history"""
@@ -303,6 +311,7 @@ class FullState(ContinueBaseModel):
session_info: Optional[SessionInfo] = None
config: ContinueConfig
saved_context_groups: Dict[str, List[ContextItem]] = {}
+ context_providers: List[ContextProviderDescription] = []
class ContinueSDK:
@@ -392,13 +401,13 @@ class Validator(Step):
class Context:
- key_value: Dict[str, str] = {}
+ key_value: Dict[str, Any] = {}
- def set(self, key: str, value: str):
+ def set(self, key: str, value: Any):
self.key_value[key] = value
- def get(self, key: str) -> str:
- return self.key_value[key]
+ def get(self, key: str) -> Any:
+ return self.key_value.get(key, None)
class ContinueCustomException(Exception):
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 85ca1969..653d2d6b 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -137,6 +137,7 @@ class LLM(ContinueBaseModel):
async def stream_complete(
self,
prompt: str,
+ raw: bool = False,
model: str = None,
temperature: float = None,
top_p: float = None,
@@ -163,7 +164,9 @@ class LLM(ContinueBaseModel):
prompt = prune_raw_prompt_from_top(
self.model, self.context_length, prompt, options.max_tokens
)
- prompt = self.template_prompt_like_messages(prompt)
+
+ if not raw:
+ prompt = self.template_prompt_like_messages(prompt)
self.write_log(f"Prompt: \n\n{prompt}")
@@ -181,6 +184,7 @@ class LLM(ContinueBaseModel):
async def complete(
self,
prompt: str,
+ raw: bool = False,
model: str = None,
temperature: float = None,
top_p: float = None,
@@ -207,7 +211,9 @@ class LLM(ContinueBaseModel):
prompt = prune_raw_prompt_from_top(
self.model, self.context_length, prompt, options.max_tokens
)
- prompt = self.template_prompt_like_messages(prompt)
+
+ if not raw:
+ prompt = self.template_prompt_like_messages(prompt)
self.write_log(f"Prompt: \n\n{prompt}")
diff --git a/continuedev/src/continuedev/libs/llm/prompts/edit.py b/continuedev/src/continuedev/libs/llm/prompts/edit.py
index a234fa61..275473bc 100644
--- a/continuedev/src/continuedev/libs/llm/prompts/edit.py
+++ b/continuedev/src/continuedev/libs/llm/prompts/edit.py
@@ -11,3 +11,5 @@ simplified_edit_prompt = dedent(
Output nothing except for the code. No code block, no English explanation, no start/end tags.
[/INST]"""
)
+
+codellama_infill_edit_prompt = "{{file_prefix}}<FILL>{{file_suffix}}"
diff --git a/continuedev/src/continuedev/plugins/context_providers/diff.py b/continuedev/src/continuedev/plugins/context_providers/diff.py
index c8345d02..4c16cabf 100644
--- a/continuedev/src/continuedev/plugins/context_providers/diff.py
+++ b/continuedev/src/continuedev/plugins/context_providers/diff.py
@@ -7,6 +7,9 @@ from ...core.main import ContextItem, ContextItemDescription, ContextItemId
class DiffContextProvider(ContextProvider):
title = "diff"
+ display_title = "Diff"
+ description = "Output of 'git diff' in current repo"
+ dynamic = True
DIFF_CONTEXT_ITEM_ID = "diff"
@@ -30,8 +33,8 @@ class DiffContextProvider(ContextProvider):
return [self.BASE_CONTEXT_ITEM]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if not id.item_id == self.DIFF_CONTEXT_ITEM_ID:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
diff = subprocess.check_output(["git", "diff"], cwd=self.workspace_dir).decode(
"utf-8"
diff --git a/continuedev/src/continuedev/plugins/context_providers/dynamic.py b/continuedev/src/continuedev/plugins/context_providers/dynamic.py
index e6ea0e88..50567621 100644
--- a/continuedev/src/continuedev/plugins/context_providers/dynamic.py
+++ b/continuedev/src/continuedev/plugins/context_providers/dynamic.py
@@ -13,16 +13,12 @@ class DynamicProvider(ContextProvider, ABC):
"""
title: str
- """
- A name representing the provider. Probably use capitalized version of title
- """
+ """A name representing the provider. Probably use capitalized version of title"""
+
name: str
- """
- A description for the provider
- """
- description: str
workspace_dir: str = None
+ dynamic: bool = True
@property
def BASE_CONTEXT_ITEM(self):
@@ -41,8 +37,8 @@ class DynamicProvider(ContextProvider, ABC):
return [self.BASE_CONTEXT_ITEM]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if id.item_id != self.title:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
query = query.lstrip(self.title + " ")
results = await self.get_content(query)
diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py
index 3f37232e..86cba311 100644
--- a/continuedev/src/continuedev/plugins/context_providers/embeddings.py
+++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py
@@ -17,6 +17,11 @@ class EmbeddingResult(BaseModel):
class EmbeddingsProvider(ContextProvider):
title = "embed"
+ display_title = "Embeddings Search"
+ description = "Search the codebase using embeddings"
+ dynamic = True
+ requires_query = True
+
workspace_directory: str
EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings"
@@ -62,8 +67,8 @@ class EmbeddingsProvider(ContextProvider):
return [self.BASE_CONTEXT_ITEM]
async def add_context_item(self, id: ContextItemId, query: str):
- if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
results = await self._get_query_results(query)
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py
index 6b27a889..c4a61193 100644
--- a/continuedev/src/continuedev/plugins/context_providers/file.py
+++ b/continuedev/src/continuedev/plugins/context_providers/file.py
@@ -27,6 +27,10 @@ class FileContextProvider(ContextProvider):
title = "file"
ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS
+ display_title = "Files"
+ description = "Reference files in the current workspace"
+ dynamic = False
+
async def start(self, *args):
await super().start(*args)
diff --git a/continuedev/src/continuedev/plugins/context_providers/filetree.py b/continuedev/src/continuedev/plugins/context_providers/filetree.py
index 959a0a66..968b761d 100644
--- a/continuedev/src/continuedev/plugins/context_providers/filetree.py
+++ b/continuedev/src/continuedev/plugins/context_providers/filetree.py
@@ -35,6 +35,9 @@ def split_path(path: str, with_root=None) -> List[str]:
class FileTreeContextProvider(ContextProvider):
title = "tree"
+ display_title = "File Tree"
+ description = "Add a formatted file tree of this directory to the context"
+ dynamic = True
workspace_dir: str = None
@@ -78,7 +81,7 @@ class FileTreeContextProvider(ContextProvider):
return [await self._filetree_context_item()]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if not id.item_id == self.title:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
return await self._filetree_context_item()
diff --git a/continuedev/src/continuedev/plugins/context_providers/github.py b/continuedev/src/continuedev/plugins/context_providers/github.py
index d394add1..7a16d3c9 100644
--- a/continuedev/src/continuedev/plugins/context_providers/github.py
+++ b/continuedev/src/continuedev/plugins/context_providers/github.py
@@ -20,6 +20,10 @@ class GitHubIssuesContextProvider(ContextProvider):
repo_name: str
auth_token: str
+ display_title = "GitHub Issues"
+ description = "Reference GitHub issues"
+ dynamic = False
+
async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
auth = Auth.Token(self.auth_token)
gh = Github(auth=auth)
diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py
index fc9d4555..493806cc 100644
--- a/continuedev/src/continuedev/plugins/context_providers/google.py
+++ b/continuedev/src/continuedev/plugins/context_providers/google.py
@@ -10,6 +10,10 @@ from .util import remove_meilisearch_disallowed_chars
class GoogleContextProvider(ContextProvider):
title = "google"
+ display_title = "Google"
+ description = "Search Google"
+ dynamic = True
+ requires_query = True
serper_api_key: str
@@ -42,8 +46,8 @@ class GoogleContextProvider(ContextProvider):
return [self.BASE_CONTEXT_ITEM]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if not id.item_id == self.GOOGLE_CONTEXT_ITEM_ID:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
results = await self._google_search(query)
json_results = json.loads(results)
diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
index 504764b9..0610a8c3 100644
--- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
+++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
@@ -28,6 +28,9 @@ class HighlightedCodeContextProvider(ContextProvider):
"""
title = "code"
+ display_title = "Highlighted Code"
+ description = "Highlight code"
+ dynamic = True
ide: Any # IdeProtocolServer
diff --git a/continuedev/src/continuedev/plugins/context_providers/search.py b/continuedev/src/continuedev/plugins/context_providers/search.py
index 8a3a3689..6df6f66c 100644
--- a/continuedev/src/continuedev/plugins/context_providers/search.py
+++ b/continuedev/src/continuedev/plugins/context_providers/search.py
@@ -11,6 +11,10 @@ from .util import remove_meilisearch_disallowed_chars
class SearchContextProvider(ContextProvider):
title = "search"
+ display_title = "Search"
+ description = "Search the workspace for all matches of an exact string (e.g. '@search console.log')"
+ dynamic = True
+ requires_query = True
SEARCH_CONTEXT_ITEM_ID = "search"
@@ -86,8 +90,8 @@ class SearchContextProvider(ContextProvider):
return [self.BASE_CONTEXT_ITEM]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if not id.item_id == self.SEARCH_CONTEXT_ITEM_ID:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
query = query.lstrip("search ")
results = await self._search(query)
diff --git a/continuedev/src/continuedev/plugins/context_providers/terminal.py b/continuedev/src/continuedev/plugins/context_providers/terminal.py
index 57d06a4f..f63ed676 100644
--- a/continuedev/src/continuedev/plugins/context_providers/terminal.py
+++ b/continuedev/src/continuedev/plugins/context_providers/terminal.py
@@ -6,6 +6,9 @@ from ...core.main import ChatMessage, ContextItem, ContextItemDescription, Conte
class TerminalContextProvider(ContextProvider):
title = "terminal"
+ display_title = "Terminal"
+ description = "Reference the contents of the terminal"
+ dynamic = True
workspace_dir: str = None
get_last_n_commands: int = 3
@@ -31,8 +34,8 @@ class TerminalContextProvider(ContextProvider):
return [self._terminal_context_item()]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
- if not id.item_id == self.title:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
terminal_contents = await self.sdk.ide.getTerminalContents(
self.get_last_n_commands
diff --git a/continuedev/src/continuedev/plugins/context_providers/url.py b/continuedev/src/continuedev/plugins/context_providers/url.py
index a5ec8990..c2c19cfb 100644
--- a/continuedev/src/continuedev/plugins/context_providers/url.py
+++ b/continuedev/src/continuedev/plugins/context_providers/url.py
@@ -10,6 +10,10 @@ from .util import remove_meilisearch_disallowed_chars
class URLContextProvider(ContextProvider):
title = "url"
+ display_title = "URL"
+ description = "Reference the contents of a webpage"
+ dynamic = True
+ requires_query = True
# Allows users to provide a list of preset urls
preset_urls: List[str] = []
@@ -78,8 +82,8 @@ class URLContextProvider(ContextProvider):
return matching_static_item
# Check if the item is the dynamic item
- if not id.item_id == self.DYNAMIC_URL_CONTEXT_ITEM_ID:
- raise Exception("Invalid item id")
+ if not id.provider_title == self.title:
+ raise Exception("Invalid provider title for item")
# Generate the dynamic item
url = query.lstrip("url ").strip()
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index 97235e6f..1d7ffdd7 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -604,7 +604,12 @@ Please output the code to be inserted at the cursor in order to fulfill the user
rendered = render_prompt_template(
template,
messages[:-1],
- {"code_to_edit": rif.contents, "user_input": self.user_input},
+ {
+ "code_to_edit": rif.contents,
+ "user_input": self.user_input,
+ "file_prefix": file_prefix,
+ "file_suffix": file_suffix,
+ },
)
if isinstance(rendered, str):
messages = [
@@ -617,11 +622,25 @@ Please output the code to be inserted at the cursor in order to fulfill the user
else:
messages = rendered
- generator = model_to_use.stream_chat(
- messages,
- temperature=sdk.config.temperature,
- max_tokens=min(max_tokens, model_to_use.context_length // 2),
- )
+ generator = model_to_use.stream_complete(
+ rendered,
+ raw=True,
+ temperature=sdk.config.temperature,
+ max_tokens=min(max_tokens, model_to_use.context_length // 2),
+ )
+
+ else:
+
+ async def gen():
+ async for chunk in model_to_use.stream_chat(
+ messages,
+ temperature=sdk.config.temperature,
+ max_tokens=min(max_tokens, model_to_use.context_length // 2),
+ ):
+ if "content" in chunk:
+ yield chunk["content"]
+
+ generator = gen()
posthog_logger.capture_event(
"model_use",
@@ -641,9 +660,6 @@ Please output the code to be inserted at the cursor in order to fulfill the user
return
# Accumulate lines
- if "content" not in chunk:
- continue # ayo
- chunk = chunk["content"]
chunk_lines = chunk.split("\n")
chunk_lines[0] = unfinished_line + chunk_lines[0]
if chunk.endswith("\n"):
@@ -796,7 +812,6 @@ Please output the code to be inserted at the cursor in order to fulfill the user
rif_dict[rif.filepath] = rif.contents
for rif in rif_with_contents:
- await sdk.ide.setFileOpen(rif.filepath)
await sdk.ide.setSuggestionsLocked(rif.filepath, True)
await self.stream_rif(rif, sdk)
await sdk.ide.setSuggestionsLocked(rif.filepath, False)
@@ -826,6 +841,10 @@ Please output the code to be inserted at the cursor in order to fulfill the user
self.description += chunk
await sdk.update_ui()
+ sdk.context.set("last_edit_user_input", self.user_input)
+ sdk.context.set("last_edit_diff", changes)
+ sdk.context.set("last_edit_range", self.range_in_files[-1].range)
+
class EditFileStep(Step):
filepath: str
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index ca15aaab..cd3b30e0 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -1,4 +1,5 @@
import os
+import urllib.parse
from textwrap import dedent
from typing import Coroutine, List, Optional, Union
@@ -235,6 +236,58 @@ class StarCoderEditHighlightedCodeStep(Step):
await sdk.ide.setFileOpen(rif.filepath)
+class EditAlreadyEditedRangeStep(Step):
+ hide = True
+ model: Optional[LLM] = None
+ range_in_file: RangeInFile
+
+ user_input: str
+
+ _prompt = dedent(
+ """\
+ You were previously asked to edit this code. The request was:
+
+ "{prev_user_input}"
+
+ And you generated this diff:
+
+ {diff}
+
+ Could you please re-edit this code to follow these secondary instructions?
+
+ "{user_input}"
+ """
+ )
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ if os.path.basename(self.range_in_file.filepath) in os.listdir(
+ os.path.expanduser(os.path.join("~", ".continue", "diffs"))
+ ):
+ decoded_basename = urllib.parse.unquote(
+ os.path.basename(self.range_in_file.filepath)
+ )
+ self.range_in_file.filepath = decoded_basename
+
+ self.range_in_file.range = sdk.context.get("last_edit_range")
+
+ if self.range_in_file.range.start == self.range_in_file.range.end:
+ self.range_in_file.range = Range.from_entire_file(
+ await sdk.ide.readFile(self.range_in_file.filepath)
+ )
+
+ await sdk.run_step(
+ DefaultModelEditCodeStep(
+ model=self.model,
+ user_input=self._prompt.format(
+ prev_user_input=sdk.context.get("last_edit_user_input"),
+ diff=sdk.context.get("last_edit_diff"),
+ user_input=self.user_input,
+ ),
+ range_in_files=[self.range_in_file],
+ )
+ )
+
+
class EditHighlightedCodeStep(Step):
user_input: str = Field(
...,
@@ -258,13 +311,6 @@ class EditHighlightedCodeStep(Step):
highlighted_code = await sdk.ide.getHighlightedCode()
if highlighted_code is not None:
for rif in highlighted_code:
- if os.path.dirname(rif.filepath) == os.path.expanduser(
- os.path.join("~", ".continue", "diffs")
- ):
- raise ContinueCustomException(
- message="Please accept or reject the change before making another edit in this file.",
- title="Accept/Reject First",
- )
if rif.range.start == rif.range.end:
range_in_files.append(
RangeInFileWithContents.from_range_in_file(rif, "")
@@ -289,10 +335,22 @@ class EditHighlightedCodeStep(Step):
)
for range_in_file in range_in_files:
- if os.path.dirname(range_in_file.filepath) == os.path.expanduser(
- os.path.join("~", ".continue", "diffs")
- ):
- self.description = "Please accept or reject the change before making another edit in this file."
+ # Check whether re-editing
+ if (
+ os.path.dirname(range_in_file.filepath)
+ == os.path.expanduser(os.path.join("~", ".continue", "diffs"))
+ or urllib.parse.quote_plus(range_in_file.filepath)
+ in os.listdir(
+ os.path.expanduser(os.path.join("~", ".continue", "diffs"))
+ )
+ ) and sdk.context.get("last_edit_user_input") is not None:
+ await sdk.run_step(
+ EditAlreadyEditedRangeStep(
+ range_in_file=range_in_file,
+ user_input=self.user_input,
+ model=self.model,
+ )
+ )
return
args = {
diff --git a/extension/react-app/src/components/ComboBox.tsx b/extension/react-app/src/components/ComboBox.tsx
index 41b44684..c216e7d1 100644
--- a/extension/react-app/src/components/ComboBox.tsx
+++ b/extension/react-app/src/components/ComboBox.tsx
@@ -1,4 +1,5 @@
import React, {
+ useCallback,
useContext,
useEffect,
useImperativeHandle,
@@ -7,6 +8,8 @@ import React, {
import { useCombobox } from "downshift";
import styled from "styled-components";
import {
+ StyledTooltip,
+ buttonColor,
defaultBorderRadius,
lightGray,
secondaryDark,
@@ -19,6 +22,9 @@ import {
BookmarkIcon,
DocumentPlusIcon,
FolderArrowDownIcon,
+ ArrowLeftIcon,
+ PlusIcon,
+ ArrowRightIcon,
} from "@heroicons/react/24/outline";
import { ContextItem } from "../../../schema/FullState";
import { postVscMessage } from "../vscode";
@@ -60,7 +66,7 @@ const EmptyPillDiv = styled.div`
}
`;
-const MainTextInput = styled.textarea`
+const MainTextInput = styled.textarea<{ inQueryForDynamicProvider: boolean }>`
resize: none;
padding: 8px;
@@ -73,11 +79,16 @@ const MainTextInput = styled.textarea`
background-color: ${secondaryDark};
color: ${vscForeground};
z-index: 1;
- border: 1px solid transparent;
+ border: 1px solid
+ ${(props) =>
+ props.inQueryForDynamicProvider ? buttonColor : "transparent"};
&:focus {
- outline: 1px solid ${lightGray};
+ outline: 1px solid
+ ${(props) => (props.inQueryForDynamicProvider ? buttonColor : lightGray)};
border: 1px solid transparent;
+ background-color: ${(props) =>
+ props.inQueryForDynamicProvider ? `${buttonColor}22` : secondaryDark};
}
&::placeholder {
@@ -85,6 +96,37 @@ const MainTextInput = styled.textarea`
}
`;
+const DynamicQueryTitleDiv = styled.div`
+ position: absolute;
+ right: 0px;
+ top: 0px;
+ height: fit-content;
+ padding: 2px 4px;
+ border-radius: ${defaultBorderRadius};
+ z-index: 2;
+ color: white;
+ font-size: 12px;
+
+ background-color: ${buttonColor};
+`;
+
+const StyledPlusIcon = styled(PlusIcon)`
+ position: absolute;
+ right: 0px;
+ top: 0px;
+ height: fit-content;
+ padding: 0;
+ cursor: pointer;
+ border-radius: ${defaultBorderRadius};
+ z-index: 2;
+
+ background-color: ${vscBackground};
+
+ &:hover {
+ background-color: ${secondaryDark};
+ }
+`;
+
const UlMaxHeight = 300;
const Ul = styled.ul<{
hidden: boolean;
@@ -129,7 +171,8 @@ const Li = styled.li<{
${({ selected }) => selected && "font-weight: bold;"}
padding: 0.5rem 0.75rem;
display: flex;
- flex-direction: column;
+ flex-direction: row;
+ align-items: center;
${({ isLastItem }) => isLastItem && "border-bottom: 1px solid gray;"}
/* border-top: 1px solid gray; */
cursor: pointer;
@@ -164,37 +207,66 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {
const [items, setItems] = React.useState(props.items);
const inputRef = React.useRef<HTMLInputElement>(null);
- const [inputBoxHeight, setInputBoxHeight] = useState<string | undefined>(
- undefined
- );
// Whether the current input follows an '@' and should be treated as context query
const [currentlyInContextQuery, setCurrentlyInContextQuery] = useState(false);
+ const [nestedContextProvider, setNestedContextProvider] = useState<
+ any | undefined
+ >(undefined);
+ const [inQueryForContextProvider, setInQueryForContextProvider] = useState<
+ any | undefined
+ >(undefined);
- const { getInputProps, ...downshiftProps } = useCombobox({
- onSelectedItemChange: ({ selectedItem }) => {
- if (selectedItem?.id) {
- // Get the query from the input value
- const segs = downshiftProps.inputValue.split("@");
- const query = segs[segs.length - 1];
- const restOfInput = segs.splice(0, segs.length - 1).join("@");
+ useEffect(() => {
+ if (!currentlyInContextQuery) {
+ setNestedContextProvider(undefined);
+ setInQueryForContextProvider(undefined);
+ }
+ }, [currentlyInContextQuery]);
- // Tell server the context item was selected
- client?.selectContextItem(selectedItem.id, query);
+ const contextProviders = useSelector(
+ (state: RootStore) => state.serverState.context_providers
+ ) as any[];
- // Remove the '@' and the context query from the input
- if (downshiftProps.inputValue.includes("@")) {
- downshiftProps.setInputValue(restOfInput);
- }
- }
- },
- onInputValueChange({ inputValue, highlightedIndex }) {
+ const goBackToContextProviders = () => {
+ setCurrentlyInContextQuery(false);
+ setNestedContextProvider(undefined);
+ setInQueryForContextProvider(undefined);
+ downshiftProps.setInputValue("@");
+ };
+
+ useEffect(() => {
+ if (!nestedContextProvider) {
+ console.log("setting items", nestedContextProvider);
+ setItems(
+ contextProviders?.map((provider) => ({
+ name: provider.display_title,
+ description: provider.description,
+ id: provider.title,
+ })) || []
+ );
+ }
+ }, [nestedContextProvider]);
+
+ const onInputValueChangeCallback = useCallback(
+ ({ inputValue, highlightedIndex }: any) => {
+ // Clear the input
if (!inputValue) {
setItems([]);
+ setNestedContextProvider(undefined);
+ setCurrentlyInContextQuery(false);
return;
}
+ if (
+ inQueryForContextProvider &&
+ !inputValue.startsWith(`@${inQueryForContextProvider.title}`)
+ ) {
+ setInQueryForContextProvider(undefined);
+ }
+
props.onInputValueChange(inputValue);
+ // Handle context selection
if (inputValue.endsWith("@") || currentlyInContextQuery) {
const segs = inputValue?.split("@") || [];
@@ -202,46 +274,124 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {
// Get search results and return
setCurrentlyInContextQuery(true);
const providerAndQuery = segs[segs.length - 1] || "";
- // Only return context items from the current workspace - the index is currently shared between all sessions
- const workspaceFilter =
- workspacePaths && workspacePaths.length > 0
- ? `workspace_dir IN [ ${workspacePaths
- .map((path) => `"${path}"`)
- .join(", ")} ]`
- : undefined;
- searchClient
- .index(SEARCH_INDEX_NAME)
- .search(providerAndQuery, {
- filter: workspaceFilter,
- })
- .then((res) => {
- setItems(
- res.hits.map((hit) => {
- return {
- name: hit.name,
- description: hit.description,
- id: hit.id,
- content: hit.content,
- };
- })
- );
- })
- .catch(() => {
- // Swallow errors, because this simply is not supported on Windows at the moment
+
+ if (nestedContextProvider && !inputValue.endsWith("@")) {
+ // Search only within this specific context provider
+ getFilteredContextItemsForProvider(
+ nestedContextProvider.title,
+ providerAndQuery
+ ).then((res) => {
+ setItems(res);
});
+ } else {
+ // Search through the list of context providers
+ const filteredItems =
+ contextProviders
+ ?.filter(
+ (provider) =>
+ `@${provider.title}`
+ .toLowerCase()
+ .startsWith(inputValue.toLowerCase()) ||
+ `@${provider.display_title}`
+ .toLowerCase()
+ .startsWith(inputValue.toLowerCase())
+ )
+ .map((provider) => ({
+ name: provider.display_title,
+ description: provider.description,
+ id: provider.title,
+ })) || [];
+ setItems(filteredItems);
+ setCurrentlyInContextQuery(true);
+ }
return;
} else {
// Exit the '@' context menu
setCurrentlyInContextQuery(false);
- setItems;
+ setNestedContextProvider(undefined);
}
}
+
+ setNestedContextProvider(undefined);
+
+ // Handle slash commands
setItems(
props.items.filter((item) =>
item.name.toLowerCase().startsWith(inputValue.toLowerCase())
)
);
},
+ [
+ props.items,
+ currentlyInContextQuery,
+ nestedContextProvider,
+ inQueryForContextProvider,
+ ]
+ );
+
+ const onSelectedItemChangeCallback = useCallback(
+ ({ selectedItem }: any) => {
+ if (!selectedItem) return;
+ if (selectedItem.id) {
+ // Get the query from the input value
+ const segs = downshiftProps.inputValue.split("@");
+ const query = segs[segs.length - 1];
+
+ // Tell server the context item was selected
+ client?.selectContextItem(selectedItem.id, query);
+ if (downshiftProps.inputValue.includes("@")) {
+ const selectedNestedContextProvider = contextProviders.find(
+ (provider) => provider.title === selectedItem.id
+ );
+ if (
+ !nestedContextProvider &&
+ !selectedNestedContextProvider?.dynamic
+ ) {
+ downshiftProps.setInputValue(`@${selectedItem.id} `);
+ setNestedContextProvider(selectedNestedContextProvider);
+ } else {
+ downshiftProps.setInputValue("");
+ }
+ }
+ }
+ },
+ [nestedContextProvider, contextProviders, client]
+ );
+
+ const getFilteredContextItemsForProvider = async (
+ provider: string,
+ query: string
+ ) => {
+ // Only return context items from the current workspace - the index is currently shared between all sessions
+ const workspaceFilter =
+ workspacePaths && workspacePaths.length > 0
+ ? `workspace_dir IN [ ${workspacePaths
+ .map((path) => `"${path}"`)
+ .join(", ")} ] AND provider_name = '${provider}'`
+ : undefined;
+ try {
+ const res = await searchClient.index(SEARCH_INDEX_NAME).search(query, {
+ filter: workspaceFilter,
+ });
+ return (
+ res?.hits.map((hit) => {
+ return {
+ name: hit.name,
+ description: hit.description,
+ id: hit.id,
+ content: hit.content,
+ };
+ }) || []
+ );
+ } catch (e) {
+ console.log("Error searching context items", e);
+ return [];
+ }
+ };
+
+ const { getInputProps, ...downshiftProps } = useCombobox({
+ onSelectedItemChange: onSelectedItemChangeCallback,
+ onInputValueChange: onInputValueChangeCallback,
items,
itemToString(item) {
return item ? item.name : "";
@@ -348,6 +498,42 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {
};
}, [inputRef.current]);
+ const selectContextItemFromDropdown = useCallback(
+ (event: any) => {
+ const newProviderName = items[downshiftProps.highlightedIndex].name;
+ const newProvider = contextProviders.find(
+ (provider) => provider.display_title === newProviderName
+ );
+
+ if (!newProvider) {
+ (event.nativeEvent as any).preventDownshiftDefault = true;
+ return;
+ } else if (newProvider.dynamic && newProvider.requires_query) {
+ setInQueryForContextProvider(newProvider);
+ downshiftProps.setInputValue(`@${newProvider.title} `);
+ (event.nativeEvent as any).preventDownshiftDefault = true;
+ event.preventDefault();
+ return;
+ } else if (newProvider.dynamic) {
+ return;
+ }
+
+ setNestedContextProvider(newProvider);
+ downshiftProps.setInputValue(`@${newProvider.title} `);
+ (event.nativeEvent as any).preventDownshiftDefault = true;
+ event.preventDefault();
+ getFilteredContextItemsForProvider(newProvider.title, "").then((items) =>
+ setItems(items)
+ );
+ },
+ [
+ items,
+ downshiftProps.highlightedIndex,
+ contextProviders,
+ nestedContextProvider,
+ ]
+ );
+
const showSelectContextGroupDialog = () => {
dispatch(setDialogMessage(<SelectContextGroupDialog />));
dispatch(setShowDialog(true));
@@ -409,21 +595,6 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {
</HeaderButtonWithText>
{props.selectedContextItems.length > 0 && (
<>
- <HeaderButtonWithText
- text="Bookmark context"
- onClick={() => {
- showDialogToSaveContextGroup();
- }}
- className="pill-button focus:outline-none focus:border-red-600 focus:border focus:border-solid"
- onKeyDown={(e: KeyboardEvent) => {
- e.preventDefault();
- if (e.key === "Enter") {
- showDialogToSaveContextGroup();
- }
- }}
- >
- <BookmarkIcon width="1.4em" height="1.4em" />
- </HeaderButtonWithText>
{props.addingHighlightedCode ? (
<EmptyPillDiv
onClick={() => {
@@ -449,11 +620,33 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {
<DocumentPlusIcon width="1.4em" height="1.4em" />
</HeaderButtonWithText>
)}
+ <HeaderButtonWithText
+ text="Bookmark context"
+ onClick={() => {
+ showDialogToSaveContextGroup();
+ }}
+ className="pill-button focus:outline-none focus:border-red-600 focus:border focus:border-solid"
+ onKeyDown={(e: KeyboardEvent) => {
+ e.preventDefault();
+ if (e.key === "Enter") {
+ showDialogToSaveContextGroup();
+ }
+ }}
+ >
+ <BookmarkIcon width="1.4em" height="1.4em" />
+ </HeaderButtonWithText>
</>
)}
</div>
- <div className="flex px-2" ref={divRef} hidden={!downshiftProps.isOpen}>
+ <div
+ className="flex px-2 relative"
+ ref={divRef}
+ hidden={!downshiftProps.isOpen}
+ >
<MainTextInput
+ inQueryForDynamicProvider={
+ typeof inQueryForContextProvider !== "undefined"
+ }
disabled={props.disabled}
placeholder={`Ask a question, give instructions, type '/' for slash commands, or '@' to add context`}
{...getInputProps({
@@ -467,7 +660,6 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {
target.scrollHeight,
300
).toString()}px`;
- setInputBoxHeight(target.style.height);
// setShowContextDropdown(target.value.endsWith("@"));
},
@@ -487,17 +679,34 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {
!isComposing
) {
const value = downshiftProps.inputValue;
- if (value !== "") {
- setPositionInHistory(history.length + 1);
- setHistory([...history, value]);
- }
- // Prevent Downshift's default 'Enter' behavior.
- (event.nativeEvent as any).preventDownshiftDefault = true;
+ if (inQueryForContextProvider) {
+ const segs = value.split("@");
+ client?.selectContextItem(
+ inQueryForContextProvider.title,
+ segs[segs.length - 1]
+ );
+ setCurrentlyInContextQuery(false);
+ downshiftProps.setInputValue("");
+ return;
+ } else {
+ if (value !== "") {
+ setPositionInHistory(history.length + 1);
+ setHistory([...history, value]);
+ }
+ // Prevent Downshift's default 'Enter' behavior.
+ (event.nativeEvent as any).preventDownshiftDefault = true;
- if (props.onEnter) {
- props.onEnter(event);
+ if (props.onEnter) {
+ props.onEnter(event);
+ }
}
setCurrentlyInContextQuery(false);
+ } else if (
+ event.key === "Enter" &&
+ currentlyInContextQuery &&
+ nestedContextProvider === undefined
+ ) {
+ selectContextItemFromDropdown(event);
} else if (event.key === "Tab" && items.length > 0) {
downshiftProps.setInputValue(items[0].name);
event.preventDefault();
@@ -545,6 +754,16 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {
);
setCurrentlyInContextQuery(false);
} else if (event.key === "Escape") {
+ if (nestedContextProvider) {
+ goBackToContextProviders();
+ (event.nativeEvent as any).preventDownshiftDefault = true;
+ return;
+ } else if (inQueryForContextProvider) {
+ goBackToContextProviders();
+ (event.nativeEvent as any).preventDownshiftDefault = true;
+ return;
+ }
+
setCurrentlyInContextQuery(false);
if (downshiftProps.isOpen && items.length > 0) {
downshiftProps.closeMenu();
@@ -570,6 +789,27 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {
ref: inputRef,
})}
/>
+ {inQueryForContextProvider ? (
+ <DynamicQueryTitleDiv>
+ Enter {inQueryForContextProvider.display_title} Query
+ </DynamicQueryTitleDiv>
+ ) : (
+ <>
+ <StyledPlusIcon
+ width="1.4em"
+ height="1.4em"
+ data-tooltip-id="add-context-button"
+ onClick={() => {
+ downshiftProps.setInputValue("@");
+ inputRef.current?.focus();
+ }}
+ />
+ <StyledTooltip id="add-context-button" place="bottom">
+ Add Context to Prompt
+ </StyledTooltip>
+ </>
+ )}
+
<Ul
{...downshiftProps.getMenuProps({
ref: ulRef,
@@ -578,20 +818,72 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {
ulHeightPixels={ulRef.current?.getBoundingClientRect().height || 0}
hidden={!downshiftProps.isOpen || items.length === 0}
>
+ {nestedContextProvider && (
+ <div
+ style={{
+ backgroundColor: secondaryDark,
+ borderBottom: `1px solid ${lightGray}`,
+ display: "flex",
+ gap: "4px",
+ position: "sticky",
+ top: "0px",
+ }}
+ className="py-2 px-4 my-0"
+ >
+ <ArrowLeftIcon
+ width="1.4em"
+ height="1.4em"
+ className="cursor-pointer"
+ onClick={() => {
+ goBackToContextProviders();
+ }}
+ />
+ {nestedContextProvider.display_title} -{" "}
+ {nestedContextProvider.description}
+ </div>
+ )}
{downshiftProps.isOpen &&
items.map((item, index) => (
<Li
- style={{ borderTop: index === 0 ? "none" : undefined }}
+ style={{
+ borderTop: index === 0 ? "none" : undefined,
+ }}
key={`${item.name}${index}`}
{...downshiftProps.getItemProps({ item, index })}
highlighted={downshiftProps.highlightedIndex === index}
selected={downshiftProps.selectedItem === item}
+ onClick={(e) => {
+ // e.stopPropagation();
+ // e.preventDefault();
+ // (e.nativeEvent as any).preventDownshiftDefault = true;
+ // downshiftProps.selectItem(item);
+ selectContextItemFromDropdown(e);
+ onSelectedItemChangeCallback({ selectedItem: item });
+ }}
>
<span>
{item.name}
{" "}
- <span style={{ color: lightGray }}>{item.description}</span>
+ <span
+ style={{
+ color: lightGray,
+ }}
+ >
+ {item.description}
+ </span>
</span>
+ {contextProviders
+ .filter(
+ (provider) => !provider.dynamic || provider.requires_query
+ )
+ .find((provider) => provider.title === item.id) && (
+ <ArrowRightIcon
+ width="1.2em"
+ height="1.2em"
+ color={lightGray}
+ className="ml-2"
+ />
+ )}
</Li>
))}
</Ul>
diff --git a/extension/react-app/src/components/EditableDiv.tsx b/extension/react-app/src/components/EditableDiv.tsx
new file mode 100644
index 00000000..a86bd692
--- /dev/null
+++ b/extension/react-app/src/components/EditableDiv.tsx
@@ -0,0 +1,84 @@
+import styled from "styled-components";
+import {
+ defaultBorderRadius,
+ lightGray,
+ secondaryDark,
+ vscForeground,
+} from ".";
+
+const Div = styled.div`
+ resize: none;
+
+ padding: 8px;
+ font-size: 13px;
+ font-family: inherit;
+ border-radius: ${defaultBorderRadius};
+ margin: 8px auto;
+ height: auto;
+ width: 100%;
+ background-color: ${secondaryDark};
+ color: ${vscForeground};
+ z-index: 1;
+ border: 1px solid transparent;
+
+ &:focus {
+ outline: 1px solid ${lightGray};
+ border: 1px solid transparent;
+ }
+
+ &::placeholder {
+ color: ${lightGray}80;
+ }
+`;
+
+const Span = styled.span<{ color?: string }>`
+ background-color: ${(props) => props.color || "#2cf8"};
+ border-radius: ${defaultBorderRadius};
+ padding: 2px 4px;
+`;
+
+interface EditableDivProps {
+ onChange: (e: any) => void;
+ value?: string;
+}
+
+function EditableDiv(props: EditableDivProps) {
+ return (
+ <Div
+ suppressContentEditableWarning={true}
+ contentEditable={true}
+ onChange={(e) => {
+ const target = e.target as HTMLTextAreaElement;
+ // Update the height of the textarea to match the content, up to a max of 200px.
+ target.style.height = "auto";
+ target.style.height = `${Math.min(
+ target.scrollHeight,
+ 300
+ ).toString()}px`;
+
+ // setShowContextDropdown(target.value.endsWith("@"));
+ props.onChange(e);
+ }}
+ onKeyDown={(e) => {
+ // if (e.key === "Delete") {
+ // // Delete spans if they are last child
+ // const selection = window.getSelection();
+ // const range = selection?.getRangeAt(0);
+ // const node = range?.startContainer;
+ // console.log("Del");
+ // if (node?.nodeName === "SPAN") {
+ // console.log("span");
+ // const parent = node.parentNode;
+ // if (parent?.childNodes.length === 1) {
+ // parent.removeChild(node);
+ // }
+ // }
+ // }
+ }}
+ >
+ {props.value ? props.value : <Span contentEditable={false}>testing</Span>}
+ </Div>
+ );
+}
+
+export default EditableDiv;
diff --git a/extension/react-app/src/components/Layout.tsx b/extension/react-app/src/components/Layout.tsx
index 17100c7f..6410db8a 100644
--- a/extension/react-app/src/components/Layout.tsx
+++ b/extension/react-app/src/components/Layout.tsx
@@ -21,8 +21,9 @@ import {
Cog6ToothIcon,
} from "@heroicons/react/24/outline";
import HeaderButtonWithText from "./HeaderButtonWithText";
-import { useNavigate } from "react-router-dom";
+import { useNavigate, useLocation } from "react-router-dom";
import ModelSelect from "./ModelSelect";
+import ProgressBar from "./ProgressBar";
// #region Styled Components
const FOOTER_HEIGHT = "1.8em";
@@ -74,6 +75,7 @@ const GridDiv = styled.div`
const Layout = () => {
const navigate = useNavigate();
+ const location = useLocation();
const client = useContext(GUIClientContext);
const dispatch = useDispatch();
const dialogMessage = useSelector(
@@ -82,10 +84,11 @@ const Layout = () => {
const showDialog = useSelector(
(state: RootStore) => state.uiState.showDialog
);
- const dialogEntryOn = useSelector(
- (state: RootStore) => state.uiState.dialogEntryOn
- );
+ const defaultModel = useSelector(
+ (state: RootStore) =>
+ (state.serverState.config as any).models?.default?.class_name
+ );
// #region Selectors
const bottomMessage = useSelector(
@@ -175,6 +178,17 @@ const Layout = () => {
)}
<ModelSelect />
+ {defaultModel === "MaybeProxyOpenAI" &&
+ (location.pathname === "/settings" ||
+ parseInt(localStorage.getItem("freeTrialCounter") || "0") >=
+ 125) && (
+ <ProgressBar
+ completed={parseInt(
+ localStorage.getItem("freeTrialCounter") || "0"
+ )}
+ total={250}
+ />
+ )}
</div>
<HeaderButtonWithText
onClick={() => {
diff --git a/extension/react-app/src/components/ProgressBar.tsx b/extension/react-app/src/components/ProgressBar.tsx
new file mode 100644
index 00000000..b4a2efc9
--- /dev/null
+++ b/extension/react-app/src/components/ProgressBar.tsx
@@ -0,0 +1,77 @@
+import React from "react";
+import styled from "styled-components";
+import { StyledTooltip, lightGray, vscForeground } from ".";
+
+const ProgressBarWrapper = styled.div`
+ width: 100px;
+ height: 6px;
+ border-radius: 6px;
+ border: 0.5px solid ${lightGray};
+ margin-top: 6px;
+`;
+
+const ProgressBarFill = styled.div<{ completed: number; color?: string }>`
+ height: 100%;
+ background-color: ${(props) => props.color || vscForeground};
+ border-radius: inherit;
+ transition: width 0.2s ease-in-out;
+ width: ${(props) => props.completed}%;
+`;
+
+const GridDiv = styled.div`
+ display: grid;
+ grid-template-rows: 1fr auto;
+ align-items: center;
+ justify-items: center;
+`;
+
+const P = styled.p`
+ margin: 0;
+ margin-top: 2px;
+ font-size: 12px;
+ color: ${lightGray};
+ text-align: center;
+`;
+
+interface ProgressBarProps {
+ completed: number;
+ total: number;
+}
+
+const ProgressBar = ({ completed, total }: ProgressBarProps) => {
+ const fillPercentage = Math.min(100, Math.max(0, (completed / total) * 100));
+
+ return (
+ <>
+ <a
+ href="https://continue.dev/docs/customization"
+ className="no-underline"
+ >
+ <GridDiv data-tooltip-id="usage_progress_bar">
+ <ProgressBarWrapper>
+ <ProgressBarFill
+ completed={fillPercentage}
+ color={
+ completed / total > 0.75
+ ? completed / total > 0.95
+ ? "#f00"
+ : "#fc0"
+ : undefined
+ }
+ />
+ </ProgressBarWrapper>
+ <P>
+ Free Usage: {completed} / {total}
+ </P>
+ </GridDiv>
+ </a>
+ <StyledTooltip id="usage_progress_bar" place="bottom">
+ {
+ "Continue allows you to use our OpenAI API key for up to 250 inputs. After this, you can either use your own API key, or use a local LLM. Click the progress bar to go to the docs and learn more."
+ }
+ </StyledTooltip>
+ </>
+ );
+};
+
+export default ProgressBar;
diff --git a/extension/react-app/src/components/index.ts b/extension/react-app/src/components/index.ts
index 25e35dd1..1f418c94 100644
--- a/extension/react-app/src/components/index.ts
+++ b/extension/react-app/src/components/index.ts
@@ -6,8 +6,8 @@ export const lightGray = "#646464";
// export const secondaryDark = "rgb(45 45 45)";
// export const vscBackground = "rgb(30 30 30)";
export const vscBackgroundTransparent = "#1e1e1ede";
-export const buttonColor = "rgb(27 190 132)";
-export const buttonColorHover = "rgb(27 190 132 0.67)";
+export const buttonColor = "#1bbe84";
+export const buttonColorHover = "1bbe84a8";
export const secondaryDark = "var(--vscode-list-hoverBackground)";
export const vscBackground = "var(--vscode-editor-background)";
diff --git a/extension/react-app/src/pages/gui.tsx b/extension/react-app/src/pages/gui.tsx
index cb62f7ed..a52e1ffc 100644
--- a/extension/react-app/src/pages/gui.tsx
+++ b/extension/react-app/src/pages/gui.tsx
@@ -66,6 +66,10 @@ function GUI(props: GUIProps) {
// #region Selectors
const history = useSelector((state: RootStore) => state.serverState.history);
+ const defaultModel = useSelector(
+ (state: RootStore) =>
+ (state.serverState.config as any).models?.default?.class_name
+ );
const user_input_queue = useSelector(
(state: RootStore) => state.serverState.user_input_queue
);
@@ -240,6 +244,43 @@ function GUI(props: GUIProps) {
return;
}
+ // Increment localstorage counter for usage of free trial
+ if (
+ defaultModel === "MaybeProxyOpenAI" &&
+ (!input.startsWith("/") || input.startsWith("/edit"))
+ ) {
+ const freeTrialCounter = localStorage.getItem("freeTrialCounter");
+ if (freeTrialCounter) {
+ const usages = parseInt(freeTrialCounter);
+ localStorage.setItem("freeTrialCounter", (usages + 1).toString());
+
+ if (usages >= 250) {
+ console.log("Free trial limit reached");
+ dispatch(setShowDialog(true));
+ dispatch(
+ setDialogMessage(
+ <div className="p-4">
+ <h3>Free Trial Limit Reached</h3>
+ You've reached the free trial limit of 250 free inputs with
+ Continue's OpenAI API key. To keep using Continue, you can
+ either use your own API key, or use a local LLM. To read more
+ about the options, see our{" "}
+ <a href="https://continue.dev/docs/customization">
+ documentation
+ </a>
+ . If you're just looking for fastest way to keep going, type
+ '/config' to open your Continue config file and paste your API
+ key into the MaybeProxyOpenAI object.
+ </div>
+ )
+ );
+ return;
+ }
+ } else {
+ localStorage.setItem("freeTrialCounter", "1");
+ }
+ }
+
setWaitingForSteps(true);
if (
@@ -266,7 +307,7 @@ function GUI(props: GUIProps) {
client.sendMainInput(input);
dispatch(temporarilyPushToUserInputQueue(input));
- // Increment localstorage counter
+ // Increment localstorage counter for popup
const counter = localStorage.getItem("mainTextEntryCounter");
if (counter) {
let currentCount = parseInt(counter);
@@ -274,7 +315,7 @@ function GUI(props: GUIProps) {
"mainTextEntryCounter",
(currentCount + 1).toString()
);
- if (currentCount === 100) {
+ if (currentCount === -300) {
dispatch(
setDialogMessage(
<div className="text-center p-4">
diff --git a/extension/react-app/src/pages/settings.tsx b/extension/react-app/src/pages/settings.tsx
index 9a3d3cc2..8b3d9c5b 100644
--- a/extension/react-app/src/pages/settings.tsx
+++ b/extension/react-app/src/pages/settings.tsx
@@ -4,19 +4,12 @@ import { useSelector } from "react-redux";
import { RootStore } from "../redux/store";
import { useNavigate } from "react-router-dom";
import { ContinueConfig } from "../../../schema/ContinueConfig";
-import {
- Button,
- Select,
- TextArea,
- lightGray,
- secondaryDark,
-} from "../components";
+import { Button, TextArea, lightGray, secondaryDark } from "../components";
import styled from "styled-components";
import { ArrowLeftIcon } from "@heroicons/react/24/outline";
import Loader from "../components/Loader";
import InfoHover from "../components/InfoHover";
import { FormProvider, useForm } from "react-hook-form";
-import ModelSettings from "../components/ModelSettings";
const Hr = styled.hr`
border: 0.5px solid ${lightGray};
diff --git a/extension/src/commands.ts b/extension/src/commands.ts
index 4ceac25d..7d190634 100644
--- a/extension/src/commands.ts
+++ b/extension/src/commands.ts
@@ -28,14 +28,11 @@ const commandsMap: { [command: string]: (...args: any) => any } = {
}
},
"continue.focusContinueInput": async () => {
- if (focusedOnContinueInput) {
- vscode.commands.executeCommand("workbench.action.focusActiveEditorGroup");
- } else {
- vscode.commands.executeCommand("continue.continueGUIView.focus");
- debugPanelWebview?.postMessage({
- type: "focusContinueInput",
- });
- }
+ vscode.commands.executeCommand("continue.continueGUIView.focus");
+ debugPanelWebview?.postMessage({
+ type: "focusContinueInput",
+ });
+
focusedOnContinueInput = !focusedOnContinueInput;
},
"continue.focusContinueInputWithEdit": async () => {
diff --git a/extension/src/diffs.ts b/extension/src/diffs.ts
index 98b8753a..b7acd109 100644
--- a/extension/src/diffs.ts
+++ b/extension/src/diffs.ts
@@ -62,7 +62,7 @@ class DiffManager {
}
private escapeFilepath(filepath: string): string {
- return filepath.replace(/\\/g, "_").replace(/\//g, "_");
+ return encodeURIComponent(filepath);
}
private remoteTmpDir: string = "/tmp/continue";
diff --git a/extension/src/lang-server/codeLens.ts b/extension/src/lang-server/codeLens.ts
index ba80e557..ec03f73e 100644
--- a/extension/src/lang-server/codeLens.ts
+++ b/extension/src/lang-server/codeLens.ts
@@ -69,6 +69,10 @@ class DiffViewerCodeLensProvider implements vscode.CodeLensProvider {
title: `Reject All ❌ (${getMetaKeyLabel()}⇧⌫)`,
command: "continue.rejectDiff",
arguments: [document.uri.fsPath],
+ }),
+ new vscode.CodeLens(range, {
+ title: `Further Edit ✏️ (${getMetaKeyLabel()}⇧M)`,
+ command: "continue.focusContinueInputWithEdit",
})
);
return codeLenses;