diff options
Diffstat (limited to 'continuedev')
38 files changed, 435 insertions, 264 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index a27b0cb7..5804ce6b 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -10,6 +10,7 @@ from aiohttp import ClientPayloadError from openai import error as openai_errors from pydantic import root_validator +from ..libs.llm.prompts.chat import template_alpaca_messages from ..libs.util.create_async_task import create_async_task from ..libs.util.devdata import dev_data_logger from ..libs.util.edit_config import edit_config_property @@ -201,7 +202,9 @@ class Autopilot(ContinueBaseModel): ) or [] ) - return custom_commands + slash_commands + cmds = custom_commands + slash_commands + cmds.sort(key=lambda x: x["name"] == "edit", reverse=True) + return cmds async def clear_history(self): # Reset history @@ -273,14 +276,16 @@ class Autopilot(ContinueBaseModel): await self._run_singular_step(step) async def handle_highlighted_code( - self, range_in_files: List[RangeInFileWithContents] + self, + range_in_files: List[RangeInFileWithContents], + edit: Optional[bool] = False, ): if "code" not in self.context_manager.context_providers: return # Add to context manager await self.context_manager.context_providers["code"].handle_highlighted_code( - range_in_files + range_in_files, edit ) await self.update_subscribers() @@ -292,7 +297,9 @@ class Autopilot(ContinueBaseModel): self._retry_queue.post(str(index), None) async def delete_at_index(self, index: int): - self.history.timeline[index].step.hide = True + if not self.history.timeline[index].active: + self.history.timeline[index].step.hide = True + self.history.timeline[index].deleted = True self.history.timeline[index].active = False @@ -476,9 +483,43 @@ class Autopilot(ContinueBaseModel): create_async_task( update_description(), - on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e)), + on_error=lambda e: self.continue_sdk.run_step( + DisplayErrorStep.from_exception(e) + ), ) + # Create the session title if not done yet + if self.session_info is None or self.session_info.title is None: + visible_nodes = list( + filter(lambda node: not node.step.hide, self.history.timeline) + ) + + user_input = None + should_create_title = False + for visible_node in visible_nodes: + if isinstance(visible_node.step, UserInputStep): + if user_input is None: + user_input = visible_node.step.user_input + else: + # More than one user input, so don't create title + should_create_title = False + break + elif user_input is None: + continue + else: + # Already have user input, now have the next step + should_create_title = True + break + + # Only create the title if the step after the first input is done + if should_create_title: + create_async_task( + self.create_title(backup=user_input), + on_error=lambda e: self.continue_sdk.run_step( + DisplayErrorStep.from_exception(e) + ), + ) + return observation async def run_from_step(self, step: "Step"): @@ -523,41 +564,43 @@ class Autopilot(ContinueBaseModel): self._should_halt = False return None - async def accept_user_input(self, user_input: str): - self._main_user_input_queue.append(user_input) - await self.update_subscribers() + def set_current_session_title(self, title: str): + self.session_info = SessionInfo( + title=title, + session_id=self.ide.session_id, + date_created=str(time.time()), + workspace_directory=self.ide.workspace_directory, + ) - # Use the first input to create title for session info, and make the session saveable - if self.session_info is None: + async def create_title(self, backup: str = None): + # Use the first input and first response to create title for session info, and make the session saveable + if self.session_info is not None and self.session_info.title is not None: + return - async def create_title(): - if ( - self.session_info is not None - and self.session_info.title is not None - ): - return + if self.continue_sdk.config.disable_summaries: + if backup is not None: + title = backup + else: + title = "New Session" + else: + chat_history = list( + map(lambda x: x.dict(), await self.continue_sdk.get_chat_context()) + ) + chat_history_str = template_alpaca_messages(chat_history) + title = await self.continue_sdk.models.summarize.complete( + f"{chat_history_str}\n\nGive a short title to describe the above chat session. Do not put quotes around the title. Do not use more than 6 words. The title is: ", + max_tokens=20, + log=False, + ) + title = remove_quotes_and_escapes(title) - if self.continue_sdk.config.disable_summaries: - title = user_input - else: - title = await self.continue_sdk.models.medium.complete( - f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". Do not use more than 10 words. The title is: ', - max_tokens=20, - ) - title = remove_quotes_and_escapes(title) - - self.session_info = SessionInfo( - title=title, - session_id=self.ide.session_id, - date_created=str(time.time()), - workspace_directory=self.ide.workspace_directory, - ) - dev_data_logger.capture("new_session", self.session_info.dict()) + self.set_current_session_title(title) + await self.update_subscribers() + dev_data_logger.capture("new_session", self.session_info.dict()) - create_async_task( - create_title(), - on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e)), - ) + async def accept_user_input(self, user_input: str): + self._main_user_input_queue.append(user_input) + await self.update_subscribers() if len(self._main_user_input_queue) > 1: return @@ -579,6 +622,15 @@ class Autopilot(ContinueBaseModel): await self.reverse_to_index(index) await self.run_from_step(UserInputStep(user_input=user_input)) + async def reject_diff(self, step_index: int): + # Hide the edit step and the UserInputStep before it + self.history.timeline[step_index].step.hide = True + for i in range(step_index - 1, -1, -1): + if isinstance(self.history.timeline[i].step, UserInputStep): + self.history.timeline[i].step.hide = True + break + await self.update_subscribers() + async def select_context_item(self, id: str, query: str): await self.context_manager.select_context_item(id, query) await self.update_subscribers() diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index d431c704..2bbb42cc 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Type from pydantic import BaseModel, Field, validator -from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from ..libs.llm.openai_free_trial import OpenAIFreeTrial from .context import ContextProvider from .main import Policy, Step from .models import Models @@ -48,8 +48,8 @@ class ContinueConfig(BaseModel): ) models: Models = Field( Models( - default=MaybeProxyOpenAI(model="gpt-4"), - medium=MaybeProxyOpenAI(model="gpt-3.5-turbo"), + default=OpenAIFreeTrial(model="gpt-4"), + summarize=OpenAIFreeTrial(model="gpt-3.5-turbo"), ), description="Configuration for the models used by Continue. Read more about how to configure models in the documentation.", ) diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 63a3e6a9..cf41aab9 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -337,6 +337,12 @@ class Step(ContinueBaseModel): hide: bool = False description: Union[str, None] = None + class_name: str = "Step" + + @validator("class_name", pre=True, always=True) + def class_name_is_class_name(cls, class_name): + return cls.__name__ + system_message: Union[str, None] = None chat_context: List[ChatMessage] = [] manage_own_chat_context: bool = False diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index f24c81ca..2396a0db 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -5,13 +5,14 @@ from pydantic import BaseModel from ..libs.llm import LLM from ..libs.llm.anthropic import AnthropicLLM from ..libs.llm.ggml import GGML +from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI +from ..libs.llm.hf_tgi import HuggingFaceTGI from ..libs.llm.llamacpp import LlamaCpp -from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI from ..libs.llm.ollama import Ollama from ..libs.llm.openai import OpenAI +from ..libs.llm.openai_free_trial import OpenAIFreeTrial from ..libs.llm.replicate import ReplicateLLM from ..libs.llm.together import TogetherLLM -from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI class ContinueSDK(BaseModel): @@ -20,9 +21,7 @@ class ContinueSDK(BaseModel): ALL_MODEL_ROLES = [ "default", - "small", - "medium", - "large", + "summarize", "edit", "chat", ] @@ -31,7 +30,7 @@ MODEL_CLASSES = { cls.__name__: cls for cls in [ OpenAI, - MaybeProxyOpenAI, + OpenAIFreeTrial, GGML, TogetherLLM, AnthropicLLM, @@ -39,12 +38,13 @@ MODEL_CLASSES = { Ollama, LlamaCpp, HuggingFaceInferenceAPI, + HuggingFaceTGI, ] } MODEL_MODULE_NAMES = { "OpenAI": "openai", - "MaybeProxyOpenAI": "maybe_proxy_openai", + "OpenAIFreeTrial": "openai_free_trial", "GGML": "ggml", "TogetherLLM": "together", "AnthropicLLM": "anthropic", @@ -52,6 +52,7 @@ MODEL_MODULE_NAMES = { "Ollama": "ollama", "LlamaCpp": "llamacpp", "HuggingFaceInferenceAPI": "hf_inference_api", + "HuggingFaceTGI": "hf_tgi", } @@ -59,13 +60,11 @@ class Models(BaseModel): """Main class that holds the current model configuration""" default: LLM - small: Optional[LLM] = None - medium: Optional[LLM] = None - large: Optional[LLM] = None + summarize: Optional[LLM] = None edit: Optional[LLM] = None chat: Optional[LLM] = None - unused: List[LLM] = [] + saved: List[LLM] = [] # TODO namespace these away to not confuse readers, # or split Models into ModelsConfig, which gets turned into Models @@ -89,7 +88,8 @@ class Models(BaseModel): def set_system_message(self, msg: str): for model in self.all_models: - model.system_message = msg + if model.system_message is None: + model.system_message = msg async def start(self, sdk: "ContinueSDK"): """Start each of the LLMs, or fall back to default""" diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 12fce1c6..64fd784c 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -104,7 +104,7 @@ class ContinueSDK(AbstractContinueSDK): ) await sdk.lsp.start() except Exception as e: - logger.warning(f"Failed to start LSP client: {e}", exc_info=True) + logger.warning(f"Failed to start LSP client: {e}", exc_info=False) sdk.lsp = None create_async_task( diff --git a/continuedev/src/continuedev/libs/constants/default_config.py b/continuedev/src/continuedev/libs/constants/default_config.py index d93dffcd..a1b2de2c 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py +++ b/continuedev/src/continuedev/libs/constants/default_config.py @@ -8,16 +8,14 @@ See https://continue.dev/docs/customization to for documentation of the availabl from continuedev.src.continuedev.core.models import Models from continuedev.src.continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig from continuedev.src.continuedev.plugins.context_providers.github import GitHubIssuesContextProvider -from continuedev.src.continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from continuedev.src.continuedev.libs.llm.openai_free_trial import OpenAIFreeTrial from continuedev.src.continuedev.plugins.steps.open_config import OpenConfigStep from continuedev.src.continuedev.plugins.steps.clear_history import ClearHistoryStep -from continuedev.src.continuedev.plugins.steps.feedback import FeedbackStep from continuedev.src.continuedev.plugins.steps.comment_code import CommentCodeStep from continuedev.src.continuedev.plugins.steps.share_session import ShareSessionStep from continuedev.src.continuedev.plugins.steps.main import EditHighlightedCodeStep from continuedev.src.continuedev.plugins.steps.cmd import GenerateShellCommandStep -from continuedev.src.continuedev.plugins.context_providers.search import SearchContextProvider from continuedev.src.continuedev.plugins.context_providers.diff import DiffContextProvider from continuedev.src.continuedev.plugins.context_providers.url import URLContextProvider from continuedev.src.continuedev.plugins.context_providers.terminal import TerminalContextProvider @@ -25,8 +23,8 @@ from continuedev.src.continuedev.plugins.context_providers.terminal import Termi config = ContinueConfig( allow_anonymous_telemetry=True, models=Models( - default=MaybeProxyOpenAI(api_key="", model="gpt-4"), - medium=MaybeProxyOpenAI(api_key="", model="gpt-3.5-turbo") + default=OpenAIFreeTrial(api_key="", model="gpt-4"), + summarize=OpenAIFreeTrial(api_key="", model="gpt-3.5-turbo") ), system_message=None, temperature=0.5, @@ -54,11 +52,6 @@ config = ContinueConfig( step=CommentCodeStep, ), SlashCommand( - name="feedback", - description="Send feedback to improve Continue", - step=FeedbackStep, - ), - SlashCommand( name="clear", description="Clear step history", step=ClearHistoryStep, @@ -79,7 +72,6 @@ config = ContinueConfig( # repo_name="<your github username or organization>/<your repo name>", # auth_token="<your github auth token>" # ), - SearchContextProvider(), DiffContextProvider(), URLContextProvider( preset_urls = [ diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index b2eecab6..28f614c7 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,5 +1,8 @@ +import ssl from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +import aiohttp +import certifi from pydantic import Field, validator from ...core.main import ChatMessage @@ -83,6 +86,10 @@ class LLM(ContinueBaseModel): None, description="Path to a custom CA bundle to use when making the HTTP request", ) + proxy: Optional[str] = Field( + None, + description="Proxy URL to use when making the HTTP request", + ) prompt_templates: dict = Field( {}, description='A dictionary of prompt templates that can be used to customize the behavior of the LLM in certain situations. For example, set the "edit" key in order to change the prompt that is used for the /edit slash command. Each value in the dictionary is a string templated in mustache syntax, and filled in at runtime with the variables specific to the situation. See the documentation for more information.', @@ -134,6 +141,10 @@ class LLM(ContinueBaseModel): "verify_ssl": { "description": "Whether to verify SSL certificates for requests." }, + "ca_bundle_path": { + "description": "Path to a custom CA bundle to use when making the HTTP request" + }, + "proxy": {"description": "Proxy URL to use when making the HTTP request"}, } def dict(self, **kwargs): @@ -155,6 +166,22 @@ class LLM(ContinueBaseModel): """Stop the connection to the LLM.""" pass + def create_client_session(self): + if self.verify_ssl is False: + return aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=False), + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) + else: + ca_bundle_path = ( + certifi.where() if self.ca_bundle_path is None else self.ca_bundle_path + ) + ssl_context = ssl.create_default_context(cafile=ca_bundle_path) + return aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl_context=ssl_context), + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) + def collect_args(self, options: CompletionOptions) -> Dict[str, Any]: """Collect the arguments for the LLM.""" args = {**DEFAULT_ARGS.copy(), "model": self.model} @@ -199,6 +226,7 @@ class LLM(ContinueBaseModel): stop: Optional[List[str]] = None, max_tokens: Optional[int] = None, functions: Optional[List[Any]] = None, + log: bool = True, ) -> Generator[Union[Any, List, Dict], None, None]: """Yield completion response, either streamed or not.""" options = CompletionOptions( @@ -220,14 +248,17 @@ class LLM(ContinueBaseModel): if not raw: prompt = self.template_prompt_like_messages(prompt) - self.write_log(f"Prompt: \n\n{prompt}") + if log: + self.write_log(prompt) completion = "" async for chunk in self._stream_complete(prompt=prompt, options=options): yield chunk completion += chunk - self.write_log(f"Completion: \n\n{completion}") + # if log: + # self.write_log(f"Completion: \n\n{completion}") + dev_data_logger.capture( "tokens_generated", {"model": self.model, "tokens": self.count_tokens(completion)}, @@ -246,6 +277,7 @@ class LLM(ContinueBaseModel): stop: Optional[List[str]] = None, max_tokens: Optional[int] = None, functions: Optional[List[Any]] = None, + log: bool = True, ) -> str: """Yield completion response, either streamed or not.""" options = CompletionOptions( @@ -267,11 +299,14 @@ class LLM(ContinueBaseModel): if not raw: prompt = self.template_prompt_like_messages(prompt) - self.write_log(f"Prompt: \n\n{prompt}") + if log: + self.write_log(prompt) completion = await self._complete(prompt=prompt, options=options) - self.write_log(f"Completion: \n\n{completion}") + # if log: + # self.write_log(f"Completion: \n\n{completion}") + dev_data_logger.capture( "tokens_generated", {"model": self.model, "tokens": self.count_tokens(completion)}, @@ -291,6 +326,7 @@ class LLM(ContinueBaseModel): stop: Optional[List[str]] = None, max_tokens: Optional[int] = None, functions: Optional[List[Any]] = None, + log: bool = True, ) -> Generator[Union[Any, List, Dict], None, None]: """Yield completion response, either streamed or not.""" options = CompletionOptions( @@ -313,7 +349,8 @@ class LLM(ContinueBaseModel): else: prompt = format_chat_messages(messages) - self.write_log(f"Prompt: \n\n{prompt}") + if log: + self.write_log(prompt) completion = "" @@ -328,7 +365,9 @@ class LLM(ContinueBaseModel): yield {"role": "assistant", "content": chunk} completion += chunk - self.write_log(f"Completion: \n\n{completion}") + # if log: + # self.write_log(f"Completion: \n\n{completion}") + dev_data_logger.capture( "tokens_generated", {"model": self.model, "tokens": self.count_tokens(completion)}, diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 2fd123bd..27a55dfe 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,8 +1,6 @@ import json -import ssl from typing import Any, Callable, Coroutine, Dict, List, Optional -import aiohttp from pydantic import Field from ...core.main import ChatMessage @@ -38,10 +36,6 @@ class GGML(LLM): "http://localhost:8000", description="URL of the OpenAI-compatible server where the model is being served", ) - proxy: Optional[str] = Field( - None, - description="Proxy URL to use when making the HTTP request", - ) model: str = Field( "ggml", description="The name of the model to use (optional for the GGML class)" ) @@ -57,20 +51,6 @@ class GGML(LLM): class Config: arbitrary_types_allowed = True - def create_client_session(self): - if self.ca_bundle_path is None: - ssl_context = ssl.create_default_context(cafile=self.ca_bundle_path) - tcp_connector = aiohttp.TCPConnector( - verify_ssl=self.verify_ssl, ssl=ssl_context - ) - else: - tcp_connector = aiohttp.TCPConnector(verify_ssl=self.verify_ssl) - - return aiohttp.ClientSession( - connector=tcp_connector, - timeout=aiohttp.ClientTimeout(total=self.timeout), - ) - def get_headers(self): headers = { "Content-Type": "application/json", diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py index a3672fe2..27d71cb4 100644 --- a/continuedev/src/continuedev/libs/llm/hf_tgi.py +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -1,7 +1,6 @@ import json from typing import Any, Callable, List -import aiohttp from pydantic import Field from ...core.main import ChatMessage @@ -36,14 +35,12 @@ class HuggingFaceTGI(LLM): async def _stream_complete(self, prompt, options): args = self.collect_args(options) - async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), - timeout=aiohttp.ClientTimeout(total=self.timeout), - ) as client_session: + async with self.create_client_session() as client_session: async with client_session.post( f"{self.server_url}/generate_stream", json={"inputs": prompt, "parameters": args}, headers={"Content-Type": "application/json"}, + proxy=self.proxy, ) as resp: async for line in resp.content.iter_any(): if line: diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index c795bd15..0b4c9fb0 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -1,7 +1,6 @@ import json from typing import Any, Callable, Dict -import aiohttp from pydantic import Field from ..llm import LLM @@ -70,14 +69,12 @@ class LlamaCpp(LLM): headers = {"Content-Type": "application/json"} async def server_generator(): - async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), - timeout=aiohttp.ClientTimeout(total=self.timeout), - ) as client_session: + async with self.create_client_session() as client_session: async with client_session.post( f"{self.server_url}/completion", json={"prompt": prompt, "stream": True, **args}, headers=headers, + proxy=self.proxy, ) as resp: async for line in resp.content: content = line.decode("utf-8") diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index b699398b..19d48a2f 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -5,6 +5,7 @@ import aiohttp from pydantic import Field from ..llm import LLM +from ..util.logging import logger from .prompts.chat import llama2_template_messages from .prompts.edit import simplified_edit_prompt @@ -43,9 +44,19 @@ class Ollama(LLM): async def start(self, **kwargs): await super().start(**kwargs) - self._client_session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=self.timeout) - ) + self._client_session = self.create_client_session() + try: + async with self._client_session.post( + f"{self.server_url}/api/generate", + proxy=self.proxy, + json={ + "prompt": "", + "model": self.model, + }, + ) as _: + pass + except Exception as e: + logger.warning(f"Error pre-loading Ollama model: {e}") async def stop(self): await self._client_session.close() @@ -59,6 +70,7 @@ class Ollama(LLM): "system": self.system_message, "options": {"temperature": options.temperature}, }, + proxy=self.proxy, ) as resp: async for line in resp.content.iter_any(): if line: diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/openai_free_trial.py index 3fdcb42e..367f2bbd 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/openai_free_trial.py @@ -6,9 +6,9 @@ from .openai import OpenAI from .proxy_server import ProxyServer -class MaybeProxyOpenAI(LLM): +class OpenAIFreeTrial(LLM): """ - With the `MaybeProxyOpenAI` `LLM`, new users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code. + With the `OpenAIFreeTrial` `LLM`, new users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code. Once you are using Continue regularly though, you will need to add an OpenAI API key that has access to GPT-4 by following these steps: @@ -21,13 +21,13 @@ class MaybeProxyOpenAI(LLM): config = ContinueConfig( ... models=Models( - default=MaybeProxyOpenAI(model="gpt-4", api_key=API_KEY), - medium=MaybeProxyOpenAI(model="gpt-3.5-turbo", api_key=API_KEY) + default=OpenAIFreeTrial(model="gpt-4", api_key=API_KEY), + summarize=OpenAIFreeTrial(model="gpt-3.5-turbo", api_key=API_KEY) ) ) ``` - The `MaybeProxyOpenAI` class will automatically switch to using your API key instead of ours. If you'd like to explicitly use one or the other, you can use the `ProxyServer` or `OpenAI` classes instead. + The `OpenAIFreeTrial` class will automatically switch to using your API key instead of ours. If you'd like to explicitly use one or the other, you can use the `ProxyServer` or `OpenAI` classes instead. These classes support any models available through the OpenAI API, assuming your API key has access, including "gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", and "gpt-4-32k". """ diff --git a/continuedev/src/continuedev/libs/llm/prompts/chat.py b/continuedev/src/continuedev/libs/llm/prompts/chat.py index 03230499..0bf8635b 100644 --- a/continuedev/src/continuedev/libs/llm/prompts/chat.py +++ b/continuedev/src/continuedev/libs/llm/prompts/chat.py @@ -28,8 +28,8 @@ def template_alpaca_messages(msgs: List[Dict[str, str]]) -> str: prompt += f"{msgs[0]['content']}\n" msgs.pop(0) - prompt += "### Instruction:\n" for msg in msgs: + prompt += "### Instruction:\n" if msg["role"] == "user" else "### Response:\n" prompt += f"{msg['content']}\n" prompt += "### Response:\n" diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 294c1713..d741fee4 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,10 +1,8 @@ import json -import ssl import traceback from typing import List import aiohttp -import certifi from ...core.main import ChatMessage from ..llm import LLM @@ -32,20 +30,7 @@ class ProxyServer(LLM): **kwargs, ): await super().start(**kwargs) - if self.verify_ssl is False: - self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=False), - timeout=aiohttp.ClientTimeout(total=self.timeout), - ) - else: - ca_bundle_path = ( - certifi.where() if self.ca_bundle_path is None else self.ca_bundle_path - ) - ssl_context = ssl.create_default_context(cafile=ca_bundle_path) - self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(ssl_context=ssl_context), - timeout=aiohttp.ClientTimeout(total=self.timeout), - ) + self._client_session = self.create_client_session() self.context_length = MAX_TOKENS_FOR_MODEL[self.model] @@ -62,6 +47,7 @@ class ProxyServer(LLM): f"{SERVER_URL}/complete", json={"messages": [{"role": "user", "content": prompt}], **args}, headers=self.get_headers(), + proxy=self.proxy, ) as resp: resp_text = await resp.text() if resp.status != 200: @@ -75,6 +61,7 @@ class ProxyServer(LLM): f"{SERVER_URL}/stream_chat", json={"messages": messages, **args}, headers=self.get_headers(), + proxy=self.proxy, ) as resp: if resp.status != 200: raise Exception(await resp.text()) @@ -110,6 +97,7 @@ class ProxyServer(LLM): f"{SERVER_URL}/stream_complete", json={"messages": [{"role": "user", "content": prompt}], **args}, headers=self.get_headers(), + proxy=self.proxy, ) as resp: if resp.status != 200: raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index 257f9a8f..b679351c 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -1,5 +1,5 @@ import json -from typing import Callable, Optional +from typing import Callable import aiohttp from pydantic import Field @@ -68,6 +68,7 @@ class TogetherLLM(LLM): **args, }, headers={"Authorization": f"Bearer {self.api_key}"}, + proxy=self.proxy, ) as resp: async for line in resp.content.iter_chunks(): if line[1]: @@ -99,6 +100,7 @@ class TogetherLLM(LLM): f"{self.base_url}/inference", json={"prompt": prompt, **args}, headers={"Authorization": f"Bearer {self.api_key}"}, + proxy=self.proxy, ) as resp: text = await resp.text() j = json.loads(text) diff --git a/continuedev/src/continuedev/libs/util/logging.py b/continuedev/src/continuedev/libs/util/logging.py index 4a550168..b4799abb 100644 --- a/continuedev/src/continuedev/libs/util/logging.py +++ b/continuedev/src/continuedev/libs/util/logging.py @@ -1,13 +1,31 @@ import logging +import os from .paths import getLogFilePath +logfile_path = getLogFilePath() + +try: + # Truncate the logs that are more than a day old + if os.path.exists(logfile_path): + tail = None + with open(logfile_path, "rb") as f: + f.seek(-32 * 1024, os.SEEK_END) + tail = f.read().decode("utf-8") + + if tail is not None: + with open(logfile_path, "w") as f: + f.write(tail) + +except Exception as e: + print("Error truncating log file: {}".format(e)) + # Create a logger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) # Create a file handler -file_handler = logging.FileHandler(getLogFilePath()) +file_handler = logging.FileHandler(logfile_path) file_handler.setLevel(logging.DEBUG) # Create a console handler diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index e8bbd4ba..9d4eccd6 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -1,4 +1,5 @@ import os +from typing import Optional from ..constants.default_config import default_config from ..constants.main import ( @@ -73,6 +74,22 @@ def getSessionsListFilePath(): return path +def migrateConfigFile(existing: str) -> Optional[str]: + if existing.strip() == "": + return default_config + + migrated = ( + existing.replace("MaybeProxyOpenAI", "OpenAIFreeTrial") + .replace("maybe_proxy_openai", "openai_free_trial") + .replace("unused=", "saved=") + .replace("medium=", "summarize=") + ) + if migrated != existing: + return migrated + + return None + + def getConfigFilePath() -> str: path = os.path.join(getGlobalFolderPath(), "config.py") os.makedirs(os.path.dirname(path), exist_ok=True) @@ -81,12 +98,15 @@ def getConfigFilePath() -> str: with open(path, "w") as f: f.write(default_config) else: + # Make any necessary migrations with open(path, "r") as f: existing_content = f.read() - if existing_content.strip() == "": + migrated = migrateConfigFile(existing_content) + + if migrated is not None: with open(path, "w") as f: - f.write(default_config) + f.write(migrated) return path diff --git a/continuedev/src/continuedev/models/main.py b/continuedev/src/continuedev/models/main.py index 34c557e0..5519d718 100644 --- a/continuedev/src/continuedev/models/main.py +++ b/continuedev/src/continuedev/models/main.py @@ -116,6 +116,12 @@ class Range(BaseModel): def contains(self, position: Position) -> bool: return self.start <= position and position <= self.end + def merge_with(self, other: "Range") -> "Range": + return Range( + start=min(self.start, other.start).copy(), + end=max(self.end, other.end).copy(), + ) + @staticmethod def from_indices(string: str, start_index: int, end_index: int) -> "Range": return Range( diff --git a/continuedev/src/continuedev/models/reference/test.py b/continuedev/src/continuedev/models/reference/test.py index 1cebfc36..87f01ede 100644 --- a/continuedev/src/continuedev/models/reference/test.py +++ b/continuedev/src/continuedev/models/reference/test.py @@ -14,7 +14,7 @@ LLM_MODULES = [ ("together", "TogetherLLM"), ("hf_inference_api", "HuggingFaceInferenceAPI"), ("hf_tgi", "HuggingFaceTGI"), - ("maybe_proxy_openai", "MaybeProxyOpenAI"), + ("openai_free_trial", "OpenAIFreeTrial"), ("queued", "QueuedLLM"), ] @@ -101,7 +101,7 @@ for module_name, module_title in LLM_MODULES: markdown_docs = docs_from_schema( schema, f"libs/llm/{module_name}.py", inherited=ctx_properties ) - with open(f"docs/docs/reference/Models/{module_name}.md", "w") as f: + with open(f"docs/docs/reference/Models/{module_title.lower()}.md", "w") as f: f.write(markdown_docs) config_module = importlib.import_module("continuedev.src.continuedev.core.config") @@ -130,7 +130,9 @@ for module_name, module_title in CONTEXT_PROVIDER_MODULES: ], inherited=ctx_properties, ) - with open(f"docs/docs/reference/Context Providers/{module_name}.md", "w") as f: + with open( + f"docs/docs/reference/Context Providers/{module_title.lower()}.md", "w" + ) as f: f.write(markdown_docs) # sdk_module = importlib.import_module("continuedev.src.continuedev.core.sdk") diff --git a/continuedev/src/continuedev/plugins/context_providers/diff.py b/continuedev/src/continuedev/plugins/context_providers/diff.py index 157cbc33..05da3547 100644 --- a/continuedev/src/continuedev/plugins/context_providers/diff.py +++ b/continuedev/src/continuedev/plugins/context_providers/diff.py @@ -4,7 +4,12 @@ from typing import List from pydantic import Field from ...core.context import ContextProvider -from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...core.main import ( + ContextItem, + ContextItemDescription, + ContextItemId, + ContinueCustomException, +) class DiffContextProvider(ContextProvider): @@ -44,9 +49,24 @@ class DiffContextProvider(ContextProvider): if not id.provider_title == self.title: raise Exception("Invalid provider title for item") - diff = subprocess.check_output(["git", "diff"], cwd=self.workspace_dir).decode( - "utf-8" + result = subprocess.run( + ["git", "diff"], cwd=self.workspace_dir, capture_output=True, text=True ) + diff = result.stdout + error = result.stderr + if error.strip() != "": + if error.startswith("warning: Not a git repository"): + raise ContinueCustomException( + title="Not a git repository", + message="The @diff context provider only works in git repositories.", + ) + raise ContinueCustomException( + title="Error running git diff", + message=f"Error running git diff:\n\n{error}", + ) + + if diff.strip() == "": + diff = "No changes" ctx_item = self.BASE_CONTEXT_ITEM.copy() ctx_item.content = diff diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index 0610a8c3..df82b1ab 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pydantic import BaseModel @@ -35,7 +35,8 @@ class HighlightedCodeContextProvider(ContextProvider): ide: Any # IdeProtocolServer highlighted_ranges: List[HighlightedRangeContextItem] = [] - adding_highlighted_code: bool = False + adding_highlighted_code: bool = True + # Controls whether you can have more than one highlighted range. Now always True. should_get_fallback_context_item: bool = True last_added_fallback: bool = False @@ -177,7 +178,9 @@ class HighlightedCodeContextProvider(ContextProvider): ) async def handle_highlighted_code( - self, range_in_files: List[RangeInFileWithContents] + self, + range_in_files: List[RangeInFileWithContents], + edit: Optional[bool] = False, ): self.should_get_fallback_context_item = True self.last_added_fallback = False @@ -208,47 +211,47 @@ class HighlightedCodeContextProvider(ContextProvider): self.highlighted_ranges = [ HighlightedRangeContextItem( rif=range_in_files[0], - item=self._rif_to_context_item(range_in_files[0], 0, True), + item=self._rif_to_context_item(range_in_files[0], 0, edit), ) ] return - # If current range overlaps with any others, delete them and only keep the new range - new_ranges = [] - for i, hr in enumerate(self.highlighted_ranges): - found_overlap = False - for new_rif in range_in_files: - if hr.rif.filepath == new_rif.filepath and hr.rif.range.overlaps_with( - new_rif.range - ): - found_overlap = True - break + # If editing, make sure none of the other ranges are editing + if edit: + for hr in self.highlighted_ranges: + hr.item.editing = False - # Also don't allow multiple ranges in same file with same content. This is useless to the model, and avoids - # the bug where cmd+f causes repeated highlights + # If new range overlaps with any existing, keep the existing but merged + new_ranges = [] + for i, new_hr in enumerate(range_in_files): + found_overlap_with = None + for existing_rif in self.highlighted_ranges: if ( - hr.rif.filepath == new_rif.filepath - and hr.rif.contents == new_rif.contents + new_hr.filepath == existing_rif.rif.filepath + and new_hr.range.overlaps_with(existing_rif.rif.range) ): - found_overlap = True + existing_rif.rif.range = existing_rif.rif.range.merge_with( + new_hr.range + ) + found_overlap_with = existing_rif break - if not found_overlap: + if found_overlap_with is None: new_ranges.append( HighlightedRangeContextItem( - rif=hr.rif, - item=self._rif_to_context_item(hr.rif, len(new_ranges), False), + rif=new_hr, + item=self._rif_to_context_item( + new_hr, len(self.highlighted_ranges) + i, edit + ), ) ) + elif edit: + # Want to update the range so it's only the newly selected portion + found_overlap_with.rif.range = new_hr.range + found_overlap_with.item.editing = True - self.highlighted_ranges = new_ranges + [ - HighlightedRangeContextItem( - rif=rif, - item=self._rif_to_context_item(rif, len(new_ranges) + idx, False), - ) - for idx, rif in enumerate(range_in_files) - ] + self.highlighted_ranges = self.highlighted_ranges + new_ranges self._make_sure_is_editing_range() self._disambiguate_highlighted_ranges() diff --git a/continuedev/src/continuedev/plugins/policies/default.py b/continuedev/src/continuedev/plugins/policies/default.py index ea3541e3..574d2a1c 100644 --- a/continuedev/src/continuedev/plugins/policies/default.py +++ b/continuedev/src/continuedev/plugins/policies/default.py @@ -1,13 +1,9 @@ -import os -from textwrap import dedent from typing import Type, Union from ...core.config import ContinueConfig from ...core.main import History, Policy, Step from ...core.observation import UserInputObservation -from ...libs.util.paths import getServerFolderPath from ..steps.chat import SimpleChatStep -from ..steps.core.core import MessageStep from ..steps.custom_command import CustomCommandStep from ..steps.main import EditHighlightedCodeStep from ..steps.steps_on_startup import StepsOnStartupStep @@ -59,24 +55,7 @@ class DefaultPolicy(Policy): def next(self, config: ContinueConfig, history: History) -> Step: # At the very start, run initial Steps specified in the config if history.get_current() is None: - shown_welcome_file = os.path.join(getServerFolderPath(), ".shown_welcome") - if os.path.exists(shown_welcome_file): - return StepsOnStartupStep() - - with open(shown_welcome_file, "w") as f: - f.write("") - return ( - MessageStep( - name="Welcome to Continue", - message=dedent( - """\ - - Highlight code section and ask a question or use `/edit` - - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue - - [Customize Continue](https://continue.dev/docs/customization) by typing '/config' (e.g. use your own API key) """ - ), - ) - >> StepsOnStartupStep() - ) + return StepsOnStartupStep() observation = history.get_current().observation if observation is not None and isinstance(observation, UserInputObservation): diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index 43a2b800..9a5ca2bb 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -30,7 +30,7 @@ class SetupPipelineStep(Step): sdk.context.set("api_description", self.api_description) source_name = ( - await sdk.models.medium.complete( + await sdk.models.summarize.complete( f"Write a snake_case name for the data source described by {self.api_description}: " ) ).strip() @@ -115,7 +115,7 @@ class ValidatePipelineStep(Step): if "Traceback" in output or "SyntaxError" in output: output = "Traceback" + output.split("Traceback")[-1] file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) - suggestion = await sdk.models.medium.complete( + suggestion = await sdk.models.summarize.complete( dedent( f"""\ ```python @@ -131,7 +131,7 @@ class ValidatePipelineStep(Step): ) ) - api_documentation_url = await sdk.models.medium.complete( + api_documentation_url = await sdk.models.summarize.complete( dedent( f"""\ The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this: @@ -216,7 +216,7 @@ class RunQueryStep(Step): ) if "Traceback" in output or "SyntaxError" in output: - suggestion = await sdk.models.medium.complete( + suggestion = await sdk.models.summarize.complete( dedent( f"""\ ```python diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index e2712746..63edabc6 100644 --- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -45,7 +45,7 @@ class WritePytestsRecipe(Step): Here is a complete set of pytest unit tests:""" ) - tests = await sdk.models.medium.complete(prompt) + tests = await sdk.models.summarize.complete(prompt) await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md index 3f2f804c..a8cae90b 100644 --- a/continuedev/src/continuedev/plugins/steps/README.md +++ b/continuedev/src/continuedev/plugins/steps/README.md @@ -33,7 +33,7 @@ If you'd like to override the default description of your step, which is just th - Return a static string - Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method. -- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. +- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.summarize.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. Here's an example: diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index b00bf85b..179882bb 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -11,8 +11,8 @@ from pydantic import Field from ...core.main import ChatMessage, FunctionCall, Models, Step, step_to_json_schema from ...core.sdk import ContinueSDK -from ...libs.llm.maybe_proxy_openai import MaybeProxyOpenAI from ...libs.llm.openai import OpenAI +from ...libs.llm.openai_free_trial import OpenAIFreeTrial from ...libs.util.devdata import dev_data_logger from ...libs.util.strings import remove_quotes_and_escapes from ...libs.util.telemetry import posthog_logger @@ -41,7 +41,7 @@ class SimpleChatStep(Step): async def run(self, sdk: ContinueSDK): # Check if proxy server API key if ( - isinstance(sdk.models.default, MaybeProxyOpenAI) + isinstance(sdk.models.default, OpenAIFreeTrial) and ( sdk.models.default.api_key is None or sdk.models.default.api_key.strip() == "" @@ -70,8 +70,8 @@ class SimpleChatStep(Step): config=ContinueConfig( ... models=Models( - default=MaybeProxyOpenAI(api_key="<API_KEY>", model="gpt-4"), - medium=MaybeProxyOpenAI(api_key="<API_KEY>", model="gpt-3.5-turbo") + default=OpenAIFreeTrial(api_key="<API_KEY>", model="gpt-4"), + summarize=OpenAIFreeTrial(api_key="<API_KEY>", model="gpt-3.5-turbo") ) ) ``` @@ -129,9 +129,10 @@ class SimpleChatStep(Step): await sdk.update_ui() self.name = add_ellipsis( remove_quotes_and_escapes( - await sdk.models.medium.complete( + await sdk.models.summarize.complete( f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:', max_tokens=20, + log=False, ) ), 200, diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index 25633942..39b0741f 100644 --- a/continuedev/src/continuedev/plugins/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -58,7 +58,7 @@ class AnswerQuestionChroma(Step): Here is the answer:""" ) - answer = await sdk.models.medium.complete(prompt) + answer = await sdk.models.summarize.complete(prompt) # Make paths relative to the workspace directory answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "") diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 61de6578..ad2e88e2 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -1,16 +1,13 @@ # These steps are depended upon by ContinueSDK import difflib import subprocess -import traceback from textwrap import dedent -from typing import Any, Coroutine, List, Optional, Union - -from pydantic import validator +from typing import Coroutine, List, Optional, Union from ....core.main import ChatMessage, ContinueCustomException, Step from ....core.observation import Observation, TextObservation, UserInputObservation from ....libs.llm import LLM -from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from ....libs.llm.openai_free_trial import OpenAIFreeTrial from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS from ....libs.util.devdata import dev_data_logger from ....libs.util.strings import ( @@ -57,21 +54,25 @@ class MessageStep(Step): class DisplayErrorStep(Step): name: str = "Error in the Continue server" - e: Any + + title: str = "Error in the Continue server" + message: str = "There was an error in the Continue server." + + @staticmethod + def from_exception(e: Exception) -> "DisplayErrorStep": + if isinstance(e, ContinueCustomException): + return DisplayErrorStep(title=e.title, message=e.message, name=e.title) + + return DisplayErrorStep(message=str(e)) class Config: arbitrary_types_allowed = True - @validator("e", pre=True, always=True) - def validate_e(cls, v): - if isinstance(v, Exception): - return "\n".join(traceback.format_exception(v)) - async def describe(self, models: Models) -> Coroutine[str, None, None]: - return self.e + return self.message async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - raise ContinueCustomException(message=self.e, title=self.name) + raise ContinueCustomException(message=self.message, title=self.title) class FileSystemEditStep(ReversibleStep): @@ -109,7 +110,7 @@ class ShellCommandsStep(Step): return f"Error when running shell commands:\n```\n{self._err_text}\n```" cmds_str = "\n".join(self.cmds) - return await models.medium.complete( + return await models.summarize.complete( f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:" ) @@ -180,10 +181,10 @@ class DefaultModelEditCodeStep(Step): _new_contents: str = "" _prompt_and_completion: str = "" - summary_prompt: str = "Please give brief a description of the changes made above using markdown bullet points. Be concise:" + summary_prompt: str = "Please briefly explain the changes made to the code above. Give no more than 2-3 sentences, and use markdown bullet points:" async def describe(self, models: Models) -> Coroutine[str, None, None]: - name = await models.medium.complete( + name = await models.summarize.complete( f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:" ) self.name = remove_quotes_and_escapes(name) @@ -231,7 +232,7 @@ class DefaultModelEditCodeStep(Step): # If using 3.5 and overflows, upgrade to 3.5.16k if model_to_use.model == "gpt-3.5-turbo": if total_tokens > model_to_use.context_length: - model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613") + model_to_use = OpenAIFreeTrial(model="gpt-3.5-turbo-0613") await sdk.start_model(model_to_use) # Remove tokens from the end first, and then the start to clear space @@ -829,7 +830,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user else: self.name = "Generating summary" self.description = "" - async for chunk in sdk.models.medium.stream_complete( + async for chunk in sdk.models.summarize.stream_complete( dedent( f"""\ Diff summary: "{self.user_input}" diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index 43299d00..241afe31 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -105,7 +105,7 @@ class FasterEditHighlightedCodeStep(Step): for rif in range_in_files: rif_dict[rif.filepath] = rif.contents - completion = await sdk.models.medium.complete(prompt) + completion = await sdk.models.summarize.complete(prompt) # Temporarily doing this to generate description. self._prompt = prompt @@ -180,7 +180,7 @@ class StarCoderEditHighlightedCodeStep(Step): _prompt_and_completion: str = "" async def describe(self, models: Models) -> Coroutine[str, None, None]: - return await models.medium.complete( + return await models.summarize.complete( f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:" ) diff --git a/continuedev/src/continuedev/plugins/steps/on_traceback.py b/continuedev/src/continuedev/plugins/steps/on_traceback.py index 3a96a8c7..86894818 100644 --- a/continuedev/src/continuedev/plugins/steps/on_traceback.py +++ b/continuedev/src/continuedev/plugins/steps/on_traceback.py @@ -2,7 +2,7 @@ import os from textwrap import dedent from typing import Dict, List, Optional, Tuple -from ...core.main import ChatMessage, Step +from ...core.main import ChatMessage, ContinueCustomException, Step from ...core.sdk import ContinueSDK from ...libs.util.filter_files import should_filter_path from ...libs.util.traceback.traceback_parsers import ( @@ -51,6 +51,12 @@ class DefaultOnTracebackStep(Step): # And this function is where you can get arbitrarily fancy about adding context async def run(self, sdk: ContinueSDK): + if self.output.strip() == "": + raise ContinueCustomException( + title="No terminal open", + message="You must have a terminal open in order to automatically debug with Continue.", + ) + if get_python_traceback(self.output) is not None and sdk.lsp is not None: await sdk.run_step(SolvePythonTracebackStep(output=self.output)) return diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index a2612731..1b9bc265 100644 --- a/continuedev/src/continuedev/plugins/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -29,7 +29,7 @@ class NLDecisionStep(Step): Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""" ) - resp = (await sdk.models.medium.complete(prompt)).lower() + resp = (await sdk.models.summarize.complete(prompt)).lower() step_to_run = None for step in self.steps: diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 7ca8a2be..83516719 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -46,7 +46,7 @@ class WriteRegexPatternStep(Step): async def run(self, sdk: ContinueSDK): # Ask the user for a regex pattern - pattern = await sdk.models.medium.complete( + pattern = await sdk.models.summarize.complete( dedent( f"""\ This is the user request: diff --git a/continuedev/src/continuedev/plugins/steps/setup_model.py b/continuedev/src/continuedev/plugins/steps/setup_model.py index 7fa34907..83e616b0 100644 --- a/continuedev/src/continuedev/plugins/steps/setup_model.py +++ b/continuedev/src/continuedev/plugins/steps/setup_model.py @@ -6,14 +6,14 @@ from ...models.main import Range MODEL_CLASS_TO_MESSAGE = { "OpenAI": "Obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", - "MaybeProxyOpenAI": "To get started with OpenAI models, obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", + "OpenAIFreeTrial": "To get started with OpenAI models, obtain your OpenAI API key from [here](https://platform.openai.com/account/api-keys) and paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", "AnthropicLLM": "To get started with Anthropic, you first need to sign up for the beta [here](https://claude.ai/login) to obtain an API key. Once you have the key, paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then reload the VS Code window for changes to take effect.", "ReplicateLLM": "To get started with Replicate, sign up to obtain an API key [here](https://replicate.ai/), then paste it into the `api_key` field at config.models.default.api_key in `config.py`.", "Ollama": "To get started with Ollama, download the Mac app from [ollama.ai](https://ollama.ai/). Once it is downloaded, be sure to pull at least one model and use its name in the model field in config.py (e.g. `model='codellama'`).", "GGML": "GGML models can be run locally using the `llama-cpp-python` library. To learn how to set up a local llama-cpp-python server, read [here](https://github.com/continuedev/ggml-server-example). Once it is started on port 8000, you're all set!", "TogetherLLM": "To get started using models from Together, first obtain your Together API key from [here](https://together.ai). Paste it into the `api_key` field at config.models.default.api_key in `config.py`. Then, on their models page, press 'start' on the model of your choice and make sure the `model=` parameter in the config file for the `TogetherLLM` class reflects the name of this model. Finally, reload the VS Code window for changes to take effect.", "LlamaCpp": "To get started with this model, clone the [`llama.cpp` repo](https://github.com/ggerganov/llama.cpp) and follow the instructions to set up the server [here](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#build). Any of the parameters described in the README can be passed to the `llama_cpp_args` field in the `LlamaCpp` class in `config.py`.", - "HuggingFaceInferenceAPI": "To get started with the HuggingFace Inference API, first deploy a model and obtain your API key from [here](https://huggingface.co/inference-api). Paste it into the `hf_token` field at config.models.default.hf_token in `config.py`. Finally, reload the VS Code window for changes to take effect." + "HuggingFaceInferenceAPI": "To get started with the HuggingFace Inference API, first deploy a model and obtain your API key from [here](https://huggingface.co/inference-api). Paste it into the `hf_token` field at config.models.default.hf_token in `config.py`. Finally, reload the VS Code window for changes to take effect.", } @@ -29,7 +29,7 @@ class SetupModelStep(Step): config_contents = await sdk.ide.readFile(getConfigFilePath()) start = config_contents.find("default=") + len("default=") - end = config_contents.find("unused=") - 1 + end = config_contents.find("saved=") - 1 range = Range.from_indices(config_contents, start, end) range.end.line -= 1 await sdk.ide.highlightCode( diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 770065ac..9d2ea47a 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -82,7 +82,9 @@ class GUIProtocolServer: return resp_model.parse_obj(resp) def on_error(self, e: Exception): - return self.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + return self.session.autopilot.continue_sdk.run_step( + DisplayErrorStep.from_exception(e) + ) def handle_json(self, message_type: str, data: Any): if message_type == "main_input": @@ -97,6 +99,8 @@ class GUIProtocolServer: self.on_retry_at_index(data["index"]) elif message_type == "clear_history": self.on_clear_history() + elif message_type == "set_current_session_title": + self.set_current_session_title(data["title"]) elif message_type == "delete_at_index": self.on_delete_at_index(data["index"]) elif message_type == "delete_context_with_ids": @@ -107,6 +111,8 @@ class GUIProtocolServer: self.on_set_editing_at_ids(data["ids"]) elif message_type == "show_logs_at_index": self.on_show_logs_at_index(data["index"]) + elif message_type == "show_context_virtual_file": + self.show_context_virtual_file() elif message_type == "select_context_item": self.select_context_item(data["id"], data["query"]) elif message_type == "load_session": @@ -180,11 +186,9 @@ class GUIProtocolServer: create_async_task(self.session.autopilot.set_editing_at_ids(ids), self.on_error) def on_show_logs_at_index(self, index: int): - name = "continue_logs.txt" + name = "Continue Context" logs = "\n\n############################################\n\n".join( - [ - "This is a log of the prompt/completion pairs sent/received from the LLM during this step" - ] + ["This is the prompt sent to the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs ) create_async_task( @@ -192,6 +196,22 @@ class GUIProtocolServer: ) posthog_logger.capture_event("show_logs_at_index", {}) + def show_context_virtual_file(self): + async def async_stuff(): + msgs = await self.session.autopilot.continue_sdk.get_chat_context() + ctx = "\n\n-----------------------------------\n\n".join( + ["This is the exact context that will be passed to the LLM"] + + list(map(lambda x: x.content, msgs)) + ) + await self.session.autopilot.ide.showVirtualFile( + "Continue - Selected Context", ctx + ) + + create_async_task( + async_stuff(), + self.on_error, + ) + def select_context_item(self, id: str, query: str): """Called when user selects an item from the dropdown""" create_async_task( @@ -211,6 +231,9 @@ class GUIProtocolServer: posthog_logger.capture_event("load_session", {"session_id": session_id}) + def set_current_session_title(self, title: str): + self.session.autopilot.set_current_session_title(title) + def set_system_message(self, message: str): self.session.autopilot.continue_sdk.config.system_message = message self.session.autopilot.continue_sdk.models.set_system_message(message) @@ -239,14 +262,14 @@ class GUIProtocolServer: # Set models in SDK temp = models.default - models.default = models.unused[index] - models.unused[index] = temp + models.default = models.saved[index] + models.saved[index] = temp await self.session.autopilot.continue_sdk.start_model(models.default) # Set models in config.py JOINER = ",\n\t\t" models_args = { - "unused": f"[{JOINER.join([display_llm_class(llm) for llm in models.unused])}]", + "saved": f"[{JOINER.join([display_llm_class(llm) for llm in models.saved])}]", ("default" if role == "*" else role): display_llm_class(models.default), } @@ -265,48 +288,59 @@ class GUIProtocolServer: def add_model_for_role(self, role: str, model_class: str, model: Any): models = self.session.autopilot.continue_sdk.config.models - unused_models = models.unused if role == "*": async def async_stuff(): - for role in ALL_MODEL_ROLES: - models.__setattr__(role, None) - - # Set and start the default model if didn't already exist from unused - models.default = MODEL_CLASSES[model_class](**model) - await self.session.autopilot.continue_sdk.run_step( - SetupModelStep(model_class=model_class) + # Remove all previous models in roles and place in saved + saved_models = models.saved + existing_saved_models = set( + [display_llm_class(llm) for llm in saved_models] ) - - await self.session.autopilot.continue_sdk.start_model(models.default) - - models_args = {} - for role in ALL_MODEL_ROLES: val = models.__getattribute__(role) - if val is None: - continue # no pun intended + if ( + val is not None + and display_llm_class(val) not in existing_saved_models + ): + saved_models.append(val) + existing_saved_models.add(display_llm_class(val)) + models.__setattr__(role, None) - models_args[role] = display_llm_class(val, True) + # Set and start the new default model + new_model = MODEL_CLASSES[model_class](**model) + models.default = new_model + await self.session.autopilot.continue_sdk.start_model(models.default) + # Construct and set the new models object JOINER = ",\n\t\t" - models_args[ - "unused" - ] = f"[{JOINER.join([display_llm_class(llm) for llm in unused_models])}]" + saved_model_strings = set( + [display_llm_class(llm) for llm in saved_models] + ) + models_args = { + "default": display_llm_class(models.default, True), + "saved": f"[{JOINER.join(saved_model_strings)}]", + } await self.session.autopilot.set_config_attr( ["models"], create_obj_node("Models", models_args), ) + # Add the requisite import to config.py add_config_import( f"from continuedev.src.continuedev.libs.llm.{MODEL_MODULE_NAMES[model_class]} import {model_class}" ) + # Set all roles (in-memory) to the new default model for role in ALL_MODEL_ROLES: if role != "default": models.__setattr__(role, models.default) + # Display setup help + await self.session.autopilot.continue_sdk.run_step( + SetupModelStep(model_class=model_class) + ) + create_async_task(async_stuff(), self.on_error) else: # TODO diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 7396b1db..d4f0690b 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -4,7 +4,7 @@ import json import os import traceback import uuid -from typing import Any, Callable, Coroutine, List, Type, TypeVar, Union +from typing import Any, Callable, Coroutine, List, Optional, Type, TypeVar, Union import nest_asyncio from fastapi import APIRouter, WebSocket @@ -232,7 +232,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self.onFileEdits(fileEdits) elif message_type == "highlightedCodePush": self.onHighlightedCodeUpdate( - [RangeInFileWithContents(**rif) for rif in data["highlightedCode"]] + [RangeInFileWithContents(**rif) for rif in data["highlightedCode"]], + edit=data.get("edit", None), ) elif message_type == "commandOutput": output = data["output"] @@ -243,7 +244,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): elif message_type == "acceptRejectSuggestion": self.onAcceptRejectSuggestion(data["accepted"]) elif message_type == "acceptRejectDiff": - self.onAcceptRejectDiff(data["accepted"]) + self.onAcceptRejectDiff(data["accepted"], data["stepIndex"]) elif message_type == "mainUserInput": self.onMainUserInput(data["input"]) elif message_type == "deleteAtIndex": @@ -349,10 +350,17 @@ class IdeProtocolServer(AbstractIdeProtocolServer): posthog_logger.capture_event("accept_reject_suggestion", {"accepted": accepted}) dev_data_logger.capture("accept_reject_suggestion", {"accepted": accepted}) - def onAcceptRejectDiff(self, accepted: bool): + def onAcceptRejectDiff(self, accepted: bool, step_index: int): posthog_logger.capture_event("accept_reject_diff", {"accepted": accepted}) dev_data_logger.capture("accept_reject_diff", {"accepted": accepted}) + if not accepted: + if autopilot := self.__get_autopilot(): + create_async_task( + autopilot.reject_diff(step_index), + self.on_error, + ) + def onFileSystemUpdate(self, update: FileSystemEdit): # Access to Autopilot (so SessionManager) pass @@ -387,10 +395,14 @@ class IdeProtocolServer(AbstractIdeProtocolServer): if autopilot := self.__get_autopilot(): create_async_task(autopilot.handle_debug_terminal(content), self.on_error) - def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): + def onHighlightedCodeUpdate( + self, + range_in_files: List[RangeInFileWithContents], + edit: Optional[bool] = False, + ): if autopilot := self.__get_autopilot(): create_async_task( - autopilot.handle_highlighted_code(range_in_files), self.on_error + autopilot.handle_highlighted_code(range_in_files, edit), self.on_error ) ## Subscriptions ## @@ -456,7 +468,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): resp = await self._send_and_receive_json( {"commands": commands}, TerminalContentsResponse, "getTerminalContents" ) - return resp.contents + return resp.contents.strip() async def getHighlightedCode(self) -> List[RangeInFile]: resp = await self._send_and_receive_json( @@ -640,7 +652,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): if session_id is not None and session_id in session_manager.sessions: await session_manager.sessions[session_id].autopilot.continue_sdk.run_step( - DisplayErrorStep(e=e) + DisplayErrorStep.from_exception(e) ) elif ideProtocolServer is not None: await ideProtocolServer.showMessage(f"Error in Continue server: {err_msg}") diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 34030047..015da767 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, List, Union +from typing import Any, Callable, List, Optional, Union from fastapi import WebSocket @@ -104,7 +104,11 @@ class AbstractIdeProtocolServer(ABC): """Run a command""" @abstractmethod - def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): + def onHighlightedCodeUpdate( + self, + range_in_files: List[RangeInFileWithContents], + edit: Optional[bool] = False, + ): """Called when highlighted code is updated""" @abstractmethod diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 40d46b18..5e6cdd53 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -78,7 +78,7 @@ async def ensure_meilisearch_installed() -> bool: pass existing_paths.remove(meilisearchPath) - await download_meilisearch() + await download_meilisearch() # Clear the existing directories for p in existing_paths: diff --git a/continuedev/src/continuedev/tests/util/config.py b/continuedev/src/continuedev/tests/util/config.py index 73d3aeff..dd0e1f13 100644 --- a/continuedev/src/continuedev/tests/util/config.py +++ b/continuedev/src/continuedev/tests/util/config.py @@ -1,12 +1,12 @@ from continuedev.src.continuedev.core.config import ContinueConfig from continuedev.src.continuedev.core.models import Models -from continuedev.src.continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from continuedev.src.continuedev.libs.llm.openai_free_trial import OpenAIFreeTrial config = ContinueConfig( allow_anonymous_telemetry=False, models=Models( - default=MaybeProxyOpenAI(api_key="", model="gpt-4"), - medium=MaybeProxyOpenAI( + default=OpenAIFreeTrial(api_key="", model="gpt-4"), + summarize=OpenAIFreeTrial( api_key="", model="gpt-3.5-turbo", ), |