From f09150617ed2454f3074bcf93f53aae5ae637d40 Mon Sep 17 00:00:00 2001 From: Nate Sesti <33237525+sestinj@users.noreply.github.com> Date: Mon, 9 Oct 2023 18:37:27 -0700 Subject: Preview (#541) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Strong typing (#533) * refactor: :recycle: get rid of continuedev.src.continuedev structure * refactor: :recycle: switching back to server folder * feat: :sparkles: make config.py imports shorter * feat: :bookmark: publish as pre-release vscode extension * refactor: :recycle: refactor and add more completion params to ui * build: :building_construction: download from preview S3 * fix: :bug: fix paths * fix: :green_heart: package:pre-release * ci: :green_heart: more time for tests * fix: :green_heart: fix build scripts * fix: :bug: fix import in run.py * fix: :bookmark: update version to try again * ci: ๐Ÿ’š Update package.json version [skip ci] * refactor: :fire: don't check for old extensions version * fix: :bug: small bug fixes * fix: :bug: fix config.py import paths * ci: ๐Ÿ’š Update package.json version [skip ci] * ci: :green_heart: platform-specific builds test #1 * feat: :green_heart: ship with binary * fix: :green_heart: fix copy statement to include.exe for windows * fix: :green_heart: cd extension before packaging * chore: :loud_sound: count tokens generated * fix: :green_heart: remove npm_config_arch * fix: :green_heart: publish as pre-release! * chore: :bookmark: update version * perf: :green_heart: hardcode distro paths * fix: :bug: fix yaml syntax error * chore: :bookmark: update version * fix: :green_heart: update permissions and version * feat: :bug: kill old server if needed * feat: :lipstick: update marketplace icon for pre-release * ci: ๐Ÿ’š Update package.json version [skip ci] * feat: :sparkles: auto-reload for config.py * feat: :wrench: update default config.py imports * feat: :sparkles: codelens in config.py * feat: :sparkles: select model param count from UI * ci: ๐Ÿ’š Update package.json version [skip ci] * feat: :sparkles: more model options, ollama error handling * perf: :zap: don't show server loading immediately * fix: :bug: fixing small UI details * ci: ๐Ÿ’š Update package.json version [skip ci] * feat: :rocket: headers param on LLM class * fix: :bug: fix headers for openai.;y * feat: :sparkles: highlight code on cmd+shift+L * ci: ๐Ÿ’š Update package.json version [skip ci] * feat: :lipstick: sticky top bar in gui.tsx * fix: :loud_sound: websocket logging and horizontal scrollbar * ci: ๐Ÿ’š Update package.json version [skip ci] * feat: :sparkles: allow AzureOpenAI Service through GGML * ci: ๐Ÿ’š Update package.json version [skip ci] * fix: :bug: fix automigration * ci: ๐Ÿ’š Update package.json version [skip ci] * ci: :green_heart: upload binaries in ci, download apple silicon * chore: :fire: remove notes * fix: :green_heart: use curl to download binary * fix: :green_heart: set permissions on apple silicon binary * fix: :green_heart: testing * fix: :green_heart: cleanup file * fix: :green_heart: fix preview.yaml * fix: :green_heart: only upload once per binary * fix: :green_heart: install rosetta * ci: :green_heart: download binary after tests * ci: ๐Ÿ’š Update package.json version [skip ci] * ci: :green_heart: prepare ci for merge to main --------- Co-authored-by: GitHub Action --- server/README.md | 88 + server/continuedev/__init__.py | 19 + server/continuedev/__main__.py | 30 + server/continuedev/core/abstract_sdk.py | 82 + server/continuedev/core/autopilot.py | 746 ++++++ server/continuedev/core/config.py | 114 + server/continuedev/core/context.py | 516 +++++ server/continuedev/core/env.py | 31 + server/continuedev/core/lsp.py | 416 ++++ server/continuedev/core/main.py | 437 ++++ server/continuedev/core/models.py | 113 + server/continuedev/core/observation.py | 41 + server/continuedev/core/sdk.py | 309 +++ server/continuedev/core/steps.py | 963 ++++++++ server/continuedev/headless/__init__.py | 20 + server/continuedev/headless/headless_ide.py | 181 ++ server/continuedev/libs/__init__.py | 0 server/continuedev/libs/chroma/.gitignore | 1 + server/continuedev/libs/chroma/query.py | 218 ++ server/continuedev/libs/chroma/update.py | 66 + .../continuedev/libs/constants/default_config.py | 88 + server/continuedev/libs/constants/main.py | 6 + server/continuedev/libs/llm/__init__.py | 14 + server/continuedev/libs/llm/anthropic.py | 74 + server/continuedev/libs/llm/base.py | 458 ++++ server/continuedev/libs/llm/ggml.py | 226 ++ server/continuedev/libs/llm/google_palm_api.py | 50 + server/continuedev/libs/llm/hf_inference_api.py | 78 + server/continuedev/libs/llm/hf_tgi.py | 65 + server/continuedev/libs/llm/hugging_face.py | 19 + server/continuedev/libs/llm/llamacpp.py | 86 + server/continuedev/libs/llm/ollama.py | 106 + server/continuedev/libs/llm/openai.py | 156 ++ server/continuedev/libs/llm/openai_free_trial.py | 83 + server/continuedev/libs/llm/prompt_utils.py | 76 + server/continuedev/libs/llm/prompts/chat.py | 174 ++ server/continuedev/libs/llm/prompts/edit.py | 27 + server/continuedev/libs/llm/proxy_server.py | 108 + server/continuedev/libs/llm/queued.py | 77 + server/continuedev/libs/llm/replicate.py | 78 + server/continuedev/libs/llm/text_gen_interface.py | 114 + server/continuedev/libs/llm/together.py | 125 + server/continuedev/libs/util/calculate_diff.py | 154 ++ server/continuedev/libs/util/commonregex.py | 144 ++ server/continuedev/libs/util/copy_codebase.py | 121 + server/continuedev/libs/util/count_tokens.py | 206 ++ server/continuedev/libs/util/create_async_task.py | 38 + server/continuedev/libs/util/devdata.py | 67 + server/continuedev/libs/util/edit_config.py | 149 ++ server/continuedev/libs/util/errors.py | 2 + server/continuedev/libs/util/filter_files.py | 33 + server/continuedev/libs/util/logging.py | 47 + server/continuedev/libs/util/map_path.py | 16 + server/continuedev/libs/util/paths.py | 148 ++ server/continuedev/libs/util/queue.py | 17 + server/continuedev/libs/util/ripgrep.py | 25 + server/continuedev/libs/util/step_name_to_steps.py | 47 + server/continuedev/libs/util/strings.py | 64 + server/continuedev/libs/util/telemetry.py | 108 + server/continuedev/libs/util/templating.py | 76 + .../libs/util/traceback/traceback_parsers.py | 56 + server/continuedev/models/__init__.py | 0 server/continuedev/models/filesystem.py | 398 ++++ server/continuedev/models/filesystem_edit.py | 164 ++ server/continuedev/models/generate_json_schema.py | 54 + server/continuedev/models/main.py | 229 ++ server/continuedev/models/reference/generate.py | 144 ++ .../plugins/context_providers/__init__.py | 7 + .../continuedev/plugins/context_providers/diff.py | 73 + .../plugins/context_providers/dynamic.py | 75 + .../plugins/context_providers/embeddings.py | 81 + .../continuedev/plugins/context_providers/file.py | 136 ++ .../plugins/context_providers/filetree.py | 89 + .../plugins/context_providers/github.py | 49 + .../plugins/context_providers/google.py | 70 + .../plugins/context_providers/highlighted_code.py | 293 +++ .../plugins/context_providers/search.py | 90 + .../plugins/context_providers/terminal.py | 49 + .../continuedev/plugins/context_providers/url.py | 104 + .../continuedev/plugins/context_providers/util.py | 5 + server/continuedev/plugins/policies/commit.py | 77 + server/continuedev/plugins/policies/default.py | 85 + server/continuedev/plugins/policies/headless.py | 18 + .../plugins/recipes/AddTransformRecipe/README.md | 9 + .../AddTransformRecipe/dlt_transform_docs.md | 142 ++ .../plugins/recipes/AddTransformRecipe/main.py | 31 + .../plugins/recipes/AddTransformRecipe/steps.py | 106 + .../plugins/recipes/ContinueRecipeRecipe/README.md | 7 + .../plugins/recipes/ContinueRecipeRecipe/main.py | 43 + .../plugins/recipes/CreatePipelineRecipe/README.md | 0 .../plugins/recipes/CreatePipelineRecipe/main.py | 40 + .../plugins/recipes/CreatePipelineRecipe/steps.py | 243 ++ .../plugins/recipes/DDtoBQRecipe/README.md | 3 + .../DDtoBQRecipe/dlt_duckdb_to_bigquery_docs.md | 85 + .../plugins/recipes/DDtoBQRecipe/main.py | 31 + .../plugins/recipes/DDtoBQRecipe/steps.py | 119 + .../recipes/DeployPipelineAirflowRecipe/README.md | 0 .../recipes/DeployPipelineAirflowRecipe/main.py | 86 + .../recipes/DeployPipelineAirflowRecipe/steps.py | 125 + server/continuedev/plugins/recipes/README.md | 19 + .../plugins/recipes/TemplateRecipe/README.md | 7 + .../plugins/recipes/TemplateRecipe/main.py | 29 + .../plugins/recipes/WritePytestsRecipe/README.md | 7 + .../plugins/recipes/WritePytestsRecipe/main.py | 52 + server/continuedev/plugins/steps/README.md | 50 + server/continuedev/plugins/steps/__init__.py | 13 + server/continuedev/plugins/steps/chat.py | 379 +++ server/continuedev/plugins/steps/chroma.py | 86 + server/continuedev/plugins/steps/clear_history.py | 10 + server/continuedev/plugins/steps/cmd.py | 30 + server/continuedev/plugins/steps/comment_code.py | 16 + server/continuedev/plugins/steps/custom_command.py | 29 + .../plugins/steps/draft/abstract_method.py | 21 + server/continuedev/plugins/steps/draft/redux.py | 50 + server/continuedev/plugins/steps/draft/typeorm.py | 54 + server/continuedev/plugins/steps/feedback.py | 14 + .../continuedev/plugins/steps/find_and_replace.py | 30 + server/continuedev/plugins/steps/help.py | 70 + .../plugins/steps/input/nl_multiselect.py | 32 + server/continuedev/plugins/steps/main.py | 422 ++++ server/continuedev/plugins/steps/on_traceback.py | 206 ++ server/continuedev/plugins/steps/open_config.py | 17 + server/continuedev/plugins/steps/react.py | 44 + server/continuedev/plugins/steps/refactor.py | 136 ++ .../continuedev/plugins/steps/search_directory.py | 84 + server/continuedev/plugins/steps/setup_model.py | 38 + server/continuedev/plugins/steps/share_session.py | 52 + .../continuedev/plugins/steps/steps_on_startup.py | 19 + server/continuedev/plugins/steps/welcome.py | 40 + server/continuedev/server/gui.py | 459 ++++ server/continuedev/server/ide.py | 680 ++++++ server/continuedev/server/ide_protocol.py | 170 ++ server/continuedev/server/main.py | 109 + server/continuedev/server/meilisearch_server.py | 196 ++ server/continuedev/server/session_manager.py | 192 ++ server/dev_requirements.txt | 2 + server/install-dependencies.sh | 16 + server/main.py | 5 + server/notes.md | 101 + server/poetry.lock | 2414 ++++++++++++++++++++ server/poetry.toml | 2 + server/pyproject.toml | 47 + server/requirements.txt | 27 + server/tests/__init__.py | 0 server/tests/llm_test.py | 179 ++ server/tests/step_test.py | 68 + server/tests/util/__init__.py | 0 server/tests/util/config.py | 19 + server/tests/util/openai_mock.py | 139 ++ server/tests/util/prompts.py | 2 + 150 files changed, 18440 insertions(+) create mode 100644 server/README.md create mode 100644 server/continuedev/__init__.py create mode 100644 server/continuedev/__main__.py create mode 100644 server/continuedev/core/abstract_sdk.py create mode 100644 server/continuedev/core/autopilot.py create mode 100644 server/continuedev/core/config.py create mode 100644 server/continuedev/core/context.py create mode 100644 server/continuedev/core/env.py create mode 100644 server/continuedev/core/lsp.py create mode 100644 server/continuedev/core/main.py create mode 100644 server/continuedev/core/models.py create mode 100644 server/continuedev/core/observation.py create mode 100644 server/continuedev/core/sdk.py create mode 100644 server/continuedev/core/steps.py create mode 100644 server/continuedev/headless/__init__.py create mode 100644 server/continuedev/headless/headless_ide.py create mode 100644 server/continuedev/libs/__init__.py create mode 100644 server/continuedev/libs/chroma/.gitignore create mode 100644 server/continuedev/libs/chroma/query.py create mode 100644 server/continuedev/libs/chroma/update.py create mode 100644 server/continuedev/libs/constants/default_config.py create mode 100644 server/continuedev/libs/constants/main.py create mode 100644 server/continuedev/libs/llm/__init__.py create mode 100644 server/continuedev/libs/llm/anthropic.py create mode 100644 server/continuedev/libs/llm/base.py create mode 100644 server/continuedev/libs/llm/ggml.py create mode 100644 server/continuedev/libs/llm/google_palm_api.py create mode 100644 server/continuedev/libs/llm/hf_inference_api.py create mode 100644 server/continuedev/libs/llm/hf_tgi.py create mode 100644 server/continuedev/libs/llm/hugging_face.py create mode 100644 server/continuedev/libs/llm/llamacpp.py create mode 100644 server/continuedev/libs/llm/ollama.py create mode 100644 server/continuedev/libs/llm/openai.py create mode 100644 server/continuedev/libs/llm/openai_free_trial.py create mode 100644 server/continuedev/libs/llm/prompt_utils.py create mode 100644 server/continuedev/libs/llm/prompts/chat.py create mode 100644 server/continuedev/libs/llm/prompts/edit.py create mode 100644 server/continuedev/libs/llm/proxy_server.py create mode 100644 server/continuedev/libs/llm/queued.py create mode 100644 server/continuedev/libs/llm/replicate.py create mode 100644 server/continuedev/libs/llm/text_gen_interface.py create mode 100644 server/continuedev/libs/llm/together.py create mode 100644 server/continuedev/libs/util/calculate_diff.py create mode 100644 server/continuedev/libs/util/commonregex.py create mode 100644 server/continuedev/libs/util/copy_codebase.py create mode 100644 server/continuedev/libs/util/count_tokens.py create mode 100644 server/continuedev/libs/util/create_async_task.py create mode 100644 server/continuedev/libs/util/devdata.py create mode 100644 server/continuedev/libs/util/edit_config.py create mode 100644 server/continuedev/libs/util/errors.py create mode 100644 server/continuedev/libs/util/filter_files.py create mode 100644 server/continuedev/libs/util/logging.py create mode 100644 server/continuedev/libs/util/map_path.py create mode 100644 server/continuedev/libs/util/paths.py create mode 100644 server/continuedev/libs/util/queue.py create mode 100644 server/continuedev/libs/util/ripgrep.py create mode 100644 server/continuedev/libs/util/step_name_to_steps.py create mode 100644 server/continuedev/libs/util/strings.py create mode 100644 server/continuedev/libs/util/telemetry.py create mode 100644 server/continuedev/libs/util/templating.py create mode 100644 server/continuedev/libs/util/traceback/traceback_parsers.py create mode 100644 server/continuedev/models/__init__.py create mode 100644 server/continuedev/models/filesystem.py create mode 100644 server/continuedev/models/filesystem_edit.py create mode 100644 server/continuedev/models/generate_json_schema.py create mode 100644 server/continuedev/models/main.py create mode 100644 server/continuedev/models/reference/generate.py create mode 100644 server/continuedev/plugins/context_providers/__init__.py create mode 100644 server/continuedev/plugins/context_providers/diff.py create mode 100644 server/continuedev/plugins/context_providers/dynamic.py create mode 100644 server/continuedev/plugins/context_providers/embeddings.py create mode 100644 server/continuedev/plugins/context_providers/file.py create mode 100644 server/continuedev/plugins/context_providers/filetree.py create mode 100644 server/continuedev/plugins/context_providers/github.py create mode 100644 server/continuedev/plugins/context_providers/google.py create mode 100644 server/continuedev/plugins/context_providers/highlighted_code.py create mode 100644 server/continuedev/plugins/context_providers/search.py create mode 100644 server/continuedev/plugins/context_providers/terminal.py create mode 100644 server/continuedev/plugins/context_providers/url.py create mode 100644 server/continuedev/plugins/context_providers/util.py create mode 100644 server/continuedev/plugins/policies/commit.py create mode 100644 server/continuedev/plugins/policies/default.py create mode 100644 server/continuedev/plugins/policies/headless.py create mode 100644 server/continuedev/plugins/recipes/AddTransformRecipe/README.md create mode 100644 server/continuedev/plugins/recipes/AddTransformRecipe/dlt_transform_docs.md create mode 100644 server/continuedev/plugins/recipes/AddTransformRecipe/main.py create mode 100644 server/continuedev/plugins/recipes/AddTransformRecipe/steps.py create mode 100644 server/continuedev/plugins/recipes/ContinueRecipeRecipe/README.md create mode 100644 server/continuedev/plugins/recipes/ContinueRecipeRecipe/main.py create mode 100644 server/continuedev/plugins/recipes/CreatePipelineRecipe/README.md create mode 100644 server/continuedev/plugins/recipes/CreatePipelineRecipe/main.py create mode 100644 server/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py create mode 100644 server/continuedev/plugins/recipes/DDtoBQRecipe/README.md create mode 100644 server/continuedev/plugins/recipes/DDtoBQRecipe/dlt_duckdb_to_bigquery_docs.md create mode 100644 server/continuedev/plugins/recipes/DDtoBQRecipe/main.py create mode 100644 server/continuedev/plugins/recipes/DDtoBQRecipe/steps.py create mode 100644 server/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/README.md create mode 100644 server/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/main.py create mode 100644 server/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/steps.py create mode 100644 server/continuedev/plugins/recipes/README.md create mode 100644 server/continuedev/plugins/recipes/TemplateRecipe/README.md create mode 100644 server/continuedev/plugins/recipes/TemplateRecipe/main.py create mode 100644 server/continuedev/plugins/recipes/WritePytestsRecipe/README.md create mode 100644 server/continuedev/plugins/recipes/WritePytestsRecipe/main.py create mode 100644 server/continuedev/plugins/steps/README.md create mode 100644 server/continuedev/plugins/steps/__init__.py create mode 100644 server/continuedev/plugins/steps/chat.py create mode 100644 server/continuedev/plugins/steps/chroma.py create mode 100644 server/continuedev/plugins/steps/clear_history.py create mode 100644 server/continuedev/plugins/steps/cmd.py create mode 100644 server/continuedev/plugins/steps/comment_code.py create mode 100644 server/continuedev/plugins/steps/custom_command.py create mode 100644 server/continuedev/plugins/steps/draft/abstract_method.py create mode 100644 server/continuedev/plugins/steps/draft/redux.py create mode 100644 server/continuedev/plugins/steps/draft/typeorm.py create mode 100644 server/continuedev/plugins/steps/feedback.py create mode 100644 server/continuedev/plugins/steps/find_and_replace.py create mode 100644 server/continuedev/plugins/steps/help.py create mode 100644 server/continuedev/plugins/steps/input/nl_multiselect.py create mode 100644 server/continuedev/plugins/steps/main.py create mode 100644 server/continuedev/plugins/steps/on_traceback.py create mode 100644 server/continuedev/plugins/steps/open_config.py create mode 100644 server/continuedev/plugins/steps/react.py create mode 100644 server/continuedev/plugins/steps/refactor.py create mode 100644 server/continuedev/plugins/steps/search_directory.py create mode 100644 server/continuedev/plugins/steps/setup_model.py create mode 100644 server/continuedev/plugins/steps/share_session.py create mode 100644 server/continuedev/plugins/steps/steps_on_startup.py create mode 100644 server/continuedev/plugins/steps/welcome.py create mode 100644 server/continuedev/server/gui.py create mode 100644 server/continuedev/server/ide.py create mode 100644 server/continuedev/server/ide_protocol.py create mode 100644 server/continuedev/server/main.py create mode 100644 server/continuedev/server/meilisearch_server.py create mode 100644 server/continuedev/server/session_manager.py create mode 100644 server/dev_requirements.txt create mode 100755 server/install-dependencies.sh create mode 100644 server/main.py create mode 100644 server/notes.md create mode 100644 server/poetry.lock create mode 100644 server/poetry.toml create mode 100644 server/pyproject.toml create mode 100644 server/requirements.txt create mode 100644 server/tests/__init__.py create mode 100644 server/tests/llm_test.py create mode 100644 server/tests/step_test.py create mode 100644 server/tests/util/__init__.py create mode 100644 server/tests/util/config.py create mode 100644 server/tests/util/openai_mock.py create mode 100644 server/tests/util/prompts.py (limited to 'server') diff --git a/server/README.md b/server/README.md new file mode 100644 index 00000000..25fb640e --- /dev/null +++ b/server/README.md @@ -0,0 +1,88 @@ +# Continue PyPI Package + +This package contains the [Continue](https://github.com/continuedev/continue) server and core classes needed to build your own recipes. + +Continue is a Python library for automating repetitive sequences of software development tasks using language models. Using our VS Code extension, you can build, run, and refine these recipes as they natively interact with your codebase. Read the docs [here](https://continue.dev/docs) or download the VS Code extension [here](https://marketplace.visualstudio.com/items?itemName=Continue.continue). + +## Continue Server + +The Continue server acts as a bridge between the Continue React app and your IDE, running your recipes and acting on the codebase. + +Start it by running the following commands: + +1. `cd server` +2. Make sure packages are installed with `poetry install` + - If poetry is not installed, you can install with + ```bash + curl -sSL https://install.python-poetry.org | python3 - + ``` + (official instructions [here](https://python-poetry.org/docs/#installing-with-the-official-installer)) +3. `poetry shell` to activate the virtual environment +4. `python3 -m continuedev.server.main` to start the server + +Once you've validated that this works, you'll often want to use a debugger, in which case we've provided a launch configuration for VS Code in `.vscode/launch.json`. To start the debugger in VS Code, ensure that the workspace directory is the root of the `continue` repo, then press F5. + +> [!NOTE] +> To start the debugger, you'll have to select the poetry Python interpreter +> (`/path-to-poetry-venv/bin/python3`) in the bottom right of the VS Code window. If you +> don't see this, you may have to install the [Python +> extension](https://marketplace.visualstudio.com/items?itemName=ms-python.python). + +## Scripts + +`poetry run typegen` to generate JSONSchema .json files from the Pydantic types defined in the `models` directory. + +`poetry build` will output wheel and tarball files in `./dist`. + +## Writing Steps + +See the `continuedev/libs/steps` folder for examples of writing a Continue step. See our documentation for tutorials. + +## How to contribute + +Open a [new GitHub Issue](https://github.com/continuedev/continue/issues/new) or comment on [an existing one](https://github.com/continuedev/continue/issues). Let us know what you would like to contribute, and we will help you make it happen! + +For more a more detailed contributing guide, see [CONTRIBUTING.md](../CONTRIBUTING.md). + +## Install from source + +#### 1. Clone this repo + +Recommended: Run this command to use SSH + +```bash +git clone git@github.com:continuedev/continue.git +``` + +Alternative: Run this command to use HTTPS + +```bash +git clone https://github.com/continuedev/continue +``` + +#### 2. Install Continue + +Run this command to use the install script + +```bash +cd continue/extension/scripts && python3 install_from_source.py +``` + +> [!IMPORTANT] +> Ensure you have a Java Runtime Environment (JRE) installed. Verify this by typing `java +-version` in your command prompt or terminal. If a version number appears, you're set. +> If not, download and install a JRE from Oracle's website or through a package manager, +> for example Homebrew. +> +> ```sh +> brew install openjdk@11 +> ``` + +# Understanding the codebase + +- [Continue Server README](./README.md): learn about the core of Continue, which can be downloaded as a [PyPI package](https://pypi.org/project/continuedev/) +- [VS Code Extension README](../extension/README.md): learn about the capabilities of our extensionโ€”the first implementation of Continue's IDE Protocolโ€”which makes it possible to use use Continue in VS Code and GitHub Codespaces +- [Continue GUI README](../extension/react-app/): learn about the React app that lets users interact with the server and is placed adjacent to the text editor in any supported IDE +- [Schema README](../schema/README.md): learn about the JSON Schema types generated from Pydantic models, which we use across the `server/` and `extension/` directories +- [Continue Docs README](../docs/README.md): learn how our [docs](https://continue.dev/docs) are written and built +- [How to debug the VS Code Extension README](../extension/src/README.md): learn how to set up the VS Code extension, so you can debug it diff --git a/server/continuedev/__init__.py b/server/continuedev/__init__.py new file mode 100644 index 00000000..1b4776a8 --- /dev/null +++ b/server/continuedev/__init__.py @@ -0,0 +1,19 @@ +import asyncio +from typing import Union + +from .core.config import ContinueConfig +from .core.main import Step +from .headless import start_headless_session + + +def run(step_or_config: Union[Step, ContinueConfig]): + if isinstance(step_or_config, ContinueConfig): + config = step_or_config + else: + config = ContinueConfig() + config.steps_on_startup = [step_or_config] + + loop = asyncio.get_event_loop() + loop.run_until_complete(start_headless_session(config=config)) + tasks = asyncio.all_tasks(loop) + loop.run_until_complete(asyncio.gather(*tasks)) diff --git a/server/continuedev/__main__.py b/server/continuedev/__main__.py new file mode 100644 index 00000000..caaba117 --- /dev/null +++ b/server/continuedev/__main__.py @@ -0,0 +1,30 @@ +from typing import Optional + +import typer + +from . import run +from .server.main import run_server + +app = typer.Typer() + + +@app.command() +def main( + port: int = typer.Option(65432, help="server port"), + host: str = typer.Option("127.0.0.1", help="server host"), + meilisearch_url: Optional[str] = typer.Option( + None, help="The URL of the MeiliSearch server if running manually" + ), + config: Optional[str] = typer.Option( + None, help="The path to the configuration file" + ), + headless: bool = typer.Option(False, help="Run in headless mode"), +): + if headless: + run(config) + else: + run_server(port=port, host=host, meilisearch_url=meilisearch_url) + + +if __name__ == "__main__": + app() diff --git a/server/continuedev/core/abstract_sdk.py b/server/continuedev/core/abstract_sdk.py new file mode 100644 index 00000000..fdb99d47 --- /dev/null +++ b/server/continuedev/core/abstract_sdk.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from typing import Coroutine, List, Union + +from ..models.filesystem_edit import FileSystemEdit +from .config import ContinueConfig +from .main import ChatMessage, History, Step +from .observation import Observation + +""" +[[Generate]] +[Prompt] +Write an abstract class AbstractContinueSDK(ABC) that has all of the same methods as the ContinueSDK class, but without any implementation. +All methods should be documented with the same docstrings as the ContinueSDK class and have the same types. +[Context] +./sdk.py:ContinueSDK +""" + + +class AbstractContinueSDK(ABC): + """The SDK provided as parameters to a step""" + + @property + def history(self) -> History: + return self.__autopilot.history + + @abstractmethod + async def _ensure_absolute_path(self, path: str) -> str: + pass + + @abstractmethod + async def run_step(self, step: Step) -> Coroutine[Observation, None, None]: + pass + + @abstractmethod + async def apply_filesystem_edit(self, edit: FileSystemEdit): + pass + + @abstractmethod + async def wait_for_user_input(self) -> str: + pass + + @abstractmethod + async def wait_for_user_confirmation(self, prompt: str): + pass + + @abstractmethod + async def run(self, commands: Union[List[str], str], cwd: str = None): + pass + + @abstractmethod + async def edit_file(self, filename: str, prompt: str): + pass + + @abstractmethod + async def append_to_file(self, filename: str, content: str): + pass + + @abstractmethod + async def add_file(self, filename: str, content: Union[str, None]): + pass + + @abstractmethod + async def delete_file(self, filename: str): + pass + + @abstractmethod + async def add_directory(self, path: str): + pass + + @abstractmethod + async def delete_directory(self, path: str): + pass + + config: ContinueConfig + + @abstractmethod + def set_loading_message(self, message: str): + pass + + @abstractmethod + async def get_chat_context(self) -> List[ChatMessage]: + pass diff --git a/server/continuedev/core/autopilot.py b/server/continuedev/core/autopilot.py new file mode 100644 index 00000000..11c05378 --- /dev/null +++ b/server/continuedev/core/autopilot.py @@ -0,0 +1,746 @@ +import json +import os +import time +import traceback +import uuid +from functools import cached_property +from typing import Callable, Coroutine, Dict, List, Optional + +import redbaron +from aiohttp import ClientPayloadError +from openai import error as openai_errors +from pydantic import root_validator + +from ..libs.llm.prompts.chat import template_alpaca_messages +from ..libs.util.create_async_task import create_async_task +from ..libs.util.devdata import dev_data_logger +from ..libs.util.edit_config import edit_config_property +from ..libs.util.logging import logger +from ..libs.util.paths import getSavedContextGroupsPath +from ..libs.util.queue import AsyncSubscriptionQueue +from ..libs.util.strings import remove_quotes_and_escapes +from ..libs.util.telemetry import posthog_logger +from ..libs.util.traceback.traceback_parsers import ( + get_javascript_traceback, + get_python_traceback, +) +from ..models.filesystem import RangeInFileWithContents +from ..models.filesystem_edit import FileEditWithFullContents +from ..models.main import ContinueBaseModel +from ..plugins.context_providers.file import FileContextProvider +from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider +from ..plugins.policies.default import DefaultPolicy +from ..plugins.steps.on_traceback import DefaultOnTracebackStep +from ..server.ide_protocol import AbstractIdeProtocolServer +from ..server.meilisearch_server import get_meilisearch_url, stop_meilisearch +from .config import ContinueConfig +from .context import ContextManager +from .main import ( + Context, + ContextItem, + ContinueCustomException, + FullState, + History, + HistoryNode, + Policy, + SessionInfo, + Step, +) +from .observation import InternalErrorObservation, Observation +from .sdk import ContinueSDK +from .steps import DisplayErrorStep, ManualEditStep, ReversibleStep, UserInputStep + + +def get_error_title(e: Exception) -> str: + if isinstance(e, openai_errors.APIError): + return "OpenAI is overloaded with requests. Please try again." + elif isinstance(e, openai_errors.RateLimitError): + return "This OpenAI API key has been rate limited. Please try again." + elif isinstance(e, openai_errors.Timeout): + return "OpenAI timed out. Please try again." + elif ( + isinstance(e, openai_errors.InvalidRequestError) + and e.code == "context_length_exceeded" + ): + return e._message + elif isinstance(e, ClientPayloadError): + return "The request failed. Please try again." + elif isinstance(e, openai_errors.APIConnectionError): + return 'The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""' + elif isinstance(e, openai_errors.InvalidRequestError): + return "Invalid request sent to OpenAI. Please try again." + elif "rate_limit_ip_middleware" in e.__str__(): + return "You have reached your limit for free usage of our token. You can continue using Continue by entering your own OpenAI API key in VS Code settings." + elif e.__str__().startswith("Cannot connect to host"): + return ( + "The request failed. Please check your internet connection and try again." + ) + return e.__str__() or e.__repr__() + + +class Autopilot(ContinueBaseModel): + ide: AbstractIdeProtocolServer + + policy: Policy = DefaultPolicy() + history: History = History.from_empty() + context: Context = Context() + full_state: Optional[FullState] = None + session_info: Optional[SessionInfo] = None + context_manager: ContextManager = ContextManager() + continue_sdk: ContinueSDK = None + + _on_update_callbacks: List[Callable[[FullState], None]] = [] + + _active: bool = False + _should_halt: bool = False + _main_user_input_queue: List[str] = [] + + _user_input_queue = AsyncSubscriptionQueue() + _retry_queue = AsyncSubscriptionQueue() + + started: bool = False + + async def load( + self, config: Optional[ContinueConfig] = None, only_reloading: bool = False + ): + self.continue_sdk = await ContinueSDK.create(self, config=config) + if override_policy := self.continue_sdk.config.policy_override: + self.policy = override_policy + + # Load documents into the search index + logger.debug("Starting context manager") + await self.context_manager.start( + self.continue_sdk.config.context_providers + + [ + HighlightedCodeContextProvider(ide=self.ide), + FileContextProvider(workspace_dir=self.ide.workspace_directory), + ], + self.continue_sdk, + only_reloading=only_reloading, + ) + + async def start( + self, + full_state: Optional[FullState] = None, + config: Optional[ContinueConfig] = None, + ): + await self.load(config=config, only_reloading=False) + + if full_state is not None: + self.history = full_state.history + self.session_info = full_state.session_info + + # Load saved context groups + context_groups_file = getSavedContextGroupsPath() + try: + with open(context_groups_file, "r") as f: + json_ob = json.load(f) + for title, context_group in json_ob.items(): + self._saved_context_groups[title] = [ + ContextItem(**item) for item in context_group + ] + except Exception as e: + logger.warning( + f"Failed to load saved_context_groups.json: {e}. Reverting to empty list." + ) + self._saved_context_groups = {} + + self.started = True + + async def reload_config(self): + await self.load(config=None, only_reloading=True) + await self.update_subscribers() + + async def cleanup(self): + stop_meilisearch() + + class Config: + arbitrary_types_allowed = True + keep_untouched = (cached_property,) + + @root_validator(pre=True) + def fill_in_values(cls, values): + full_state: FullState = values.get("full_state") + if full_state is not None: + values["history"] = full_state.history + return values + + async def get_full_state(self) -> FullState: + full_state = FullState( + history=self.history, + active=self._active, + user_input_queue=self._main_user_input_queue, + slash_commands=self.get_available_slash_commands(), + adding_highlighted_code=self.context_manager.context_providers[ + "code" + ].adding_highlighted_code + if "code" in self.context_manager.context_providers + else False, + selected_context_items=await self.context_manager.get_selected_items() + if self.context_manager is not None + else [], + session_info=self.session_info, + config=self.continue_sdk.config, + saved_context_groups=self._saved_context_groups, + context_providers=self.context_manager.get_provider_descriptions(), + meilisearch_url=get_meilisearch_url(), + ) + self.full_state = full_state + return full_state + + def get_available_slash_commands(self) -> List[Dict]: + custom_commands = ( + list( + map( + lambda x: {"name": x.name, "description": x.description}, + self.continue_sdk.config.custom_commands, + ) + ) + or [] + ) + slash_commands = ( + list( + map( + lambda x: {"name": x.name, "description": x.description}, + self.continue_sdk.config.slash_commands, + ) + ) + or [] + ) + cmds = custom_commands + slash_commands + cmds.sort(key=lambda x: x["name"] == "edit", reverse=True) + return cmds + + async def clear_history(self): + # Reset history + self.history = History.from_empty() + self._main_user_input_queue = [] + self._active = False + + # Clear context + # await self.context_manager.clear_context() + + await self.update_subscribers() + + def on_update(self, callback: Coroutine["FullState", None, None]): + """Subscribe to changes to state""" + self._on_update_callbacks.append(callback) + + async def update_subscribers(self): + full_state = await self.get_full_state() + for callback in self._on_update_callbacks: + await callback(full_state) + + def give_user_input(self, input: str, index: int): + self._user_input_queue.post(str(index), input) + + async def wait_for_user_input(self) -> str: + self._active = False + await self.update_subscribers() + user_input = await self._user_input_queue.get(str(self.history.current_index)) + self._active = True + await self.update_subscribers() + return user_input + + _manual_edits_buffer: List[FileEditWithFullContents] = [] + + async def reverse_to_index(self, index: int): + try: + while self.history.get_current_index() >= index: + current_step = self.history.get_current().step + self.history.step_back() + if issubclass(current_step.__class__, ReversibleStep): + await current_step.reverse(self.continue_sdk) + + await self.update_subscribers() + except Exception as e: + logger.debug(e) + + def handle_manual_edits(self, edits: List[FileEditWithFullContents]): + for edit in edits: + self._manual_edits_buffer.append(edit) + # TODO: You're storing a lot of unnecessary data here. Can compress into EditDiffs on the spot, and merge. + # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) + # Note that this is being overridden to do nothing in DemoAgent + + async def handle_command_output(self, output: str): + get_traceback_funcs = [get_python_traceback, get_javascript_traceback] + for get_tb_func in get_traceback_funcs: + traceback = get_tb_func(output) + if ( + traceback is not None + and self.continue_sdk.config.on_traceback is not None + ): + step = self.continue_sdk.config.on_traceback(output=output) + await self._run_singular_step(step) + + async def handle_debug_terminal(self, content: str): + """Run the debug terminal step""" + # step = self.continue_sdk.config.on_traceback(output=content) + step = DefaultOnTracebackStep(output=content) + await self._run_singular_step(step) + + async def handle_highlighted_code( + self, + range_in_files: List[RangeInFileWithContents], + edit: Optional[bool] = False, + ): + if "code" not in self.context_manager.context_providers: + return + + # Add to context manager + await self.context_manager.context_providers["code"].handle_highlighted_code( + range_in_files, edit + ) + + await self.update_subscribers() + + _step_depth: int = 0 + + async def retry_at_index(self, index: int): + self.history.timeline[index].step.hide = True + self._retry_queue.post(str(index), None) + + async def delete_at_index(self, index: int): + if not self.history.timeline[index].active: + self.history.timeline[index].step.hide = True + + self.history.timeline[index].deleted = True + self.history.timeline[index].active = False + + await self.update_subscribers() + + async def edit_step_at_index(self, user_input: str, index: int): + node_to_rerun = self.history.timeline[index].copy() + step_to_rerun = node_to_rerun.step + step_to_rerun.user_input = user_input + step_to_rerun.description = user_input + + # Halt the agent's currently running jobs (delete them) + while len(self.history.timeline) > index: + # Remove from timeline + node_to_delete = self.history.timeline.pop() + # Delete so it is stopped if in the middle of running + node_to_delete.deleted = True + + self.history.current_index = index - 1 + + # Set the context to the context used by that step + await self.context_manager.clear_context() + for context_item in node_to_rerun.context_used: + await self.context_manager.manually_add_context_item(context_item) + + await self.update_subscribers() + + # Rerun from the current step + await self.run_from_step(step_to_rerun) + + async def delete_context_with_ids( + self, ids: List[str], index: Optional[int] = None + ): + if index is None: + await self.context_manager.delete_context_with_ids(ids) + else: + self.history.timeline[index].context_used = list( + filter( + lambda item: item.description.id.to_string() not in ids, + self.history.timeline[index].context_used, + ) + ) + await self.update_subscribers() + + async def toggle_adding_highlighted_code(self): + if "code" not in self.context_manager.context_providers: + return + + self.context_manager.context_providers[ + "code" + ].adding_highlighted_code = not self.context_manager.context_providers[ + "code" + ].adding_highlighted_code + await self.update_subscribers() + + async def set_editing_at_ids(self, ids: List[str]): + if "code" not in self.context_manager.context_providers: + return + + await self.context_manager.context_providers["code"].set_editing_at_ids(ids) + await self.update_subscribers() + + async def _run_singular_step( + self, step: "Step", is_future_step: bool = False + ) -> Coroutine[Observation, None, None]: + # Allow config to set disallowed steps + if step.__class__.__name__ in self.continue_sdk.config.disallowed_steps: + return None + + # If a parent step is deleted/cancelled, don't run this step + # TODO: This was problematic because when running a step after deleting one, it seemed to think that was the parent + # last_depth = self._step_depth + # i = self.history.current_index + # while i >= 0 and self.history.timeline[i].depth == last_depth - 1: + # if self.history.timeline[i].deleted: + # return None + # last_depth = self.history.timeline[i].depth + # i -= 1 + + # Log the context and step to dev data + context_used = await self.context_manager.get_selected_items() + posthog_logger.capture_event( + "step run", {"step_name": step.name, "params": step.dict()} + ) + step_id = uuid.uuid4().hex + dev_data_logger.capture( + "step_run", + {"step_name": step.name, "params": step.dict(), "step_id": step_id}, + ) + dev_data_logger.capture( + "context_used", + { + "context": list( + map( + lambda item: item.dict(), + context_used, + ) + ), + "step_id": step_id, + }, + ) + + if not is_future_step: + # Check manual edits buffer, clear out if needed by creating a ManualEditStep + if len(self._manual_edits_buffer) > 0: + manualEditsStep = ManualEditStep.from_sequence( + self._manual_edits_buffer + ) + self._manual_edits_buffer = [] + await self._run_singular_step(manualEditsStep) + + # Update history - do this first so we get top-first tree ordering + index_of_history_node = self.history.add_node( + HistoryNode( + step=step, + observation=None, + depth=self._step_depth, + context_used=context_used, + ) + ) + + # Call all subscribed callbacks + await self.update_subscribers() + + # Try to run step and handle errors + self._step_depth += 1 + + caught_error = False + try: + observation = await step(self.continue_sdk) + except Exception as e: + if ( + index_of_history_node >= len(self.history.timeline) + or self.history.timeline[index_of_history_node].deleted + ): + # If step was deleted/cancelled, don't show error or allow retry + return None + + caught_error = True + + is_continue_custom_exception = ( + issubclass(e.__class__, ContinueCustomException) + or e.__class__.__name__ == ContinueCustomException.__name__ + ) + + error_string = ( + e.message + if is_continue_custom_exception + else "\n".join(traceback.format_exception(e)) + ) + error_title = ( + e.title if is_continue_custom_exception else get_error_title(e) + ) + + # Attach an InternalErrorObservation to the step and unhide it. + logger.error(f"Error while running step: \n{error_string}\n{error_title}") + posthog_logger.capture_event( + "step error", + { + "error_message": error_string, + "error_title": error_title, + "step_name": step.name, + "params": step.dict(), + }, + ) + + observation = InternalErrorObservation( + error=error_string, title=error_title + ) + + # Reveal this step, but hide all of the following steps (its substeps) + step_was_hidden = step.hide + + step.hide = False + i = self.history.get_current_index() + while self.history.timeline[i].step.name != step.name: + self.history.timeline[i].step.hide = True + i -= 1 + + # i is now the index of the step that we want to show/rerun + self.history.timeline[i].observation = observation + self.history.timeline[i].active = False + + await self.update_subscribers() + + # ContinueCustomException can optionally specify a step to run on the error + if is_continue_custom_exception and e.with_step is not None: + await self._run_singular_step(e.with_step) + + # Wait for a retry signal and then resume the step + self._active = False + await self._retry_queue.get(str(i)) + self._active = True + # You might consider a "ignore and continue" button + # want it to have same step depth, so have to decrement + self._step_depth -= 1 + copy_step = step.copy() + copy_step.hide = step_was_hidden + observation = await self._run_singular_step(copy_step) + self._step_depth += 1 + + self._step_depth -= 1 + + # Add observation to history, unless already attached error observation + if not caught_error and index_of_history_node < len(self.history.timeline): + self.history.timeline[index_of_history_node].observation = observation + self.history.timeline[index_of_history_node].active = False + await self.update_subscribers() + + # Update its description + async def update_description(): + if self.continue_sdk.config.disable_summaries: + return + + description = await step.describe(self.continue_sdk.models) + if description is not None: + step.description = description + # Update subscribers with new description + await self.update_subscribers() + + create_async_task( + update_description(), + on_error=lambda e: self.continue_sdk.run_step( + DisplayErrorStep.from_exception(e) + ), + ) + + # Create the session title if not done yet + if self.session_info is None or self.session_info.title is None: + visible_nodes = list( + filter(lambda node: not node.step.hide, self.history.timeline) + ) + + user_input = None + should_create_title = False + for visible_node in visible_nodes: + if isinstance(visible_node.step, UserInputStep): + if user_input is None: + user_input = visible_node.step.user_input + else: + # More than one user input, so don't create title + should_create_title = False + break + elif user_input is None: + continue + else: + # Already have user input, now have the next step + should_create_title = True + break + + # Only create the title if the step after the first input is done + if should_create_title: + create_async_task( + self.create_title(backup=user_input), + on_error=lambda e: self.continue_sdk.run_step( + DisplayErrorStep.from_exception(e) + ), + ) + + return observation + + async def run_from_step(self, step: "Step"): + # if self._active: + # raise RuntimeError("Autopilot is already running") + self._active = True + + next_step = step + is_future_step = False + while not (next_step is None or self._should_halt): + if is_future_step: + # If future step, then we are replaying and need to delete the step from history so it can be replaced + self.history.remove_current_and_substeps() + + await self._run_singular_step(next_step, is_future_step) + + if next_step := self.policy.next(self.continue_sdk.config, self.history): + is_future_step = False + elif next_step := self.history.take_next_step(): + is_future_step = True + else: + next_step = None + + self._active = False + + # Doing this so active can make it to the frontend after steps are done. But want better state syncing tools + await self.update_subscribers() + + async def run_from_observation(self, observation: Observation): + next_step = self.policy.next(self.continue_sdk.config, self.history) + await self.run_from_step(next_step) + + async def run_policy(self): + first_step = self.policy.next(self.continue_sdk.config, self.history) + await self.run_from_step(first_step) + + async def _request_halt(self): + if self._active: + self._should_halt = True + while self._active: + time.sleep(0.1) + self._should_halt = False + return None + + def set_current_session_title(self, title: str): + self.session_info = SessionInfo( + title=title, + session_id=self.ide.session_id, + date_created=str(time.time()), + workspace_directory=self.ide.workspace_directory, + ) + + async def create_title(self, backup: str = None): + # Use the first input and first response to create title for session info, and make the session saveable + if self.session_info is not None and self.session_info.title is not None: + return + + if self.continue_sdk.config.disable_summaries: + if backup is not None: + title = backup + else: + title = "New Session" + else: + chat_history = list( + map(lambda x: x.dict(), await self.continue_sdk.get_chat_context()) + ) + chat_history_str = template_alpaca_messages(chat_history) + title = await self.continue_sdk.models.summarize.complete( + f"{chat_history_str}\n\nGive a short title to describe the above chat session. Do not put quotes around the title. Do not use more than 6 words. The title is: ", + max_tokens=20, + log=False, + ) + title = remove_quotes_and_escapes(title) + + self.set_current_session_title(title) + await self.update_subscribers() + dev_data_logger.capture("new_session", self.session_info.dict()) + + async def accept_user_input(self, user_input: str): + self._main_user_input_queue.append(user_input) + # await self.update_subscribers() + + if len(self._main_user_input_queue) > 1: + return + + # await self._request_halt() + # Just run the step that takes user input, and + # then up to the policy to decide how to deal with it. + self._main_user_input_queue.pop(0) + # await self.update_subscribers() + await self.run_from_step(UserInputStep(user_input=user_input)) + + while len(self._main_user_input_queue) > 0: + await self.run_from_step( + UserInputStep(user_input=self._main_user_input_queue.pop(0)) + ) + + async def accept_refinement_input(self, user_input: str, index: int): + await self._request_halt() + await self.reverse_to_index(index) + await self.run_from_step(UserInputStep(user_input=user_input)) + + async def reject_diff(self, step_index: int): + # Hide the edit step and the UserInputStep before it + self.history.timeline[step_index].step.hide = True + for i in range(step_index - 1, -1, -1): + if isinstance(self.history.timeline[i].step, UserInputStep): + self.history.timeline[i].step.hide = True + break + await self.update_subscribers() + + async def select_context_item(self, id: str, query: str): + await self.context_manager.select_context_item(id, query) + await self.update_subscribers() + + async def select_context_item_at_index(self, id: str, query: str, index: int): + # TODO: This is different from how it works for the main input + # Ideally still tracked through the ContextProviders + # so they can watch for duplicates + context_item = await self.context_manager.get_context_item(id, query) + if context_item is None: + return + self.history.timeline[index].context_used.append(context_item) + await self.update_subscribers() + + async def set_config_attr(self, key_path: List[str], value: redbaron.RedBaron): + edit_config_property(key_path, value) + await self.update_subscribers() + + _saved_context_groups: Dict[str, List[ContextItem]] = {} + + def _persist_context_groups(self): + context_groups_file = getSavedContextGroupsPath() + if os.path.exists(context_groups_file): + with open(context_groups_file, "w") as f: + dict_to_save = { + title: [item.dict() for item in context_items] + for title, context_items in self._saved_context_groups.items() + } + json.dump(dict_to_save, f) + + async def save_context_group(self, title: str, context_items: List[ContextItem]): + self._saved_context_groups[title] = context_items + await self.update_subscribers() + + # Update saved context groups + self._persist_context_groups() + + posthog_logger.capture_event( + "save_context_group", {"title": title, "length": len(context_items)} + ) + + async def select_context_group(self, id: str): + if id not in self._saved_context_groups: + logger.warning(f"Context group {id} not found") + return + context_group = self._saved_context_groups[id] + await self.context_manager.clear_context() + for item in context_group: + await self.context_manager.manually_add_context_item(item) + await self.update_subscribers() + + posthog_logger.capture_event( + "select_context_group", {"title": id, "length": len(context_group)} + ) + dev_data_logger.capture( + "select_context_group", {"title": id, "items": context_group} + ) + + async def delete_context_group(self, id: str): + if id not in self._saved_context_groups: + logger.warning(f"Context group {id} not found") + return + del self._saved_context_groups[id] + await self.update_subscribers() + + # Update saved context groups + self._persist_context_groups() + + posthog_logger.capture_event("delete_context_group", {"title": id}) diff --git a/server/continuedev/core/config.py b/server/continuedev/core/config.py new file mode 100644 index 00000000..2bbb42cc --- /dev/null +++ b/server/continuedev/core/config.py @@ -0,0 +1,114 @@ +from typing import Dict, List, Optional, Type + +from pydantic import BaseModel, Field, validator + +from ..libs.llm.openai_free_trial import OpenAIFreeTrial +from .context import ContextProvider +from .main import Policy, Step +from .models import Models + + +class SlashCommand(BaseModel): + name: str + description: str + step: Type[Step] + params: Optional[Dict] = {} + + def dict(self, *args, **kwargs): + return { + "name": self.name, + "description": self.description, + "params": self.params, + "step": self.step.__name__, + } + + +class CustomCommand(BaseModel): + name: str + prompt: str + description: str + + +class ContinueConfig(BaseModel): + """ + Continue can be deeply customized by editing the `ContinueConfig` object in `~/.continue/config.py` (`%userprofile%\.continue\config.py` for Windows) on your machine. This class is instantiated from the config file for every new session. + """ + + steps_on_startup: List[Step] = Field( + [], + description="Steps that will be automatically run at the beginning of a new session", + ) + disallowed_steps: Optional[List[str]] = Field( + [], + description="Steps that are not allowed to be run, and will be skipped if attempted", + ) + allow_anonymous_telemetry: Optional[bool] = Field( + True, + description="If this field is set to True, we will collect anonymous telemetry as described in the documentation page on telemetry. If set to False, we will not collect any data.", + ) + models: Models = Field( + Models( + default=OpenAIFreeTrial(model="gpt-4"), + summarize=OpenAIFreeTrial(model="gpt-3.5-turbo"), + ), + description="Configuration for the models used by Continue. Read more about how to configure models in the documentation.", + ) + temperature: Optional[float] = Field( + 0.5, + description="The temperature parameter for sampling from the LLM. Higher temperatures will result in more random output, while lower temperatures will result in more predictable output. This value ranges from 0 to 1.", + ) + custom_commands: Optional[List[CustomCommand]] = Field( + [ + CustomCommand( + name="test", + description="This is an example custom command. Use /config to edit it and create more", + prompt="Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated. Give the tests just as chat output, don't edit any file.", + ) + ], + description="An array of custom commands that allow you to reuse prompts. Each has name, description, and prompt properties. When you enter / in the text input, it will act as a shortcut to the prompt.", + ) + slash_commands: Optional[List[SlashCommand]] = Field( + [], + description="An array of slash commands that let you map custom Steps to a shortcut.", + ) + on_traceback: Optional[Step] = Field( + None, + description="The step that will be run when a traceback is detected (when you use the shortcut cmd+shift+R)", + ) + system_message: Optional[str] = Field( + None, description="A system message that will always be followed by the LLM" + ) + policy_override: Optional[Policy] = Field( + None, + description="A Policy object that can be used to override the default behavior of Continue, for example in order to build custom agents that take multiple steps at a time.", + ) + context_providers: List[ContextProvider] = Field( + [], + description="A list of ContextProvider objects that can be used to provide context to the LLM by typing '@'. Read more about ContextProviders in the documentation.", + ) + user_token: Optional[str] = Field( + None, description="An optional token to identify the user." + ) + data_server_url: Optional[str] = Field( + "https://us-west1-autodebug.cloudfunctions.net", + description="The URL of the server where development data is sent. No data is sent unless a valid user token is provided.", + ) + disable_summaries: Optional[bool] = Field( + False, + description="If set to `True`, Continue will not generate summaries for each Step. This can be useful if you want to save on compute.", + ) + + @validator("temperature", pre=True) + def temperature_validator(cls, v): + return max(0.0, min(1.0, v)) + + @staticmethod + def from_filepath(filepath: str) -> "ContinueConfig": + # Use importlib to load the config file config.py at the given path + import importlib.util + + spec = importlib.util.spec_from_file_location("config", filepath) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + + return config.config diff --git a/server/continuedev/core/context.py b/server/continuedev/core/context.py new file mode 100644 index 00000000..547a1593 --- /dev/null +++ b/server/continuedev/core/context.py @@ -0,0 +1,516 @@ +import asyncio +import time +from abc import abstractmethod +from typing import Awaitable, Callable, Dict, List, Optional + +from meilisearch_python_async import Client +from pydantic import BaseModel, Field + +from ..libs.util.create_async_task import create_async_task +from ..libs.util.devdata import dev_data_logger +from ..libs.util.logging import logger +from ..libs.util.telemetry import posthog_logger +from ..server.meilisearch_server import ( + check_meilisearch_running, + get_meilisearch_url, + poll_meilisearch_running, + restart_meilisearch, + start_meilisearch, +) +from .main import ( + ChatMessage, + ContextItem, + ContextItemDescription, + ContextItemId, + ContextProviderDescription, +) + + +class ContinueSDK(BaseModel): + """To avoid circular imports""" + + ... + + +SEARCH_INDEX_NAME = "continue_context_items" + + +class ContextProvider(BaseModel): + """ + The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'. + When you type '@', the context provider will be asked to populate a list of options. + These options will be updated on each keystroke. + When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object). + """ + + title: str = Field( + ..., + description="The title of the ContextProvider. This is what must be typed in the input to trigger the ContextProvider.", + ) + sdk: ContinueSDK = Field( + None, description="The ContinueSDK instance accessible by the ContextProvider" + ) + delete_documents: Callable[[List[str]], Awaitable] = Field( + None, description="Function to delete documents" + ) + update_documents: Callable[[List[ContextItem], str], Awaitable] = Field( + None, description="Function to update documents" + ) + + display_title: str = Field( + ..., + description="The display title of the ContextProvider shown in the dropdown menu", + ) + description: str = Field( + ..., + description="A description of the ContextProvider displayed in the dropdown menu", + ) + dynamic: bool = Field( + ..., description="Indicates whether the ContextProvider is dynamic" + ) + requires_query: bool = Field( + False, + description="Indicates whether the ContextProvider requires a query. For example, the SearchContextProvider requires you to type '@search '. This will change the behavior of the UI so that it can indicate the expectation for a query.", + ) + + selected_items: List[ContextItem] = Field( + [], description="List of selected items in the ContextProvider" + ) + + def dict(self, *args, **kwargs): + original_dict = super().dict(*args, **kwargs) + original_dict.pop("sdk", None) + original_dict.pop("delete_documents", None) + original_dict.pop("update_documents", None) + return original_dict + + async def start(self, sdk: ContinueSDK, delete_documents, update_documents): + """ + Starts the context provider. + + Default implementation sets the sdk. + """ + self.sdk = sdk + self.delete_documents = delete_documents + self.update_documents = update_documents + + async def get_selected_items(self) -> List[ContextItem]: + """ + Returns all of the selected ContextItems. + + Default implementation simply returns self.selected_items. + + Other implementations may add an async processing step. + """ + return self.selected_items + + @abstractmethod + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + """ + Provide documents for search index. This is run on startup. + + This is the only method that must be implemented. + """ + + async def get_chat_messages(self) -> List[ChatMessage]: + """ + Returns all of the chat messages for the context provider. + + Default implementation has a string template. + """ + return [ + ChatMessage( + role="user", + content=f"{item.description.name}: {item.description.description}\n\n{item.content}", + summary=item.description.description, + ) + for item in await self.get_selected_items() + ] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + """ + Returns the ContextItem with the given id. + + Default implementation uses the search index to get the item. + """ + async with Client(get_meilisearch_url()) as search_client: + try: + result = await search_client.index(SEARCH_INDEX_NAME).get_document( + id.to_string() + ) + return ContextItem( + description=ContextItemDescription( + name=result["name"], description=result["description"], id=id + ), + content=result["content"], + ) + except Exception as e: + logger.warning(f"Error while retrieving document from meilisearch: {e}") + + return None + + async def delete_context_with_ids(self, ids: List[ContextItemId]): + """ + Deletes the ContextItems with the given IDs, lets ContextProviders recalculate. + + Default implementation simply deletes those with the given ids. + """ + id_strings = {id.to_string() for id in ids} + self.selected_items = list( + filter( + lambda item: item.description.id.to_string() not in id_strings, + self.selected_items, + ) + ) + + async def clear_context(self): + """ + Clears all context. + + Default implementation simply clears the selected items. + """ + self.selected_items = [] + + async def add_context_item(self, id: ContextItemId, query: str): + """ + Adds the given ContextItem to the list of ContextItems. + + Default implementation simply appends the item, not allowing duplicates. + + This method also allows you not to have to load all of the information until an item is selected. + """ + + # Don't add duplicate context + for item in self.selected_items: + if item.description.id.item_id == id.item_id: + return + + if new_item := await self.get_item(id, query): + self.selected_items.append(new_item) + + async def manually_add_context_item(self, context_item: ContextItem): + for item in self.selected_items: + if item.description.id.item_id == context_item.description.id.item_id: + return + + self.selected_items.append(context_item) + + +class ContextManager: + """ + The context manager is responsible for storing the context to be passed to the LLM, including + - ContextItems (highlighted code, GitHub Issues, etc.) + - ChatMessages in the history + - System Message + - Functions + + It is responsible for compiling all of this information into a single prompt without exceeding the token limit. + """ + + def get_provider_descriptions(self) -> List[ContextProviderDescription]: + """ + Returns a list of ContextProviderDescriptions for each context provider. + """ + return [ + ContextProviderDescription( + title=provider.title, + display_title=provider.display_title, + description=provider.description, + dynamic=provider.dynamic, + requires_query=provider.requires_query, + ) + for provider in self.context_providers.values() + if provider.title != "code" + ] + + async def get_selected_items(self) -> List[ContextItem]: + """ + Returns all of the selected ContextItems. + """ + return sum( + [ + await provider.get_selected_items() + for provider in self.context_providers.values() + ], + [], + ) + + async def get_chat_messages(self) -> List[ChatMessage]: + """ + Returns chat messages from each provider. + """ + return sum( + [ + await provider.get_chat_messages() + for provider in self.context_providers.values() + ], + [], + ) + + def __init__(self): + self.context_providers = {} + self.provider_titles = set() + + async def start( + self, + context_providers: List[ContextProvider], + sdk: ContinueSDK, + only_reloading: bool = False, + ): + """ + Starts the context manager. + """ + new_context_providers = { + provider.title: provider + for provider in context_providers + if provider.title not in self.provider_titles + } + + self.context_providers = { + provider.title: provider for provider in context_providers + } + self.provider_titles = {provider.title for provider in context_providers} + + for provider in context_providers: + await provider.start( + sdk, + ContextManager.delete_documents, + ContextManager.update_documents, + ) + + async def on_err(e): + logger.warning(f"Error loading meilisearch index: {e}") + + # Start MeiliSearch in the background without blocking + async def load_index(providers_to_load: List[ContextProvider]): + running = await check_meilisearch_running() + if not running: + await start_meilisearch() + try: + await asyncio.wait_for(poll_meilisearch_running(), timeout=20) + except asyncio.TimeoutError: + logger.warning( + "Meilisearch did not start in less than 20 seconds. Stopping polling." + ) + return + + logger.debug("Loading Meilisearch index...") + await self.load_index( + sdk.ide.workspace_directory, providers_to_load=providers_to_load + ) + logger.debug("Loaded Meilisearch index") + + providers_to_load = ( + new_context_providers if only_reloading else context_providers + ) + create_async_task(load_index(providers_to_load), on_err) + + @staticmethod + async def update_documents(context_items: List[ContextItem], workspace_dir: str): + """ + Updates the documents in the search index. + """ + documents = [ + { + "id": item.description.id.to_string(), + "name": item.description.name, + "description": item.description.description, + "content": item.content, + "workspace_dir": workspace_dir, + "provider_name": item.description.id.provider_title, + } + for item in context_items + ] + async with Client(get_meilisearch_url()) as search_client: + + async def add_docs(): + index = await search_client.get_index(SEARCH_INDEX_NAME) + await index.add_documents(documents or []) + + try: + await asyncio.wait_for(add_docs(), timeout=20) + except asyncio.TimeoutError: + logger.warning("Failed to add document to meilisearch in 20 seconds") + except Exception as e: + logger.warning(f"Error adding document to meilisearch: {e}") + + @staticmethod + async def delete_documents(ids): + """ + Deletes the documents in the search index. + """ + async with Client(get_meilisearch_url()) as search_client: + try: + await asyncio.wait_for( + search_client.index(SEARCH_INDEX_NAME).delete_documents(ids), + timeout=20, + ) + except asyncio.TimeoutError: + logger.warning( + "Failed to delete document from meilisearch in 20 seconds" + ) + except Exception as e: + logger.warning(f"Error deleting document from meilisearch: {e}") + + async def load_index( + self, + workspace_dir: str, + should_retry: bool = True, + providers_to_load: Optional[List[ContextProvider]] = None, + ): + try: + async with Client(get_meilisearch_url()) as search_client: + # First, create the index if it doesn't exist + # The index is currently shared by all workspaces + await search_client.create_index(SEARCH_INDEX_NAME) + globalSearchIndex = await search_client.get_index(SEARCH_INDEX_NAME) + await globalSearchIndex.update_ranking_rules( + ["attribute", "words", "typo", "proximity", "sort", "exactness"] + ) + await globalSearchIndex.update_searchable_attributes( + ["name", "description"] + ) + await globalSearchIndex.update_filterable_attributes( + ["workspace_dir", "provider_name"] + ) + + async def load_context_provider(provider: ContextProvider): + context_items = await provider.provide_context_items(workspace_dir) + documents = [ + { + "id": item.description.id.to_string(), + "name": item.description.name, + "description": item.description.description, + "content": item.content, + "workspace_dir": workspace_dir, + "provider_name": provider.title, + } + for item in context_items + ] + if len(documents) > 0: + await globalSearchIndex.add_documents(documents) + + return len(documents) + + async def safe_load(provider: ContextProvider): + ti = time.time() + try: + num_documents = await asyncio.wait_for( + load_context_provider(provider), timeout=20 + ) + except asyncio.TimeoutError: + logger.warning( + f"Failed to add documents to meilisearch for context provider {provider.__class__.__name__} in 20 seconds" + ) + return + except Exception as e: + logger.warning( + f"Error adding documents to meilisearch for context provider {provider.__class__.__name__}: {e}" + ) + return + + tf = time.time() + logger.debug( + f"Loaded {num_documents} documents into meilisearch in {tf - ti} seconds for context provider {provider.title}" + ) + + tasks = [ + safe_load(provider) + for _, provider in ( + providers_to_load or self.context_providers + ).items() + ] + await asyncio.wait_for(asyncio.gather(*tasks), timeout=20) + + except Exception as e: + logger.debug(f"Error loading meilisearch index: {e}") + if should_retry: + await restart_meilisearch() + try: + await asyncio.wait_for(poll_meilisearch_running(), timeout=20) + except asyncio.TimeoutError: + logger.warning( + "Meilisearch did not restart in less than 20 seconds. Stopping polling." + ) + await self.load_index(workspace_dir, False) + + async def select_context_item(self, id: str, query: str): + """ + Selects the ContextItem with the given id. + """ + id: ContextItemId = ContextItemId.from_string(id) + if id.provider_title not in self.provider_titles: + raise ValueError( + f"Context provider with title {id.provider_title} not found" + ) + + posthog_logger.capture_event( + "select_context_item", + { + "provider_title": id.provider_title, + "item_id": id.item_id, + "query": query, + }, + ) + dev_data_logger.capture( + "select_context_item", + { + "provider_title": id.provider_title, + "item_id": id.item_id, + "query": query, + }, + ) + await self.context_providers[id.provider_title].add_context_item(id, query) + + async def get_context_item(self, id: str, query: str) -> ContextItem: + """ + Returns the ContextItem with the given id. + """ + id: ContextItemId = ContextItemId.from_string(id) + if id.provider_title not in self.provider_titles: + raise ValueError( + f"Context provider with title {id.provider_title} not found" + ) + + return await self.context_providers[id.provider_title].get_item(id, query) + + async def delete_context_with_ids(self, ids: List[str]): + """ + Deletes the ContextItems with the given IDs, lets ContextProviders recalculate. + """ + + # Group by provider title + provider_title_to_ids: Dict[str, List[ContextItemId]] = {} + for id in ids: + id: ContextItemId = ContextItemId.from_string(id) + if id.provider_title not in provider_title_to_ids: + provider_title_to_ids[id.provider_title] = [] + provider_title_to_ids[id.provider_title].append(id) + + # Recalculate context for each updated provider + for provider_title, ids in provider_title_to_ids.items(): + await self.context_providers[provider_title].delete_context_with_ids(ids) + + async def clear_context(self): + """ + Clears all context. + """ + for provider in self.context_providers.values(): + await self.context_providers[provider.title].clear_context() + + async def manually_add_context_item(self, item: ContextItem): + """ + Adds the given ContextItem to the list of ContextItems. + """ + if item.description.id.provider_title not in self.provider_titles: + return + + await self.context_providers[ + item.description.id.provider_title + ].manually_add_context_item(item) + + +""" +Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the +same format of prompts so you don't have to redo all of this logic. +""" diff --git a/server/continuedev/core/env.py b/server/continuedev/core/env.py new file mode 100644 index 00000000..60b86538 --- /dev/null +++ b/server/continuedev/core/env.py @@ -0,0 +1,31 @@ +import os + +from dotenv import load_dotenv + + +def get_env_var(var_name: str): + load_dotenv() + return os.getenv(var_name) + + +def make_sure_env_exists(): + if not os.path.exists(".env"): + with open(".env", "w") as f: + f.write("") + + +def save_env_var(var_name: str, var_value: str): + make_sure_env_exists() + + with open(".env", "r") as f: + lines = f.readlines() + with open(".env", "w") as f: + values = {} + for line in lines: + key, value = line.split("=") + value = value.replace('"', "") + values[key] = value + + values[var_name] = var_value + for key, value in values.items(): + f.write(f'{key}="{value}"\n') diff --git a/server/continuedev/core/lsp.py b/server/continuedev/core/lsp.py new file mode 100644 index 00000000..fc26c85c --- /dev/null +++ b/server/continuedev/core/lsp.py @@ -0,0 +1,416 @@ +import asyncio +import threading +from typing import List, Literal, Optional + +import aiohttp +from pydantic import BaseModel + +from ..models.filesystem import RangeInFile +from ..models.main import Position, Range + + +def filepath_to_uri(filename: str) -> str: + return f"file://{filename}" + + +def uri_to_filepath(uri: str) -> str: + if uri.startswith("file://"): + return uri[7:] + else: + return uri + + +PORT = 8099 + + +class LSPClient: + ready: bool = False + lock: asyncio.Lock = asyncio.Lock() + + def __init__(self, host: str, port: int, workspace_paths: List[str]): + self.host = host + self.port = port + self.session = aiohttp.ClientSession() + self.next_id = 0 + self.workspace_paths = workspace_paths + + async def connect(self): + print("Connecting") + self.ws = await self.session.ws_connect(f"ws://{self.host}:{self.port}/") + print("Connected") + self.ready = True + + async def send(self, data): + await self.ws.send_json(data) + + async def recv(self): + await self.lock.acquire() + + try: + return await self.ws.receive_json() + finally: + self.lock.release() + + async def close(self): + await self.ws.close() + await self.session.close() + + async def call_method(self, method_name, **kwargs): + body = { + "jsonrpc": "2.0", + "id": self.next_id, + "method": method_name, + "params": kwargs, + } + self.next_id += 1 + await self.send(body) + response = await self.recv() + return response + + async def initialize(self): + initialization_args = { + "capabilities": { + "textDocument": { + "codeAction": {"dynamicRegistration": True}, + "codeLens": {"dynamicRegistration": True}, + "colorProvider": {"dynamicRegistration": True}, + "completion": { + "completionItem": { + "commitCharactersSupport": True, + "documentationFormat": ["markdown", "plaintext"], + "snippetSupport": True, + }, + "completionItemKind": { + "valueSet": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + ] + }, + "contextSupport": True, + "dynamicRegistration": True, + }, + "definition": {"dynamicRegistration": True}, + "documentHighlight": {"dynamicRegistration": True}, + "documentLink": {"dynamicRegistration": True}, + "documentSymbol": { + "dynamicRegistration": True, + "symbolKind": { + "valueSet": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + ] + }, + }, + "formatting": {"dynamicRegistration": True}, + "hover": { + "contentFormat": ["markdown", "plaintext"], + "dynamicRegistration": True, + }, + "implementation": {"dynamicRegistration": True}, + "onTypeFormatting": {"dynamicRegistration": True}, + "publishDiagnostics": {"relatedInformation": True}, + "rangeFormatting": {"dynamicRegistration": True}, + "references": {"dynamicRegistration": True}, + "rename": {"dynamicRegistration": True}, + "signatureHelp": { + "dynamicRegistration": True, + "signatureInformation": { + "documentationFormat": ["markdown", "plaintext"] + }, + }, + "synchronization": { + "didSave": True, + "dynamicRegistration": True, + "willSave": True, + "willSaveWaitUntil": True, + }, + "typeDefinition": {"dynamicRegistration": True}, + }, + "workspace": { + "applyEdit": True, + "configuration": True, + "didChangeConfiguration": {"dynamicRegistration": True}, + "didChangeWatchedFiles": {"dynamicRegistration": True}, + "executeCommand": {"dynamicRegistration": True}, + "symbol": { + "dynamicRegistration": True, + "symbolKind": { + "valueSet": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + ] + }, + }, + "workspaceEdit": {"documentChanges": True}, + "workspaceFolders": True, + }, + }, + "processId": 1234, + "rootPath": None, + "rootUri": filepath_to_uri(self.workspace_paths[0]), + "initializationOptions": {}, + "trace": "off", + "workspaceFolders": [ + { + "uri": filepath_to_uri(workspacePath), + "name": workspacePath.split("/")[-1], + } + for workspacePath in self.workspace_paths + ], + } + return await self.call_method("initialize", **initialization_args) + + async def goto_definition(self, filepath: str, position: Position): + return await self.call_method( + "textDocument/definition", + textDocument={"uri": filepath_to_uri(filepath)}, + position=position.dict(), + ) + + async def document_symbol(self, filepath: str): + return await self.call_method( + "textDocument/documentSymbol", + textDocument={"uri": filepath_to_uri(filepath)}, + ) + + async def find_references( + self, filepath: str, position: Position, include_declaration: bool = False + ): + return await self.call_method( + "textDocument/references", + textDocument={"uri": filepath_to_uri(filepath)}, + position=position.dict(), + context={"includeDeclaration": include_declaration}, + ) + + async def folding_range(self, filepath: str): + response = await self.call_method( + "textDocument/foldingRange", + textDocument={"uri": filepath_to_uri(filepath)}, + ) + return response["result"] + + +async def start_language_server() -> threading.Thread: + """Manually start the python language server. Not used currently.""" + raise NotImplementedError() + # try: + # kill_proc(PORT) + # thread = threading.Thread( + # target=start_ws_lang_server, + # args=(PORT, False, PythonLSPServer), + # ) + # thread.daemon = True + # thread.start() + + # except Exception as e: + # logger.warning("Could not start TCP server: %s", e) + + # await asyncio.sleep(2) + + # return thread + + +class DocumentSymbol(BaseModel): + name: str + containerName: Optional[str] = None + kind: int + location: RangeInFile + + +class FoldingRange(BaseModel): + range: Range + kind: Optional[Literal["comment", "imports", "region"]] = None + + +class ContinueLSPClient(BaseModel): + workspace_dir: str + + lsp_client: LSPClient = None + lsp_thread: Optional[threading.Thread] = None + + @property + def ready(self): + if self.lsp_client is None: + return False + return self.lsp_client.ready + + class Config: + arbitrary_types_allowed = True + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("lsp_client", None) + return original_dict + + async def start(self): + self.lsp_thread = await start_language_server() + self.lsp_client = LSPClient("localhost", PORT, [self.workspace_dir]) + await self.lsp_client.connect() + await self.lsp_client.initialize() + + async def stop(self): + await self.lsp_client.close() + if self.lsp_thread: + self.lsp_thread.join() + + def location_to_range_in_file(self, location): + return RangeInFile( + filepath=uri_to_filepath(location["uri"]), + range=Range.from_shorthand( + location["range"]["start"]["line"], + location["range"]["start"]["character"], + location["range"]["end"]["line"], + location["range"]["end"]["character"], + ), + ) + + async def goto_definition( + self, position: Position, filename: str + ) -> List[RangeInFile]: + response = self.lsp_client.goto_definition( + filename, + position, + ) + return [self.location_to_range_in_file(x) for x in response] + + async def find_references( + self, position: Position, filename: str, include_declaration: bool = False + ) -> List[RangeInFile]: + response = await self.lsp_client.find_references( + filename, + position, + include_declaration=include_declaration, + ) + return [self.location_to_range_in_file(x) for x in response["result"]] + + async def document_symbol(self, filepath: str) -> List: + response = await self.lsp_client.document_symbol(filepath) + return [ + DocumentSymbol( + name=x["name"], + containerName=x["containerName"], + kind=x["kind"], + location=self.location_to_range_in_file(x["location"]), + ) + for x in response["result"] + ] + + async def folding_range(self, filepath: str) -> List[FoldingRange]: + response = await self.lsp_client.folding_range(filepath) + + return [ + FoldingRange( + range=Range.from_shorthand( + x["startLine"], + x.get("startCharacter", 0), + x["endLine"] if "endCharacter" in x else x["endLine"] + 1, + x.get("endCharacter", 0), + ), + kind=x.get("kind"), + ) + for x in response + ] + + async def get_enclosing_folding_range_of_position( + self, position: Position, filepath: str + ) -> Optional[FoldingRange]: + ranges = await self.folding_range(filepath) + + max_start_position = Position(line=0, character=0) + max_range = None + for r in ranges: + if r.range.contains(position): + if r.range.start > max_start_position: + max_start_position = r.range.start + max_range = r + + return max_range + + async def get_enclosing_folding_range( + self, range_in_file: RangeInFile + ) -> Optional[FoldingRange]: + ranges = await self.folding_range(range_in_file.filepath) + + max_start_position = Position(line=0, character=0) + max_range = None + for r in ranges: + if r.range.contains(range_in_file.range.start) and r.range.contains( + range_in_file.range.end + ): + if r.range.start > max_start_position: + max_start_position = r.range.start + max_range = r + + return max_range diff --git a/server/continuedev/core/main.py b/server/continuedev/core/main.py new file mode 100644 index 00000000..617a5aaa --- /dev/null +++ b/server/continuedev/core/main.py @@ -0,0 +1,437 @@ +import json +from typing import Any, Coroutine, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, validator +from pydantic.schema import schema + +from ..models.main import ContinueBaseModel +from .observation import Observation + +ChatMessageRole = Literal["assistant", "user", "system", "function"] + + +class FunctionCall(ContinueBaseModel): + name: str + arguments: str + + +class ChatMessage(ContinueBaseModel): + role: ChatMessageRole + content: Union[str, None] = None + name: Union[str, None] = None + # A summary for pruning chat context to fit context window. Often the Step name. + summary: str + function_call: Union[FunctionCall, None] = None + + def to_dict(self, with_functions: bool) -> Dict: + d = self.dict() + del d["summary"] + if d["function_call"] is not None: + d["function_call"]["name"] = d["function_call"]["name"].replace(" ", "") + + if d["content"] is None: + d["content"] = "" + for key, value in list(d.items()): + if value is None: + del d[key] + + if not with_functions: + if d["role"] == "function": + d["role"] = "assistant" + if "name" in d: + del d["name"] + if "function_call" in d: + del d["function_call"] + return d + + +def resolve_refs(schema_data): + def traverse(obj): + if isinstance(obj, dict): + if "$ref" in obj: + ref = obj["$ref"] + parts = ref.split("/") + ref_obj = schema_data + for part in parts[1:]: + ref_obj = ref_obj[part] + return traverse(ref_obj) + else: + for key, value in obj.items(): + obj[key] = traverse(value) + elif isinstance(obj, list): + for i in range(len(obj)): + obj[i] = traverse(obj[i]) + return obj + + return traverse(schema_data) + + +unincluded_parameters = [ + "system_message", + "chat_context", + "manage_own_chat_context", + "hide", + "name", + "description", +] + + +def step_to_json_schema(step) -> str: + pydantic_class = step.__class__ + schema_data = schema([pydantic_class]) + resolved_schema = resolve_refs(schema_data) + parameters = resolved_schema["definitions"][pydantic_class.__name__] + for parameter in unincluded_parameters: + if parameter in parameters["properties"]: + del parameters["properties"][parameter] + return { + "name": step.name.replace(" ", ""), + "description": step.description or "", + "parameters": parameters, + } + + +def step_to_fn_call_arguments(step: "Step") -> str: + args = step.dict() + for parameter in unincluded_parameters: + if parameter in args: + del args[parameter] + return json.dumps(args) + + +class HistoryNode(ContinueBaseModel): + """A point in history, a list of which make up History""" + + step: "Step" + observation: Union[Observation, None] + depth: int + deleted: bool = False + active: bool = True + logs: List[str] = [] + context_used: List["ContextItem"] = [] + + def to_chat_messages(self) -> List[ChatMessage]: + if self.step.description is None or self.step.manage_own_chat_context: + return self.step.chat_context + return self.step.chat_context + [ + ChatMessage( + role="assistant", + name=self.step.__class__.__name__, + content=self.step.description or f"Ran function {self.step.name}", + summary=f"Called function {self.step.name}", + ) + ] + + +class History(ContinueBaseModel): + """A history of steps taken and their results""" + + timeline: List[HistoryNode] + current_index: int + + def to_chat_history(self) -> List[ChatMessage]: + msgs = [] + for node in self.timeline: + if not node.step.hide: + msgs += node.to_chat_messages() + return msgs + + def add_node(self, node: HistoryNode) -> int: + """Add node and return the index where it was added""" + self.timeline.insert(self.current_index + 1, node) + self.current_index += 1 + return self.current_index + + def get_current(self) -> Union[HistoryNode, None]: + if self.current_index < 0: + return None + return self.timeline[self.current_index] + + def get_last_at_depth( + self, depth: int, include_current: bool = False + ) -> Union[HistoryNode, None]: + i = self.current_index if include_current else self.current_index - 1 + while i >= 0: + if ( + self.timeline[i].depth == depth + and type(self.timeline[i].step).__name__ != "ManualEditStep" + ): + return self.timeline[i] + i -= 1 + return None + + def get_last_at_same_depth(self) -> Union[HistoryNode, None]: + return self.get_last_at_depth(self.get_current().depth) + + def remove_current_and_substeps(self): + self.timeline.pop(self.current_index) + while self.get_current() is not None and self.get_current().depth > 0: + self.timeline.pop(self.current_index) + + def take_next_step(self) -> Union["Step", None]: + if self.has_future(): + self.current_index += 1 + current_state = self.get_current() + if current_state is None: + return None + return current_state.step + return None + + def get_current_index(self) -> int: + return self.current_index + + def has_future(self) -> bool: + return self.current_index < len(self.timeline) - 1 + + def step_back(self): + self.current_index -= 1 + + def last_observation(self) -> Union[Observation, None]: + state = self.get_last_at_same_depth() + if state is None: + return None + return state.observation + + def pop_step(self, index: int = None) -> Union[HistoryNode, None]: + index = index if index is not None else self.current_index + if index < 0 or self.current_index < 0: + return None + + node = self.timeline.pop(index) + + if index <= self.current_index: + self.current_index -= 1 + + return node.step + + @classmethod + def from_empty(cls): + return cls(timeline=[], current_index=-1) + + +class SlashCommandDescription(ContinueBaseModel): + name: str + description: str + + +class ContextItemId(BaseModel): + """ + A ContextItemId is a unique identifier for a ContextItem. + """ + + provider_title: str + item_id: str + + @validator("provider_title", "item_id") + def must_be_valid_id(cls, v): + import re + + if not re.match(r"^[0-9a-zA-Z_-]*$", v): + raise ValueError( + "Both provider_title and item_id can only include characters 0-9, a-z, A-Z, -, and _" + ) + return v + + def to_string(self) -> str: + return f"{self.provider_title}-{self.item_id}" + + @staticmethod + def from_string(string: str) -> "ContextItemId": + provider_title, *rest = string.split("-") + item_id = "-".join(rest) + return ContextItemId(provider_title=provider_title, item_id=item_id) + + +class ContextItemDescription(BaseModel): + """ + A ContextItemDescription is a description of a ContextItem that is displayed to the user when they type '@'. + + The id can be used to retrieve the ContextItem from the ContextManager. + """ + + name: str + description: str + id: ContextItemId + + +class ContextItem(BaseModel): + """ + A ContextItem is a single item that is stored in the ContextManager. + """ + + description: ContextItemDescription + content: str + + @validator("content", pre=True) + def content_must_be_string(cls, v): + if v is None: + return "" + return v + + editing: bool = False + editable: bool = False + + +class SessionInfo(ContinueBaseModel): + session_id: str + title: str + date_created: str + workspace_directory: Optional[str] = None + + +class ContinueConfig(ContinueBaseModel): + system_message: Optional[str] + temperature: Optional[float] + + class Config: + extra = "allow" + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("policy", None) + return original_dict + + +class ContextProviderDescription(BaseModel): + title: str + display_title: str + description: str + dynamic: bool + requires_query: bool + + +class FullState(ContinueBaseModel): + """A full state of the program, including the history""" + + history: History + active: bool + user_input_queue: List[str] + slash_commands: List[SlashCommandDescription] + adding_highlighted_code: bool + selected_context_items: List[ContextItem] + session_info: Optional[SessionInfo] = None + config: ContinueConfig + saved_context_groups: Dict[str, List[ContextItem]] = {} + context_providers: List[ContextProviderDescription] = [] + meilisearch_url: Optional[str] = None + + +class ContinueSDK: + ... + + +class Models: + ... + + +class Policy(ContinueBaseModel): + """A rule that determines which step to take next""" + + # Note that history is mutable, kinda sus + def next( + self, config: ContinueConfig, history: History = History.from_empty() + ) -> "Step": + raise NotImplementedError + + +class Step(ContinueBaseModel): + name: str = None + hide: bool = False + description: Union[str, None] = None + + class_name: str = "Step" + + @validator("class_name", pre=True, always=True) + def class_name_is_class_name(cls, class_name): + return cls.__name__ + + system_message: Union[str, None] = None + chat_context: List[ChatMessage] = [] + manage_own_chat_context: bool = False + + class Config: + copy_on_model_validation = False + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + if self.description is not None: + return self.description + return "Running step: " + self.name + + def dict(self, *args, **kwargs): + d = super().dict(*args, **kwargs) + # Make sure description is always a string + d["description"] = self.description or "" + return d + + @validator("name", pre=True, always=True) + def name_is_class_name(cls, name): + if name is None: + return cls.__name__ + return name + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + raise NotImplementedError + + async def __call__(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + return await self.run(sdk) + + def __rshift__(self, other: "Step"): + steps = [] + if isinstance(self, SequentialStep): + steps = self.steps + else: + steps.append(self) + if isinstance(other, SequentialStep): + steps += other.steps + else: + steps.append(other) + return SequentialStep(steps=steps) + + +class SequentialStep(Step): + steps: List[Step] + hide: bool = True + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + for step in self.steps: + observation = await sdk.run_step(step) + return observation + + +class ValidatorObservation(Observation): + passed: bool + observation: Observation + + +class Validator(Step): + def run(self, sdk: ContinueSDK) -> ValidatorObservation: + raise NotImplementedError + + +class Context: + key_value: Dict[str, Any] = {} + + def set(self, key: str, value: Any): + self.key_value[key] = value + + def get(self, key: str) -> Any: + return self.key_value.get(key, None) + + +class ContinueCustomException(Exception): + title: str + message: str + with_step: Union[Step, None] + + def __init__( + self, + message: str, + title: str = "Error while running step:", + with_step: Union[Step, None] = None, + ): + self.message = message + self.title = title + self.with_step = with_step + + +HistoryNode.update_forward_refs() diff --git a/server/continuedev/core/models.py b/server/continuedev/core/models.py new file mode 100644 index 00000000..21ebd8f6 --- /dev/null +++ b/server/continuedev/core/models.py @@ -0,0 +1,113 @@ +from typing import List, Optional + +from pydantic import BaseModel + +from ..libs.llm.anthropic import AnthropicLLM +from ..libs.llm.base import LLM +from ..libs.llm.ggml import GGML +from ..libs.llm.google_palm_api import GooglePaLMAPI +from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI +from ..libs.llm.hf_tgi import HuggingFaceTGI +from ..libs.llm.llamacpp import LlamaCpp +from ..libs.llm.ollama import Ollama +from ..libs.llm.openai import OpenAI +from ..libs.llm.openai_free_trial import OpenAIFreeTrial +from ..libs.llm.replicate import ReplicateLLM +from ..libs.llm.together import TogetherLLM + + +class ContinueSDK(BaseModel): + pass + + +ALL_MODEL_ROLES = [ + "default", + "summarize", + "edit", + "chat", +] + +MODEL_CLASSES = { + cls.__name__: cls + for cls in [ + OpenAI, + OpenAIFreeTrial, + GGML, + TogetherLLM, + AnthropicLLM, + ReplicateLLM, + Ollama, + LlamaCpp, + HuggingFaceInferenceAPI, + HuggingFaceTGI, + GooglePaLMAPI, + ] +} + +MODEL_MODULE_NAMES = { + "OpenAI": "openai", + "OpenAIFreeTrial": "openai_free_trial", + "GGML": "ggml", + "TogetherLLM": "together", + "AnthropicLLM": "anthropic", + "ReplicateLLM": "replicate", + "Ollama": "ollama", + "LlamaCpp": "llamacpp", + "HuggingFaceInferenceAPI": "hf_inference_api", + "HuggingFaceTGI": "hf_tgi", + "GooglePaLMAPI": "google_palm_api", +} + + +class Models(BaseModel): + """Main class that holds the current model configuration""" + + default: LLM + summarize: Optional[LLM] = None + edit: Optional[LLM] = None + chat: Optional[LLM] = None + + saved: List[LLM] = [] + + # TODO namespace these away to not confuse readers, + # or split Models into ModelsConfig, which gets turned into Models + sdk: ContinueSDK = None + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("sdk", None) + return original_dict + + @property + def all_models(self): + models = [getattr(self, role) for role in ALL_MODEL_ROLES] + return [model for model in models if model is not None] + + @property + def system_message(self) -> Optional[str]: + if self.sdk: + return self.sdk.config.system_message + return None + + def set_system_message(self, msg: str): + for model in self.all_models: + if model.system_message is None: + model.system_message = msg + + async def start(self, sdk: "ContinueSDK"): + """Start each of the LLMs, or fall back to default""" + self.sdk = sdk + + for role in ALL_MODEL_ROLES: + model = getattr(self, role) + if model is None: + setattr(self, role, self.default) + else: + await sdk.start_model(model) + + self.set_system_message(self.system_message) + + async def stop(self, sdk: "ContinueSDK"): + """Stop each LLM (if it's not the default, which is shared)""" + for model in self.all_models: + await model.stop() diff --git a/server/continuedev/core/observation.py b/server/continuedev/core/observation.py new file mode 100644 index 00000000..8a5e454e --- /dev/null +++ b/server/continuedev/core/observation.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel, validator + +from ..models.main import Traceback + + +class Observation(BaseModel): + pass + + +class TracebackObservation(Observation): + traceback: Traceback + + +class ValidatorObservation(Observation): + passed: bool + + +class UserInputObservation(Observation): + user_input: str + + +class DictObservation(Observation): + values: dict + + def __getitem__(self, key): + return self.values[key] + + +class TextObservation(Observation): + text: str + + @validator("text", pre=True, always=True) + def text_not_none(cls, v): + if v is None: + return "" + return v + + +class InternalErrorObservation(Observation): + title: str + error: str diff --git a/server/continuedev/core/sdk.py b/server/continuedev/core/sdk.py new file mode 100644 index 00000000..408168f6 --- /dev/null +++ b/server/continuedev/core/sdk.py @@ -0,0 +1,309 @@ +import os +import traceback +from typing import Coroutine, List, Optional, Union + +from ..libs.llm.base import LLM +from ..libs.util.devdata import dev_data_logger +from ..libs.util.logging import logger +from ..libs.util.paths import ( + convertConfigImports, + getConfigFilePath, + getDiffsFolderPath, +) +from ..libs.util.telemetry import posthog_logger +from ..models.filesystem import RangeInFile +from ..models.filesystem_edit import ( + AddDirectory, + AddFile, + DeleteDirectory, + DeleteFile, + FileEdit, + FileSystemEdit, +) +from ..models.main import Range +from ..server.ide_protocol import AbstractIdeProtocolServer +from .abstract_sdk import AbstractContinueSDK +from .config import ContinueConfig +from .lsp import ContinueLSPClient +from .main import ( + ChatMessage, + Context, + ContinueCustomException, + History, + HistoryNode, + Step, +) +from .models import Models +from .observation import Observation +from .steps import ( + DefaultModelEditCodeStep, + FileSystemEditStep, + MessageStep, + RangeInFileWithContents, + ShellCommandsStep, + WaitForUserConfirmationStep, +) + + +class Autopilot: + pass + + +class ContinueSDK(AbstractContinueSDK): + """The SDK provided as parameters to a step""" + + ide: AbstractIdeProtocolServer + models: Models + lsp: Optional[ContinueLSPClient] = None + context: Context + config: ContinueConfig + __autopilot: Autopilot + + def __init__(self, autopilot: Autopilot): + self.ide = autopilot.ide + self.__autopilot = autopilot + self.context = autopilot.context + + async def load(self, config: Optional[ContinueConfig] = None): + # Create necessary directories + getDiffsFolderPath() + + try: + self.config = config or self._load_config_dot_py() + except Exception as e: + logger.error(f"Failed to load config.py: {traceback.format_exception(e)}") + + self.config = ( + ContinueConfig() + if self._last_valid_config is None + else self._last_valid_config + ) + + formatted_err = "\n".join(traceback.format_exception(e)) + msg_step = MessageStep( + name="Invalid Continue Config File", message=formatted_err + ) + msg_step.description = f"Falling back to default config settings due to the following error in `~/.continue/config.py`.\n```\n{formatted_err}\n```\n\nIt's possible this was caused by an update to the Continue config format. If you'd like to see the new recommended default `config.py`, check [here](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/constants/default_config.py).\n\nIf the error is related to OpenAIServerInfo, see the updated way of using these parameters [here](https://continue.dev/docs/customization#azure-openai-service)." + self.history.add_node( + HistoryNode(step=msg_step, observation=None, depth=0, active=False) + ) + await self.ide.setFileOpen(getConfigFilePath()) + + # Start models + self.models = self.config.models + await self.models.start(self) + + # Start LSP + # async def start_lsp(): + # try: + # sdk.lsp = ContinueLSPClient( + # workspace_dir=sdk.ide.workspace_directory, + # ) + # await sdk.lsp.start() + # except Exception as e: + # logger.warning(f"Failed to start LSP client: {e}", exc_info=False) + # sdk.lsp = None + + # create_async_task( + # start_lsp(), on_error=lambda e: logger.error("Failed to setup LSP: %s", e) + # ) + + # When the config is loaded, setup posthog logger + posthog_logger.setup( + self.ide.unique_id, self.config.allow_anonymous_telemetry, self.ide.ide_info + ) + dev_data_logger.setup(self.config.user_token, self.config.data_server_url) + + @classmethod + async def create( + cls, autopilot: Autopilot, config: Optional[ContinueConfig] = None + ) -> "ContinueSDK": + sdk = ContinueSDK(autopilot) + autopilot.continue_sdk = sdk + + await sdk.load(config=config) + + return sdk + + @property + def history(self) -> History: + return self.__autopilot.history + + def write_log(self, message: str): + self.history.timeline[self.history.current_index].logs.append(message) + + async def start_model(self, llm: LLM): + await llm.start(unique_id=self.ide.unique_id, write_log=self.write_log) + + async def _ensure_absolute_path(self, path: str) -> str: + if os.path.isabs(path): + return path + + # Else if in workspace + workspace_path = os.path.join(self.ide.workspace_directory, path) + if os.path.exists(workspace_path): + return workspace_path + else: + # Check if it matches any of the open files, then use that absolute path + open_files = await self.ide.getOpenFiles() + for open_file in open_files: + if os.path.basename(open_file) == os.path.basename(path): + return open_file + raise Exception(f"Path {path} does not exist") + + async def run_step(self, step: Step) -> Coroutine[Observation, None, None]: + return await self.__autopilot._run_singular_step(step) + + async def apply_filesystem_edit( + self, edit: FileSystemEdit, name: str = None, description: str = None + ): + return await self.run_step( + FileSystemEditStep( + edit=edit, description=description, **({"name": name} if name else {}) + ) + ) + + async def wait_for_user_input(self) -> str: + return await self.__autopilot.wait_for_user_input() + + async def wait_for_user_confirmation(self, prompt: str): + return await self.run_step(WaitForUserConfirmationStep(prompt=prompt)) + + async def run( + self, + commands: Union[List[str], str], + cwd: str = None, + name: str = None, + description: str = None, + handle_error: bool = True, + ) -> Coroutine[str, None, None]: + commands = commands if isinstance(commands, List) else [commands] + return ( + await self.run_step( + ShellCommandsStep( + cmds=commands, + cwd=cwd, + description=description, + handle_error=handle_error, + **({"name": name} if name else {}), + ) + ) + ).text + + async def edit_file( + self, + filename: str, + prompt: str, + name: str = None, + description: str = "", + range: Range = None, + ): + filepath = await self._ensure_absolute_path(filename) + + await self.ide.setFileOpen(filepath) + contents = await self.ide.readFile(filepath) + await self.run_step( + DefaultModelEditCodeStep( + range_in_files=[ + RangeInFile(filepath=filepath, range=range) + if range is not None + else RangeInFile.from_entire_file(filepath, contents) + ], + user_input=prompt, + description=description, + **({"name": name} if name else {}), + ) + ) + + async def append_to_file(self, filename: str, content: str): + filepath = await self._ensure_absolute_path(filename) + previous_content = await self.ide.readFile(filepath) + file_edit = FileEdit.from_append(filepath, previous_content, content) + await self.ide.applyFileSystemEdit(file_edit) + + async def add_file(self, filename: str, content: Union[str, None]): + filepath = await self._ensure_absolute_path(filename) + dir_name = os.path.dirname(filepath) + os.makedirs(dir_name, exist_ok=True) + return await self.run_step( + FileSystemEditStep(edit=AddFile(filepath=filepath, content=content)) + ) + + async def delete_file(self, filename: str): + filename = await self._ensure_absolute_path(filename) + return await self.run_step( + FileSystemEditStep(edit=DeleteFile(filepath=filename)) + ) + + async def add_directory(self, path: str): + path = await self._ensure_absolute_path(path) + return await self.run_step(FileSystemEditStep(edit=AddDirectory(path=path))) + + async def delete_directory(self, path: str): + path = await self._ensure_absolute_path(path) + return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) + + _last_valid_config: ContinueConfig = None + + def _load_config_dot_py(self, retry: bool = True) -> ContinueConfig: + try: + path = getConfigFilePath() + config = ContinueConfig.from_filepath(path) + self._last_valid_config = config + + logger.debug("Loaded Continue config file from %s", path) + + return config + except ModuleNotFoundError as e: + if not retry: + raise e + # Check if the module was "continuedev.src" + if e.name == "continuedev.src": + convertConfigImports(shorten=True) + return self._load_config_dot_py(retry=False) + else: + raise e + + def get_code_context( + self, only_editing: bool = False + ) -> List[RangeInFileWithContents]: + highlighted_ranges = self.__autopilot.context_manager.context_providers[ + "code" + ].highlighted_ranges + context = ( + list(filter(lambda x: x.item.editing, highlighted_ranges)) + if only_editing + else highlighted_ranges + ) + return [c.rif for c in context] + + def set_loading_message(self, message: str): + # self.__autopilot.set_loading_message(message) + raise NotImplementedError() + + def raise_exception( + self, message: str, title: str, with_step: Union[Step, None] = None + ): + raise ContinueCustomException(message, title, with_step) + + async def get_chat_context(self) -> List[ChatMessage]: + history_context = self.history.to_chat_history() + + context_messages: List[ + ChatMessage + ] = await self.__autopilot.context_manager.get_chat_messages() + + # Insert at the end, but don't insert after latest user message or function call + for msg in context_messages: + history_context.insert(-1, msg) + + return history_context + + async def update_ui(self): + await self.__autopilot.update_subscribers() + + async def clear_history(self): + await self.__autopilot.clear_history() + + def current_step_was_deleted(self): + return self.history.timeline[self.history.current_index].deleted diff --git a/server/continuedev/core/steps.py b/server/continuedev/core/steps.py new file mode 100644 index 00000000..5c20dd15 --- /dev/null +++ b/server/continuedev/core/steps.py @@ -0,0 +1,963 @@ +# These steps are depended upon by ContinueSDK +import difflib +import subprocess +from textwrap import dedent +from typing import Coroutine, List, Optional, Union + +from ..libs.llm.base import LLM +from ..libs.llm.openai_free_trial import OpenAIFreeTrial +from ..libs.util.count_tokens import DEFAULT_MAX_TOKENS +from ..libs.util.devdata import dev_data_logger +from ..libs.util.strings import ( + dedent_and_get_common_whitespace, + remove_quotes_and_escapes, +) +from ..libs.util.telemetry import posthog_logger +from ..libs.util.templating import render_prompt_template +from ..models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents +from ..models.filesystem_edit import ( + EditDiff, + FileEdit, + FileEditWithFullContents, + FileSystemEdit, +) + +# from ....libs.llm.replicate import ReplicateLLM +from ..models.main import Range +from .main import ChatMessage, ContinueCustomException, Step +from .observation import Observation, TextObservation, UserInputObservation + + +class ContinueSDK: + pass + + +class Models: + pass + + +class ReversibleStep(Step): + async def reverse(self, sdk: ContinueSDK): + raise NotImplementedError + + +class MessageStep(Step): + name: str = "Message" + message: str + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return self.message + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + return TextObservation(text=self.message) + + +class DisplayErrorStep(Step): + name: str = "Error in the Continue server" + + title: str = "Error in the Continue server" + message: str = "There was an error in the Continue server." + + @staticmethod + def from_exception(e: Exception) -> "DisplayErrorStep": + if isinstance(e, ContinueCustomException): + return DisplayErrorStep(title=e.title, message=e.message, name=e.title) + + return DisplayErrorStep(message=str(e)) + + class Config: + arbitrary_types_allowed = True + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return self.message + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + raise ContinueCustomException(message=self.message, title=self.title) + + +class FileSystemEditStep(ReversibleStep): + edit: FileSystemEdit + _diff: Union[EditDiff, None] = None + + hide: bool = True + + async def run(self, sdk: "ContinueSDK") -> Coroutine[Observation, None, None]: + self._diff = await sdk.ide.applyFileSystemEdit(self.edit) + return None + + async def reverse(self, sdk: "ContinueSDK"): + await sdk.ide.applyFileSystemEdit(self._diff.backward) + # Where and when should file saves happen? + + +def output_contains_error(output: str) -> bool: + return "Traceback" in output or "SyntaxError" in output + + +AI_ASSISTED_STRING = "(โœจ AI-Assisted โœจ)" + + +class ShellCommandsStep(Step): + cmds: List[str] + cwd: Union[str, None] = None + name: str = "Run Shell Commands" + handle_error: bool = True + + _err_text: Union[str, None] = None + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + if self._err_text is not None: + return f"Error when running shell commands:\n```\n{self._err_text}\n```" + + cmds_str = "\n".join(self.cmds) + return await models.summarize.complete( + f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:" + ) + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + process = subprocess.Popen( + "/bin/bash", + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + cwd=self.cwd or sdk.ide.workspace_directory, + ) + + stdin_input = "\n".join(self.cmds) + out, err = process.communicate(stdin_input.encode()) + + # If it fails, return the error + if err is not None and err != "": + self._err_text = err + return TextObservation(text=err) + + return None + + +class DefaultModelEditCodeStep(Step): + user_input: str + model: Optional[LLM] = None + range_in_files: List[RangeInFile] + name: str = "Editing Code" + hide = False + description: str = "" + _prompt: str = dedent( + """\ + Take the file prefix and suffix into account, but only rewrite the code_to_edit as specified in the user_request. The code you write in modified_code_to_edit will replace the code between the code_to_edit tags. Do NOT preface your answer or write anything other than code. The tag should be written to indicate the end of the modified code section. Do not ever use nested tags. + + Example: + + + class Database: + def __init__(self): + self._data = {{}} + + def get(self, key): + return self._data[key] + + + + def set(self, key, value): + self._data[key] = value + + + + def clear_all(): + self._data = {{}} + + + Raise an error if the key already exists. + + + def set(self, key, value): + if key in self._data: + raise KeyError(f"Key {{key}} already exists") + self._data[key] = value + + + Main task: + """ + ) + _previous_contents: str = "" + _new_contents: str = "" + _prompt_and_completion: str = "" + + summary_prompt: str = "Please briefly explain the changes made to the code above. Give no more than 2-3 sentences, and use markdown bullet points:" + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + name = await models.summarize.complete( + f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:" + ) + self.name = remove_quotes_and_escapes(name) + + if self._previous_contents.strip() == self._new_contents.strip(): + return "No edits were made" + else: + return None + + async def get_prompt_parts( + self, rif: RangeInFileWithContents, sdk: ContinueSDK, full_file_contents: str + ): + # We don't know here all of the functions being passed in. + # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. + # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. + if self.model is not None: + await sdk.start_model(self.model) + + model_to_use = self.model or sdk.models.edit + max_tokens = int(model_to_use.context_length / 2) + + TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 + if ( + model_to_use.count_tokens(rif.contents) + > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE + ): + self.description += "\n\n**It looks like you've selected a large range to edit, which may take a while to complete. If you'd like to cancel, click the 'X' button above. If you highlight a more specific range, Continue will only edit within it.**" + + # At this point, we also increase the max_tokens parameter so it doesn't stop in the middle of generation + # Increase max_tokens to be double the size of the range + # But don't exceed twice default max tokens + max_tokens = int( + min(model_to_use.count_tokens(rif.contents), DEFAULT_MAX_TOKENS) * 2.5 + ) + + BUFFER_FOR_FUNCTIONS = 400 + total_tokens = ( + model_to_use.count_tokens( + full_file_contents + self._prompt + self.user_input + ) + + BUFFER_FOR_FUNCTIONS + + max_tokens + ) + + # If using 3.5 and overflows, upgrade to 3.5.16k + if model_to_use.model == "gpt-3.5-turbo": + if total_tokens > model_to_use.context_length: + model_to_use = OpenAIFreeTrial(model="gpt-3.5-turbo-0613") + await sdk.start_model(model_to_use) + + # Remove tokens from the end first, and then the start to clear space + # This part finds the start and end lines + full_file_contents_lst = full_file_contents.split("\n") + max_start_line = rif.range.start.line + min_end_line = rif.range.end.line + cur_start_line = 0 + cur_end_line = len(full_file_contents_lst) - 1 + + if total_tokens > model_to_use.context_length: + while cur_end_line > min_end_line: + total_tokens -= model_to_use.count_tokens( + full_file_contents_lst[cur_end_line] + ) + cur_end_line -= 1 + if total_tokens < model_to_use.context_length: + break + + if total_tokens > model_to_use.context_length: + while cur_start_line < max_start_line: + cur_start_line += 1 + total_tokens -= model_to_use.count_tokens( + full_file_contents_lst[cur_start_line] + ) + if total_tokens < model_to_use.context_length: + break + + # Now use the found start/end lines to get the prefix and suffix strings + file_prefix = "\n".join(full_file_contents_lst[cur_start_line:max_start_line]) + file_suffix = "\n".join(full_file_contents_lst[min_end_line : cur_end_line - 1]) + + # Move any surrounding blank line in rif.contents to the prefix/suffix + # TODO: Keep track of start line of the range, because it's needed below for offset stuff + if len(rif.contents) > 0: + lines = rif.contents.splitlines(keepends=True) + first_line = lines[0] if lines else None + while first_line and first_line.strip() == "": + file_prefix += first_line + rif.contents = rif.contents[len(first_line) :] + lines = rif.contents.splitlines(keepends=True) + first_line = lines[0] if lines else None + + last_line = lines[-1] if lines else None + while last_line and last_line.strip() == "": + file_suffix = last_line + file_suffix + rif.contents = rif.contents[: len(rif.contents) - len(last_line)] + lines = rif.contents.splitlines(keepends=True) + last_line = lines[-1] if lines else None + + while rif.contents.startswith("\n"): + file_prefix += "\n" + rif.contents = rif.contents[1:] + while rif.contents.endswith("\n"): + file_suffix = "\n" + file_suffix + rif.contents = rif.contents[:-1] + + return file_prefix, rif.contents, file_suffix, model_to_use, max_tokens + + def compile_prompt( + self, file_prefix: str, contents: str, file_suffix: str, sdk: ContinueSDK + ) -> str: + if contents.strip() == "": + # Separate prompt for insertion at the cursor, the other tends to cause it to repeat whole file + prompt = dedent( + f"""\ + +{file_prefix} + + + +{file_suffix} + + +{self.user_input} + + +Please output the code to be inserted at the cursor in order to fulfill the user_request. Do NOT preface your answer or write anything other than code. You should not write any tags, just the code. Make sure to correctly indent the code:""" + ) + return prompt + + prompt = self._prompt + if file_prefix.strip() != "": + prompt += dedent( + f""" + +{file_prefix} +""" + ) + prompt += dedent( + f""" + +{contents} +""" + ) + if file_suffix.strip() != "": + prompt += dedent( + f""" + +{file_suffix} +""" + ) + prompt += dedent( + f""" + +{self.user_input} + + +""" + ) + + return prompt + + def is_end_line(self, line: str) -> bool: + return "" in line or "" in line + + def line_to_be_ignored(self, line: str, is_first_line: bool = False) -> bool: + return ( + "```" in line + or "" in line + or "" in line + or "" in line + or "" in line + or "" in line + or "" in line + or "" in line + or "" in line + ) + + async def stream_rif(self, rif: RangeInFileWithContents, sdk: ContinueSDK): + await sdk.ide.saveFile(rif.filepath) + full_file_contents = await sdk.ide.readFile(rif.filepath) + + ( + file_prefix, + contents, + file_suffix, + model_to_use, + max_tokens, + ) = await self.get_prompt_parts(rif, sdk, full_file_contents) + contents, common_whitespace = dedent_and_get_common_whitespace(contents) + prompt = self.compile_prompt(file_prefix, contents, file_suffix, sdk) + full_file_contents_lines = full_file_contents.split("\n") + + lines_to_display = [] + + async def sendDiffUpdate( + lines: List[str], sdk: ContinueSDK, final: bool = False + ): + nonlocal full_file_contents_lines, rif, lines_to_display + + completion = "\n".join(lines) + + full_prefix_lines = full_file_contents_lines[: rif.range.start.line] + full_suffix_lines = full_file_contents_lines[rif.range.end.line :] + + # Don't do this at the very end, just show the inserted code + if final: + lines_to_display = [] + # Only recalculate at every new-line, because this is sort of expensive + elif completion.endswith("\n"): + contents_lines = rif.contents.split("\n") + rewritten_lines = 0 + for line in lines: + for i in range(rewritten_lines, len(contents_lines)): + if ( + difflib.SequenceMatcher( + None, line, contents_lines[i] + ).ratio() + > 0.7 + and contents_lines[i].strip() != "" + ): + rewritten_lines = i + 1 + break + lines_to_display = contents_lines[rewritten_lines:] + + new_file_contents = ( + "\n".join(full_prefix_lines) + + "\n" + + completion + + "\n" + + ( + "\n".join(lines_to_display) + "\n" + if len(lines_to_display) > 0 + else "" + ) + + "\n".join(full_suffix_lines) + ) + + step_index = sdk.history.current_index + + await sdk.ide.showDiff(rif.filepath, new_file_contents, step_index) + + # Important state variables + # ------------------------- + original_lines = [] if rif.contents == "" else rif.contents.split("\n") + # In the actual file, taking into account block offset + current_line_in_file = rif.range.start.line + current_block_lines = [] + original_lines_below_previous_blocks = original_lines + # The start of the current block in file, taking into account block offset + current_block_start = -1 + offset_from_blocks = 0 + + # Don't end the block until you've matched N simultaneous lines + # This helps avoid many tiny blocks + LINES_TO_MATCH_BEFORE_ENDING_BLOCK = 2 + # If a line has been matched at the end of the block, this is its index within original_lines_below_previous_blocks + # Except we are keeping track of multiple potentialities, so it's a list + # We always check the lines following each of these leads, but if multiple make it out at the end, we use the first one + # This is a tuple of (index_of_last_matched_line, number_of_lines_matched) + indices_of_last_matched_lines = [] + + async def handle_generated_line(line: str): + nonlocal current_block_start, current_line_in_file, original_lines, original_lines_below_previous_blocks, current_block_lines, indices_of_last_matched_lines, LINES_TO_MATCH_BEFORE_ENDING_BLOCK, offset_from_blocks + + # Highlight the line to show progress + line_to_highlight = current_line_in_file - len(current_block_lines) + if False: + await sdk.ide.highlightCode( + RangeInFile( + filepath=rif.filepath, + range=Range.from_shorthand( + line_to_highlight, 0, line_to_highlight, 0 + ), + ), + "#FFFFFF22" if len(current_block_lines) == 0 else "#00FF0022", + ) + + if len(current_block_lines) == 0: + # Set this as the start of the next block + current_block_start = ( + rif.range.start.line + + len(original_lines) + - len(original_lines_below_previous_blocks) + + offset_from_blocks + ) + if ( + len(original_lines_below_previous_blocks) > 0 + and line == original_lines_below_previous_blocks[0] + ): + # Line is equal to the next line in file, move past this line + original_lines_below_previous_blocks = ( + original_lines_below_previous_blocks[1:] + ) + return + + # In a block, and have already matched at least one line + # Check if the next line matches, for each of the candidates + matches_found = [] + first_valid_match = None + for ( + index_of_last_matched_line, + num_lines_matched, + ) in indices_of_last_matched_lines: + if ( + index_of_last_matched_line + 1 + < len(original_lines_below_previous_blocks) + and line + == original_lines_below_previous_blocks[ + index_of_last_matched_line + 1 + ] + ): + matches_found.append( + (index_of_last_matched_line + 1, num_lines_matched + 1) + ) + if ( + first_valid_match is None + and num_lines_matched + 1 >= LINES_TO_MATCH_BEFORE_ENDING_BLOCK + ): + first_valid_match = ( + index_of_last_matched_line + 1, + num_lines_matched + 1, + ) + indices_of_last_matched_lines = matches_found + + if first_valid_match is not None: + # We've matched the required number of lines, insert suggestion! + + # We added some lines to the block that were matched (including maybe some blank lines) + # So here we will strip all matching lines from the end of current_block_lines + lines_stripped = [] + index_of_last_line_in_block = first_valid_match[0] + while ( + len(current_block_lines) > 0 + and current_block_lines[-1] + == original_lines_below_previous_blocks[ + index_of_last_line_in_block - 1 + ] + ): + lines_stripped.append(current_block_lines.pop()) + index_of_last_line_in_block -= 1 + + # It's also possible that some lines match at the beginning of the block + # lines_stripped_at_beginning = [] + # j = 0 + # while len(current_block_lines) > 0 and current_block_lines[0] == original_lines_below_previous_blocks[first_valid_match[0] - first_valid_match[1] + j]: + # lines_stripped_at_beginning.append( + # current_block_lines.pop(0)) + # j += 1 + # # current_block_start += 1 + + # Insert the suggestion + replacement = "\n".join(current_block_lines) + start_line = current_block_start + end_line = current_block_start + index_of_last_line_in_block + + if False: + await sdk.ide.showSuggestion( + FileEdit( + filepath=rif.filepath, + range=Range.from_shorthand(start_line, 0, end_line, 0), + replacement=replacement, + ) + ) + + # Reset current block / update variables + current_line_in_file += 1 + offset_from_blocks += len(current_block_lines) + original_lines_below_previous_blocks = ( + original_lines_below_previous_blocks[ + index_of_last_line_in_block + 1 : + ] + ) + current_block_lines = [] + current_block_start = -1 + indices_of_last_matched_lines = [] + + return + + # Always look for new matching candidates + new_matches = [] + for i in range(len(original_lines_below_previous_blocks)): + og_line = original_lines_below_previous_blocks[i] + # TODO: It's a bit sus to be disqualifying empty lines. + # What you ideally do is find ALL matches, and then throw them out as you check the following lines + if og_line == line: # and og_line.strip() != "": + new_matches.append((i, 1)) + indices_of_last_matched_lines += new_matches + + # Make sure they are sorted by index + indices_of_last_matched_lines = sorted( + indices_of_last_matched_lines, key=lambda x: x[0] + ) + + current_block_lines.append(line) + + messages = await sdk.get_chat_context() + # Delete the last user and assistant messages + i = len(messages) - 1 + deleted = 0 + while i >= 0 and deleted < 2: + if messages[i].role == "user" or messages[i].role == "assistant": + messages.pop(i) + deleted += 1 + i -= 1 + messages.append( + ChatMessage(role="user", content=prompt, summary=self.user_input) + ) + + lines_of_prefix_copied = 0 + lines = [] + unfinished_line = "" + completion_lines_covered = 0 + repeating_file_suffix = False + line_below_highlighted_range = file_suffix.lstrip().split("\n")[0] + + # Use custom templates defined by the model + if template := model_to_use.prompt_templates.get("edit"): + rendered = render_prompt_template( + template, + messages[:-1], + { + "code_to_edit": rif.contents, + "user_input": self.user_input, + "file_prefix": file_prefix, + "file_suffix": file_suffix, + }, + ) + if isinstance(rendered, str): + messages = [ + ChatMessage( + role="user", + content=rendered, + summary=self.user_input, + ) + ] + else: + messages = rendered + + generator = model_to_use.stream_complete( + rendered, + temperature=sdk.config.temperature, + max_tokens=min(max_tokens, model_to_use.context_length // 2), + ) + + else: + + async def gen(): + async for chunk in model_to_use.stream_chat( + messages, + temperature=sdk.config.temperature, + max_tokens=min(max_tokens, model_to_use.context_length // 2), + ): + if "content" in chunk: + yield chunk["content"] + + generator = gen() + + posthog_logger.capture_event( + "model_use", + {"model": model_to_use.model, "provider": model_to_use.__class__.__name__}, + ) + dev_data_logger.capture( + "model_use", + {"model": model_to_use.model, "provider": model_to_use.__class__.__name__}, + ) + + try: + async for chunk in generator: + # Stop early if it is repeating the file_suffix or the step was deleted + if repeating_file_suffix: + break + if sdk.current_step_was_deleted(): + return + + # Accumulate lines + chunk_lines = chunk.split("\n") + chunk_lines[0] = unfinished_line + chunk_lines[0] + if chunk.endswith("\n"): + unfinished_line = "" + chunk_lines.pop() # because this will be an empty string + else: + unfinished_line = chunk_lines.pop() + + # Deal with newly accumulated lines + for i in range(len(chunk_lines)): + # Trailing whitespace doesn't matter + chunk_lines[i] = chunk_lines[i].rstrip() + chunk_lines[i] = common_whitespace + chunk_lines[i] + + # Lines that should signify the end of generation + if self.is_end_line(chunk_lines[i]): + break + # Lines that should be ignored, like the <> tags + elif self.line_to_be_ignored( + chunk_lines[i], completion_lines_covered == 0 + ): + continue # noice + # Check if we are currently just copying the prefix + elif ( + (lines_of_prefix_copied > 0 or completion_lines_covered == 0) + and lines_of_prefix_copied < len(file_prefix.splitlines()) + and chunk_lines[i] + == full_file_contents_lines[lines_of_prefix_copied] + ): + # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line + lines_of_prefix_copied += 1 + continue # also nice + # Because really short lines might be expected to be repeated, this is only a !heuristic! + # Stop when it starts copying the file_suffix + elif ( + chunk_lines[i].strip() == line_below_highlighted_range.strip() + and len(chunk_lines[i].strip()) > 4 + and not ( + len(original_lines_below_previous_blocks) > 0 + and chunk_lines[i].strip() + == original_lines_below_previous_blocks[0].strip() + ) + ): + repeating_file_suffix = True + break + + # If none of the above, insert the line! + if False: + await handle_generated_line(chunk_lines[i]) + + lines.append(chunk_lines[i]) + completion_lines_covered += 1 + current_line_in_file += 1 + + await sendDiffUpdate( + lines + + [ + common_whitespace + if unfinished_line.startswith("<") + else (common_whitespace + unfinished_line) + ], + sdk, + ) + finally: + await generator.aclose() + # Add the unfinished line + if ( + unfinished_line != "" + and not self.line_to_be_ignored( + unfinished_line, completion_lines_covered == 0 + ) + and not self.is_end_line(unfinished_line) + ): + unfinished_line = common_whitespace + unfinished_line + lines.append(unfinished_line) + await handle_generated_line(unfinished_line) + completion_lines_covered += 1 + current_line_in_file += 1 + + await sendDiffUpdate(lines, sdk, final=True) + + if False: + # If the current block isn't empty, add that suggestion + if len(current_block_lines) > 0: + # We have a chance to back-track here for blank lines that are repeats of the end of the original + # Don't want to have the same ending in both the original and the generated, can just leave it there + num_to_remove = 0 + for i in range(-1, -len(current_block_lines) - 1, -1): + if len(original_lines_below_previous_blocks) == 0: + break + if ( + current_block_lines[i] + == original_lines_below_previous_blocks[-1] + ): + num_to_remove += 1 + original_lines_below_previous_blocks.pop() + else: + break + current_block_lines = ( + current_block_lines[:-num_to_remove] + if num_to_remove > 0 + else current_block_lines + ) + + # It's also possible that some lines match at the beginning of the block + # while len(current_block_lines) > 0 and len(original_lines_below_previous_blocks) > 0 and current_block_lines[0] == original_lines_below_previous_blocks[0]: + # current_block_lines.pop(0) + # original_lines_below_previous_blocks.pop(0) + # current_block_start += 1 + + await sdk.ide.showSuggestion( + FileEdit( + filepath=rif.filepath, + range=Range.from_shorthand( + current_block_start, + 0, + current_block_start + + len(original_lines_below_previous_blocks), + 0, + ), + replacement="\n".join(current_block_lines), + ) + ) + + # Record the completion + completion = "\n".join(lines) + self._previous_contents = "\n".join(original_lines) + self._new_contents = completion + self._prompt_and_completion += prompt + completion + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + await sdk.update_ui() + + rif_with_contents = [] + for range_in_file in map( + lambda x: RangeInFile( + filepath=x.filepath, + # Only consider the range line-by-line. Maybe later don't if it's only a single line. + range=x.range.to_full_lines(), + ), + self.range_in_files, + ): + file_contents = await sdk.ide.readRangeInFile(range_in_file) + rif_with_contents.append( + RangeInFileWithContents.from_range_in_file(range_in_file, file_contents) + ) + + rif_dict = {} + for rif in rif_with_contents: + rif_dict[rif.filepath] = rif.contents + + for rif in rif_with_contents: + await sdk.ide.setSuggestionsLocked(rif.filepath, True) + await self.stream_rif(rif, sdk) + await sdk.ide.setSuggestionsLocked(rif.filepath, False) + + changes = "\n".join( + difflib.ndiff( + self._previous_contents.splitlines(), + self._new_contents.splitlines(), + ) + ) + + if sdk.config.disable_summaries: + self.name = "" + self.description = f"Edited {len(self.range_in_files)} files" + await sdk.update_ui() + else: + self.name = "Generating summary" + self.description = "" + async for chunk in sdk.models.summarize.stream_complete( + dedent( + f"""\ + Diff summary: "{self.user_input}" + + ```diff + {changes} + ``` + + {self.summary_prompt}""" + ) + ): + self.description += chunk + await sdk.update_ui() + + sdk.context.set("last_edit_user_input", self.user_input) + sdk.context.set("last_edit_diff", changes) + sdk.context.set("last_edit_range", self.range_in_files[-1].range) + + +class EditFileStep(Step): + filepath: str + prompt: str + hide: bool = True + model: Optional[LLM] = None + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return "Editing file: " + self.filepath + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + file_contents = await sdk.ide.readFile(self.filepath) + await sdk.run_step( + DefaultModelEditCodeStep( + range_in_files=[ + RangeInFile.from_entire_file(self.filepath, file_contents) + ], + user_input=self.prompt, + model=self.model, + ) + ) + + +class ManualEditStep(ReversibleStep): + edit_diff: EditDiff + hide: bool = True + + hide: bool = True + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return "Manual edit step" + # TODO - only handling FileEdit here, but need all other types of FileSystemEdits + # Also requires the merge_file_edit function + # return llm.complete(dedent(f"""This code was replaced: + + # {self.edit_diff.backward.replacement} + + # With this code: + + # {self.edit_diff.forward.replacement} + + # Maximally concise summary of changes in bullet points (can use markdown): + # """)) + + @classmethod + def from_sequence(cls, edits: List[FileEditWithFullContents]) -> "ManualEditStep": + diffs = [] + for edit in edits: + _, diff = FileSystem.apply_edit_to_str(edit.fileContents, edit.fileEdit) + diffs.append(diff) + return cls(edit_diff=EditDiff.from_sequence(diffs)) + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + return None + + async def reverse(self, sdk: ContinueSDK): + await sdk.ide.applyFileSystemEdit(self.edit_diff.backward) + + +class UserInputStep(Step): + user_input: str + name: str = "User Input" + hide: bool = False + + manage_own_chat_context: bool = True + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + if self.description is not None: + return self.description + return self.user_input + + async def run( + self, sdk: ContinueSDK + ) -> Coroutine[UserInputObservation, None, None]: + self.chat_context.append( + ChatMessage(role="user", content=self.user_input, summary=self.user_input) + ) + self.description = self.user_input + return UserInputObservation(user_input=self.user_input) + + +class WaitForUserInputStep(Step): + prompt: str + name: str = "Waiting for user input" + + _description: Union[str, None] = None + _response: Union[str, None] = None + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + if self._response is None: + return self.prompt + else: + return f"{self.prompt}\n\n`{self._response}`" + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + self.description = self.prompt + resp = await sdk.wait_for_user_input() + self.description = f"{self.prompt}\n\n`{resp}`" + return TextObservation(text=resp) + + +class WaitForUserConfirmationStep(Step): + prompt: str + name: str = "Waiting for user confirmation" + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return self.prompt + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + self.description = self.prompt + resp = await sdk.wait_for_user_input() + return TextObservation(text=resp) diff --git a/server/continuedev/headless/__init__.py b/server/continuedev/headless/__init__.py new file mode 100644 index 00000000..2ecdcce6 --- /dev/null +++ b/server/continuedev/headless/__init__.py @@ -0,0 +1,20 @@ +from typing import Optional, Union + +import typer + +from ..core.config import ContinueConfig +from ..server.session_manager import Session, session_manager +from .headless_ide import LocalIdeProtocol + +app = typer.Typer() + + +async def start_headless_session( + config: Optional[Union[str, ContinueConfig]] = None +) -> Session: + if config is not None: + if isinstance(config, str): + config: ContinueConfig = ContinueConfig.from_filepath(config) + + ide = LocalIdeProtocol() + return await session_manager.new_session(ide, config=config) diff --git a/server/continuedev/headless/headless_ide.py b/server/continuedev/headless/headless_ide.py new file mode 100644 index 00000000..088da2c9 --- /dev/null +++ b/server/continuedev/headless/headless_ide.py @@ -0,0 +1,181 @@ +import os +import subprocess +import uuid +from typing import Any, Callable, Coroutine, List, Optional + +from dotenv import load_dotenv +from fastapi import WebSocket + +from ..models.filesystem import ( + FileSystem, + RangeInFile, + RangeInFileWithContents, + RealFileSystem, +) +from ..models.filesystem_edit import EditDiff, FileEdit, FileSystemEdit +from ..server.ide_protocol import AbstractIdeProtocolServer + +load_dotenv() + + +def get_mac_address(): + mac_num = hex(uuid.getnode()).replace("0x", "").upper() + mac = "-".join(mac_num[i : i + 2] for i in range(0, 11, 2)) + return mac + + +class LocalIdeProtocol(AbstractIdeProtocolServer): + websocket: WebSocket = None + session_id: Optional[str] + workspace_directory: str = os.getcwd() + unique_id: str = get_mac_address() + + filesystem: FileSystem = RealFileSystem() + + async def handle_json(self, data: Any): + """Handle a json message""" + pass + + def showSuggestion(self, file_edit: FileEdit): + """Show a suggestion to the user""" + pass + + async def setFileOpen(self, filepath: str, open: bool = True): + """Set whether a file is open""" + pass + + async def showMessage(self, message: str): + """Show a message to the user""" + print(message) + + async def showVirtualFile(self, name: str, contents: str): + """Show a virtual file""" + pass + + async def setSuggestionsLocked(self, filepath: str, locked: bool = True): + """Set whether suggestions are locked""" + pass + + async def getSessionId(self): + """Get a new session ID""" + pass + + async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: + """Show suggestions to the user and wait for a response""" + pass + + def onAcceptRejectSuggestion(self, accepted: bool): + """Called when the user accepts or rejects a suggestion""" + pass + + def onFileSystemUpdate(self, update: FileSystemEdit): + """Called when a file system update is received""" + pass + + def onCloseGUI(self, session_id: str): + """Called when a GUI is closed""" + pass + + def onOpenGUIRequest(self): + """Called when a GUI is requested to be opened""" + pass + + async def getOpenFiles(self) -> List[str]: + """Get a list of open files""" + pass + + async def getVisibleFiles(self) -> List[str]: + """Get a list of visible files""" + pass + + async def getHighlightedCode(self) -> List[RangeInFile]: + """Get a list of highlighted code""" + pass + + async def readFile(self, filepath: str) -> str: + """Read a file""" + return self.filesystem.read(filepath) + + async def readRangeInFile(self, range_in_file: RangeInFile) -> str: + """Read a range in a file""" + return self.filesystem.read_range_in_file(range_in_file) + + async def editFile(self, edit: FileEdit): + """Edit a file""" + self.filesystem.apply_file_edit(edit) + + async def applyFileSystemEdit(self, edit: FileSystemEdit) -> EditDiff: + """Apply a file edit""" + return self.filesystem.apply_edit(edit) + + async def saveFile(self, filepath: str): + """Save a file""" + pass + + async def getUserSecret(self, key: str): + """Get a user secret""" + return os.environ.get(key) + + async def highlightCode(self, range_in_file: RangeInFile, color: str): + """Highlight code""" + pass + + async def runCommand(self, command: str) -> str: + """Run a command using subprocess (don't pass, actually implement)""" + return subprocess.check_output(command, shell=True).decode("utf-8") + + def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): + """Called when highlighted code is updated""" + pass + + def onDeleteAtIndex(self, index: int): + """Called when a step is deleted at a given index""" + pass + + async def showDiff(self, filepath: str, replacement: str, step_index: int): + """Show a diff""" + pass + + def subscribeToFilesCreated(self, callback: Callable[[List[str]], None]): + """Subscribe to files created event""" + pass + + def subscribeToFilesDeleted(self, callback: Callable[[List[str]], None]): + """Subscribe to files deleted event""" + pass + + def subscribeToFilesRenamed(self, callback: Callable[[List[str], List[str]], None]): + """Subscribe to files renamed event""" + pass + + def subscribeToFileSaved(self, callback: Callable[[str, str], None]): + """Subscribe to file saved event""" + pass + + def onFilesCreated(self, filepaths: List[str]): + """Called when files are created""" + pass + + def onFilesDeleted(self, filepaths: List[str]): + """Called when files are deleted""" + pass + + def onFilesRenamed(self, old_filepaths: List[str], new_filepaths: List[str]): + """Called when files are renamed""" + pass + + def onFileSaved(self, filepath: str, contents: str): + """Called when a file is saved""" + pass + + async def fileExists(self, filepath: str) -> Coroutine[Any, Any, str]: + """Check if a file exists""" + return self.filesystem.exists(filepath) + + async def getTerminalContents(self) -> Coroutine[Any, Any, str]: + return "" + + async def listDirectoryContents( + self, directory: str, recursive: bool = False + ) -> List[str]: + return self.filesystem.list_directory_contents(directory, recursive=recursive) diff --git a/server/continuedev/libs/__init__.py b/server/continuedev/libs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/continuedev/libs/chroma/.gitignore b/server/continuedev/libs/chroma/.gitignore new file mode 100644 index 00000000..6320cd24 --- /dev/null +++ b/server/continuedev/libs/chroma/.gitignore @@ -0,0 +1 @@ +data \ No newline at end of file diff --git a/server/continuedev/libs/chroma/query.py b/server/continuedev/libs/chroma/query.py new file mode 100644 index 00000000..d77cce49 --- /dev/null +++ b/server/continuedev/libs/chroma/query.py @@ -0,0 +1,218 @@ +import json +import os +import subprocess +from functools import cached_property +from typing import List, Tuple + +from llama_index import ( + Document, + GPTVectorStoreIndex, + StorageContext, + load_index_from_storage, +) +from llama_index.langchain_helpers.text_splitter import TokenTextSplitter + +from ..util.logging import logger +from .update import filter_ignored_files, load_gpt_index_documents + + +class ChromaIndexManager: + workspace_dir: str + + def __init__(self, workspace_dir: str): + self.workspace_dir = workspace_dir + + @cached_property + def current_commit(self) -> str: + """Get the current commit.""" + return ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], cwd=self.workspace_dir + ) + .decode("utf-8") + .strip() + ) + + @cached_property + def current_branch(self) -> str: + """Get the current branch.""" + return ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=self.workspace_dir + ) + .decode("utf-8") + .strip() + ) + + @cached_property + def index_dir(self) -> str: + return os.path.join( + self.workspace_dir, ".continue", "chroma", self.current_branch + ) + + @cached_property + def git_root_dir(self): + """Get the root directory of a Git repository.""" + try: + return ( + subprocess.check_output( + ["git", "rev-parse", "--show-toplevel"], cwd=self.workspace_dir + ) + .strip() + .decode() + ) + except subprocess.CalledProcessError: + return None + + def check_index_exists(self): + return os.path.exists(os.path.join(self.index_dir, "metadata.json")) + + def create_codebase_index(self): + """Create a new index for the current branch.""" + if not self.check_index_exists(): + os.makedirs(self.index_dir) + else: + return + + documents = load_gpt_index_documents(self.workspace_dir) + + chunks = {} + doc_chunks = [] + for doc in documents: + text_splitter = TokenTextSplitter() + try: + text_chunks = text_splitter.split_text(doc.text) + except: + logger.warning(f"ERROR (probably found special token): {doc.text}") + continue # lol + filename = doc.extra_info["filename"] + chunks[filename] = len(text_chunks) + for i, text in enumerate(text_chunks): + doc_chunks.append(Document(text, doc_id=f"{filename}::{i}")) + + with open(f"{self.index_dir}/metadata.json", "w") as f: + json.dump({"commit": self.current_commit, "chunks": chunks}, f, indent=4) + + index = GPTVectorStoreIndex([]) + + for chunk in doc_chunks: + index.insert(chunk) + + # d = 1536 # Dimension of text-ada-embedding-002 + # faiss_index = faiss.IndexFlatL2(d) + # index = GPTFaissIndex(documents, faiss_index=faiss_index) + # index.save_to_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index") + + index.storage_context.persist(persist_dir=self.index_dir) + + logger.debug("Codebase index created") + + def get_modified_deleted_files(self) -> Tuple[List[str], List[str]]: + """Get a list of all files that have been modified since the last commit.""" + metadata = f"{self.index_dir}/metadata.json" + with open(metadata, "r") as f: + previous_commit = json.load(f)["commit"] + + modified_deleted_files = ( + subprocess.check_output( + ["git", "diff", "--name-only", previous_commit, self.current_commit] + ) + .decode("utf-8") + .strip() + ) + modified_deleted_files = modified_deleted_files.split("\n") + modified_deleted_files = [f for f in modified_deleted_files if f] + + deleted_files = [ + f + for f in modified_deleted_files + if not os.path.exists(os.path.join(self.workspace_dir, f)) + ] + modified_files = [ + f + for f in modified_deleted_files + if os.path.exists(os.path.join(self.workspace_dir, f)) + ] + + return filter_ignored_files( + modified_files, self.index_dir + ), filter_ignored_files(deleted_files, self.index_dir) + + def update_codebase_index(self): + """Update the index with a list of files.""" + + if not self.check_index_exists(): + self.create_codebase_index() + else: + # index = GPTFaissIndex.load_from_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index") + index = GPTVectorStoreIndex.load_from_disk(f"{self.index_dir}/index.json") + modified_files, deleted_files = self.get_modified_deleted_files() + + with open(f"{self.index_dir}/metadata.json", "r") as f: + metadata = json.load(f) + + for file in deleted_files: + num_chunks = metadata["chunks"][file] + for i in range(num_chunks): + index.delete(f"{file}::{i}") + + del metadata["chunks"][file] + + logger.debug(f"Deleted {file}") + + for file in modified_files: + if file in metadata["chunks"]: + num_chunks = metadata["chunks"][file] + + for i in range(num_chunks): + index.delete(f"{file}::{i}") + + logger.debug(f"Deleted old version of {file}") + + with open(file, "r") as f: + text = f.read() + + text_splitter = TokenTextSplitter() + text_chunks = text_splitter.split_text(text) + + for i, text in enumerate(text_chunks): + index.insert(Document(text, doc_id=f"{file}::{i}")) + + metadata["chunks"][file] = len(text_chunks) + + logger.debug(f"Inserted new version of {file}") + + metadata["commit"] = self.current_commit + + with open(f"{self.index_dir}/metadata.json", "w") as f: + json.dump(metadata, f, indent=4) + + logger.debug("Codebase index updated") + + def query_codebase_index(self, query: str) -> str: + """Query the codebase index.""" + if not self.check_index_exists(): + logger.debug(f"No index found for the codebase at {self.index_dir}") + return "" + + storage_context = StorageContext.from_defaults(persist_dir=self.index_dir) + index = load_index_from_storage(storage_context) + # index = GPTVectorStoreIndex.load_from_disk(path) + engine = index.as_query_engine() + return engine.query(query) + + def query_additional_index(self, query: str) -> str: + """Query the additional index.""" + index = GPTVectorStoreIndex.load_from_disk( + os.path.join(self.index_dir, "additional_index.json") + ) + return index.query(query) + + def replace_additional_index(self, info: str): + """Replace the additional index with the given info.""" + with open(f"{self.index_dir}/additional_context.txt", "w") as f: + f.write(info) + documents = [Document(info)] + index = GPTVectorStoreIndex(documents) + index.save_to_disk(f"{self.index_dir}/additional_index.json") + logger.debug("Additional index replaced") diff --git a/server/continuedev/libs/chroma/update.py b/server/continuedev/libs/chroma/update.py new file mode 100644 index 00000000..7a1217f9 --- /dev/null +++ b/server/continuedev/libs/chroma/update.py @@ -0,0 +1,66 @@ +# import faiss +import os +import subprocess +from typing import List + +from dotenv import load_dotenv +from llama_index import Document, SimpleDirectoryReader + +load_dotenv() + +FILE_TYPES_TO_IGNORE = [".pyc", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico"] + + +def filter_ignored_files(files: List[str], root_dir: str): + """Further filter files before indexing.""" + for file in files: + if ( + file.endswith(tuple(FILE_TYPES_TO_IGNORE)) + or file.startswith(".git") + or file.startswith("archive") + ): + continue # nice + yield root_dir + "/" + file + + +def get_git_ignored_files(root_dir: str): + """Get the list of ignored files in a Git repository.""" + try: + output = ( + subprocess.check_output( + ["git", "ls-files", "--ignored", "--others", "--exclude-standard"], + cwd=root_dir, + ) + .strip() + .decode() + ) + return output.split("\n") + except subprocess.CalledProcessError: + return [] + + +def get_all_files(root_dir: str): + """Get a list of all files in a directory.""" + for dir_path, _, file_names in os.walk(root_dir): + for file_name in file_names: + yield os.path.join(os.path.relpath(dir_path, root_dir), file_name) + + +def get_input_files(root_dir: str): + """Get a list of all files in a Git repository that are not ignored.""" + ignored_files = set(get_git_ignored_files(root_dir)) + all_files = set(get_all_files(root_dir)) + nonignored_files = all_files - ignored_files + return filter_ignored_files(nonignored_files, root_dir) + + +def load_gpt_index_documents(root: str) -> List[Document]: + """Loads a list of GPTIndex Documents, respecting .gitignore files.""" + # Get input files + input_files = get_input_files(root) + # Use SimpleDirectoryReader to load the files into Documents + return SimpleDirectoryReader( + root, + input_files=input_files, + file_metadata=lambda filename: {"filename": filename}, + ).load_data() diff --git a/server/continuedev/libs/constants/default_config.py b/server/continuedev/libs/constants/default_config.py new file mode 100644 index 00000000..a007eef1 --- /dev/null +++ b/server/continuedev/libs/constants/default_config.py @@ -0,0 +1,88 @@ +default_config = """\ +\"\"\" +This is the Continue configuration file. + +See https://continue.dev/docs/customization to for documentation of the available options. +\"\"\" + +from continuedev.core.models import Models +from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig +from continuedev.libs.llm import OpenAIFreeTrial + +from continuedev.plugins.context_providers import ( + DiffContextProvider, + TerminalContextProvider, + URLContextProvider, + GitHubIssuesContextProvider +) +from continuedev.plugins.steps import ( + ClearHistoryStep, + CommentCodeStep, + EditHighlightedCodeStep, + GenerateShellCommandStep, + OpenConfigStep, +) +from continuedev.plugins.steps.share_session import ShareSessionStep + +config = ContinueConfig( + allow_anonymous_telemetry=True, + models=Models( + default=OpenAIFreeTrial(api_key="", model="gpt-4"), + summarize=OpenAIFreeTrial(api_key="", model="gpt-3.5-turbo") + ), + system_message=None, + temperature=0.5, + custom_commands=[ + CustomCommand( + name="test", + description="Write unit tests for highlighted code", + prompt="Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated. Give the tests just as chat output, don't edit any file.", + ) + ], + slash_commands=[ + SlashCommand( + name="edit", + description="Edit highlighted code", + step=EditHighlightedCodeStep, + ), + SlashCommand( + name="config", + description="Customize Continue", + step=OpenConfigStep, + ), + SlashCommand( + name="comment", + description="Write comments for the highlighted code", + step=CommentCodeStep, + ), + SlashCommand( + name="clear", + description="Clear step history", + step=ClearHistoryStep, + ), + SlashCommand( + name="share", + description="Download and share this session", + step=ShareSessionStep, + ), + SlashCommand( + name="cmd", + description="Generate a shell command", + step=GenerateShellCommandStep, + ), + ], + context_providers=[ + # GitHubIssuesContextProvider( + # repo_name="/", + # auth_token="" + # ), + DiffContextProvider(), + URLContextProvider( + preset_urls = [ + # Add any common urls you reference here so they appear in autocomplete + ] + ), + TerminalContextProvider(), + ], +) +""" diff --git a/server/continuedev/libs/constants/main.py b/server/continuedev/libs/constants/main.py new file mode 100644 index 00000000..f5964df6 --- /dev/null +++ b/server/continuedev/libs/constants/main.py @@ -0,0 +1,6 @@ +## PATHS ## + +CONTINUE_GLOBAL_FOLDER = ".continue" +CONTINUE_SESSIONS_FOLDER = "sessions" +CONTINUE_SERVER_FOLDER = "server" +CONTINUE_SERVER_VERSION_FILE = "server_version.txt" diff --git a/server/continuedev/libs/llm/__init__.py b/server/continuedev/libs/llm/__init__.py new file mode 100644 index 00000000..829ffede --- /dev/null +++ b/server/continuedev/libs/llm/__init__.py @@ -0,0 +1,14 @@ +from .anthropic import AnthropicLLM # noqa: F401 +from .ggml import GGML # noqa: F401 +from .google_palm_api import GooglePaLMAPI # noqa: F401 +from .hf_inference_api import HuggingFaceInferenceAPI # noqa: F401 +from .hf_tgi import HuggingFaceTGI # noqa: F401 +from .llamacpp import LlamaCpp # noqa: F401 +from .ollama import Ollama # noqa: F401 +from .openai import OpenAI # noqa: F401 +from .openai_free_trial import OpenAIFreeTrial # noqa: F401 +from .proxy_server import ProxyServer # noqa: F401 +from .queued import QueuedLLM # noqa: F401 +from .replicate import ReplicateLLM # noqa: F401 +from .text_gen_interface import TextGenUI # noqa: F401 +from .together import TogetherLLM # noqa: F401 diff --git a/server/continuedev/libs/llm/anthropic.py b/server/continuedev/libs/llm/anthropic.py new file mode 100644 index 00000000..7d0708f1 --- /dev/null +++ b/server/continuedev/libs/llm/anthropic.py @@ -0,0 +1,74 @@ +from typing import Any, Callable, Coroutine + +from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic + +from .base import LLM, CompletionOptions +from .prompts.chat import anthropic_template_messages + + +class AnthropicLLM(LLM): + """ + Import the `AnthropicLLM` class and set it as the default model: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.anthropic import AnthropicLLM + + config = ContinueConfig( + ... + models=Models( + default=AnthropicLLM(api_key="", model="claude-2") + ) + ) + ``` + + Claude 2 is not yet publicly released. You can request early access [here](https://www.anthropic.com/earlyaccess). + + """ + + api_key: str + "Anthropic API key" + + model: str = "claude-2" + + _async_client: AsyncAnthropic = None + + template_messages: Callable = anthropic_template_messages + + class Config: + arbitrary_types_allowed = True + + async def start(self, **kwargs): + await super().start(**kwargs) + self._async_client = AsyncAnthropic(api_key=self.api_key) + + if self.model == "claude-2": + self.context_length = 100_000 + + def collect_args(self, options: CompletionOptions): + options.stop = None + args = super().collect_args(options) + + if "max_tokens" in args: + args["max_tokens_to_sample"] = args["max_tokens"] + del args["max_tokens"] + if "frequency_penalty" in args: + del args["frequency_penalty"] + if "presence_penalty" in args: + del args["presence_penalty"] + return args + + async def _stream_complete(self, prompt: str, options): + args = self.collect_args(options) + prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" + + async for chunk in await self._async_client.completions.create( + prompt=prompt, stream=True, **args + ): + yield chunk.completion + + async def _complete(self, prompt: str, options) -> Coroutine[Any, Any, str]: + args = self.collect_args(options) + prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" + return ( + await self._async_client.completions.create(prompt=prompt, **args) + ).completion diff --git a/server/continuedev/libs/llm/base.py b/server/continuedev/libs/llm/base.py new file mode 100644 index 00000000..d77cb9fc --- /dev/null +++ b/server/continuedev/libs/llm/base.py @@ -0,0 +1,458 @@ +import ssl +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union + +import aiohttp +import certifi +from pydantic import Field, validator + +from ...core.main import ChatMessage +from ...models.main import ContinueBaseModel +from ..util.count_tokens import ( + DEFAULT_ARGS, + DEFAULT_MAX_TOKENS, + compile_chat_messages, + count_tokens, + format_chat_messages, + prune_raw_prompt_from_top, +) +from ..util.devdata import dev_data_logger +from ..util.telemetry import posthog_logger + + +class CompletionOptions(ContinueBaseModel): + """Options for the completion.""" + + @validator( + "*", + pre=True, + always=True, + ) + def ignore_none_and_set_default(cls, value, field): + return value if value is not None else field.default + + model: Optional[str] = Field(None, description="The model name") + temperature: Optional[float] = Field( + None, description="The temperature of the completion." + ) + top_p: Optional[float] = Field(None, description="The top_p of the completion.") + top_k: Optional[int] = Field(None, description="The top_k of the completion.") + presence_penalty: Optional[float] = Field( + None, description="The presence penalty Aof the completion." + ) + frequency_penalty: Optional[float] = Field( + None, description="The frequency penalty of the completion." + ) + stop: Optional[List[str]] = Field( + None, description="The stop tokens of the completion." + ) + max_tokens: int = Field( + DEFAULT_MAX_TOKENS, description="The maximum number of tokens to generate." + ) + functions: Optional[List[Any]] = Field( + None, description="The functions/tools to make available to the model." + ) + + +class LLM(ContinueBaseModel): + title: Optional[str] = Field( + None, + description="A title that will identify this model in the model selection dropdown", + ) + + unique_id: Optional[str] = Field(None, description="The unique ID of the user.") + model: str = Field( + ..., description="The name of the model to be used (e.g. gpt-4, codellama)" + ) + + system_message: Optional[str] = Field( + None, description="A system message that will always be followed by the LLM" + ) + + context_length: int = Field( + 2048, + description="The maximum context length of the LLM in tokens, as counted by count_tokens.", + ) + + stop_tokens: Optional[List[str]] = Field( + None, description="Tokens that will stop the completion." + ) + temperature: Optional[float] = Field( + None, description="The temperature of the completion." + ) + top_p: Optional[float] = Field(None, description="The top_p of the completion.") + top_k: Optional[int] = Field(None, description="The top_k of the completion.") + presence_penalty: Optional[float] = Field( + None, description="The presence penalty Aof the completion." + ) + frequency_penalty: Optional[float] = Field( + None, description="The frequency penalty of the completion." + ) + + timeout: Optional[int] = Field( + 300, + description="Set the timeout for each request to the LLM. If you are running a local LLM that takes a while to respond, you might want to set this to avoid timeouts.", + ) + verify_ssl: Optional[bool] = Field( + None, description="Whether to verify SSL certificates for requests." + ) + ca_bundle_path: str = Field( + None, + description="Path to a custom CA bundle to use when making the HTTP request", + ) + proxy: Optional[str] = Field( + None, + description="Proxy URL to use when making the HTTP request", + ) + headers: Optional[Dict[str, str]] = Field( + None, + description="Headers to use when making the HTTP request", + ) + prompt_templates: dict = Field( + {}, + description='A dictionary of prompt templates that can be used to customize the behavior of the LLM in certain situations. For example, set the "edit" key in order to change the prompt that is used for the /edit slash command. Each value in the dictionary is a string templated in mustache syntax, and filled in at runtime with the variables specific to the situation. See the documentation for more information.', + ) + + template_messages: Optional[Callable[[List[Dict[str, str]]], str]] = Field( + None, + description="A function that takes a list of messages and returns a prompt. This ensures that models like llama2, which are trained on specific chat formats, will always receive input in that format.", + ) + write_log: Optional[Callable[[str], None]] = Field( + None, + description="A function that is called upon every prompt and completion, by default to log to the file which can be viewed by clicking on the magnifying glass.", + ) + + api_key: Optional[str] = Field( + None, description="The API key for the LLM provider." + ) + + class Config: + arbitrary_types_allowed = True + extra = "allow" + fields = { + "title": { + "description": "A title that will identify this model in the model selection dropdown" + }, + "system_message": { + "description": "A system message that will always be followed by the LLM" + }, + "context_length": { + "description": "The maximum context length of the LLM in tokens, as counted by count_tokens." + }, + "unique_id": {"description": "The unique ID of the user."}, + "model": { + "description": "The name of the model to be used (e.g. gpt-4, codellama)" + }, + "timeout": { + "description": "Set the timeout for each request to the LLM. If you are running a local LLM that takes a while to respond, you might want to set this to avoid timeouts." + }, + "prompt_templates": { + "description": 'A dictionary of prompt templates that can be used to customize the behavior of the LLM in certain situations. For example, set the "edit" key in order to change the prompt that is used for the /edit slash command. Each value in the dictionary is a string templated in mustache syntax, and filled in at runtime with the variables specific to the situation. See the documentation for more information.' + }, + "template_messages": { + "description": "A function that takes a list of messages and returns a prompt. This ensures that models like llama2, which are trained on specific chat formats, will always receive input in that format." + }, + "write_log": { + "description": "A function that is called upon every prompt and completion, by default to log to the file which can be viewed by clicking on the magnifying glass." + }, + "api_key": {"description": "The API key for the LLM provider."}, + "verify_ssl": { + "description": "Whether to verify SSL certificates for requests." + }, + "ca_bundle_path": { + "description": "Path to a custom CA bundle to use when making the HTTP request" + }, + "headers": { + "description": "Headers to use when making the HTTP request" + }, + "proxy": {"description": "Proxy URL to use when making the HTTP request"}, + "stop_tokens": {"description": "Tokens that will stop the completion."}, + "temperature": { + "description": "The sampling temperature used for generation." + }, + "top_p": { + "description": "The top_p sampling parameter used for generation." + }, + "top_k": { + "description": "The top_k sampling parameter used for generation." + }, + "presence_penalty": { + "description": "The presence penalty used for completions." + }, + "frequency_penalty": { + "description": "The frequency penalty used for completions." + }, + } + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("write_log") + if self.template_messages is not None: + original_dict["template_messages"] = self.template_messages.__name__ + original_dict.pop("unique_id") + original_dict["class_name"] = self.__class__.__name__ + return original_dict + + async def start( + self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None + ): + """Start the connection to the LLM.""" + self.write_log = write_log + self.unique_id = unique_id + + async def stop(self): + """Stop the connection to the LLM.""" + pass + + def create_client_session(self): + if self.verify_ssl is False: + return aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=False), + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers=self.headers + ) + else: + ca_bundle_path = ( + certifi.where() if self.ca_bundle_path is None else self.ca_bundle_path + ) + ssl_context = ssl.create_default_context(cafile=ca_bundle_path) + return aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl_context=ssl_context), + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers=self.headers, + ) + + def collect_args(self, options: CompletionOptions) -> Dict[str, Any]: + """Collect the arguments for the LLM.""" + args = {**DEFAULT_ARGS.copy(), "model": self.model} + args.update(options.dict(exclude_unset=True, exclude_none=True)) + return args + + def compile_chat_messages( + self, + options: CompletionOptions, + msgs: List[ChatMessage], + functions: Optional[List[Any]] = None, + ) -> List[Dict]: + return compile_chat_messages( + model_name=options.model, + msgs=msgs, + context_length=self.context_length, + max_tokens=options.max_tokens, + functions=functions, + system_message=self.system_message, + ) + + def template_prompt_like_messages(self, prompt: str) -> str: + if self.template_messages is None: + return prompt + + msgs = [{"role": "user", "content": prompt}] + if self.system_message is not None: + msgs.insert(0, {"role": "system", "content": self.system_message}) + + return self.template_messages(msgs) + + async def stream_complete( + self, + prompt: str, + raw: bool = False, + model: str = None, + temperature: float = None, + top_p: float = None, + top_k: int = None, + presence_penalty: float = None, + frequency_penalty: float = None, + stop: Optional[List[str]] = None, + max_tokens: Optional[int] = None, + functions: Optional[List[Any]] = None, + log: bool = True, + ) -> Generator[Union[Any, List, Dict], None, None]: + """Yield completion response, either streamed or not.""" + options = CompletionOptions( + model=model or self.model, + temperature=temperature or self.temperature, + top_p=top_p or self.top_p, + top_k=top_k or self.top_k, + presence_penalty=presence_penalty or self.presence_penalty, + frequency_penalty=frequency_penalty or self.frequency_penalty, + stop=stop or self.stop_tokens, + max_tokens=max_tokens, + functions=functions, + ) + + prompt = prune_raw_prompt_from_top( + self.model, self.context_length, prompt, options.max_tokens + ) + + if not raw: + prompt = self.template_prompt_like_messages(prompt) + + if log: + self.write_log(prompt) + + completion = "" + async for chunk in self._stream_complete(prompt=prompt, options=options): + yield chunk + completion += chunk + + # if log: + # self.write_log(f"Completion: \n\n{completion}") + + dev_data_logger.capture( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + posthog_logger.capture_event( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + + async def complete( + self, + prompt: str, + raw: bool = False, + model: str = None, + temperature: float = None, + top_p: float = None, + top_k: int = None, + presence_penalty: float = None, + frequency_penalty: float = None, + stop: Optional[List[str]] = None, + max_tokens: Optional[int] = None, + functions: Optional[List[Any]] = None, + log: bool = True, + ) -> str: + """Yield completion response, either streamed or not.""" + options = CompletionOptions( + model=model or self.model, + temperature=temperature or self.temperature, + top_p=top_p or self.top_p, + top_k=top_k or self.top_k, + presence_penalty=presence_penalty or self.presence_penalty, + frequency_penalty=frequency_penalty or self.frequency_penalty, + stop=stop or self.stop_tokens, + max_tokens=max_tokens, + functions=functions, + ) + + prompt = prune_raw_prompt_from_top( + self.model, self.context_length, prompt, options.max_tokens + ) + + if not raw: + prompt = self.template_prompt_like_messages(prompt) + + if log: + self.write_log(prompt) + + completion = await self._complete(prompt=prompt, options=options) + + # if log: + # self.write_log(f"Completion: \n\n{completion}") + + dev_data_logger.capture( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + posthog_logger.capture_event( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + + return completion + + async def stream_chat( + self, + messages: List[ChatMessage], + model: str = None, + temperature: float = None, + top_p: float = None, + top_k: int = None, + presence_penalty: float = None, + frequency_penalty: float = None, + stop: Optional[List[str]] = None, + max_tokens: Optional[int] = None, + functions: Optional[List[Any]] = None, + log: bool = True, + ) -> Generator[Union[Any, List, Dict], None, None]: + """Yield completion response, either streamed or not.""" + options = CompletionOptions( + model=model or self.model, + temperature=temperature or self.temperature, + top_p=top_p or self.top_p, + top_k=top_k or self.top_k, + presence_penalty=presence_penalty or self.presence_penalty, + frequency_penalty=frequency_penalty or self.frequency_penalty, + stop=stop or self.stop_tokens, + max_tokens=max_tokens, + functions=functions, + ) + + messages = self.compile_chat_messages( + options=options, msgs=messages, functions=functions + ) + if self.template_messages is not None: + prompt = self.template_messages(messages) + else: + prompt = format_chat_messages(messages) + + if log: + self.write_log(prompt) + + completion = "" + + # Use the template_messages function if it exists and do a raw completion + if self.template_messages is None: + async for chunk in self._stream_chat(messages=messages, options=options): + yield chunk + if "content" in chunk: + completion += chunk["content"] + else: + async for chunk in self._stream_complete(prompt=prompt, options=options): + yield {"role": "assistant", "content": chunk} + completion += chunk + + # if log: + # self.write_log(f"Completion: \n\n{completion}") + + dev_data_logger.capture( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + posthog_logger.capture_event( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + + def _stream_complete( + self, prompt, options: CompletionOptions + ) -> Generator[str, None, None]: + """Stream the completion through generator.""" + raise NotImplementedError + + async def _complete( + self, prompt: str, options: CompletionOptions + ) -> Coroutine[Any, Any, str]: + """Return the completion of the text with the given temperature.""" + completion = "" + async for chunk in self._stream_complete(prompt=prompt, options=options): + completion += chunk + return completion + + async def _stream_chat( + self, messages: List[ChatMessage], options: CompletionOptions + ) -> Generator[Union[Any, List, Dict], None, None]: + """Stream the chat through generator.""" + if self.template_messages is None: + raise NotImplementedError( + "You must either implement template_messages or _stream_chat" + ) + + async for chunk in self._stream_complete( + prompt=self.template_messages(messages), options=options + ): + yield {"role": "assistant", "content": chunk} + + def count_tokens(self, text: str): + """Return the number of tokens in the given text.""" + return count_tokens(self.model, text) diff --git a/server/continuedev/libs/llm/ggml.py b/server/continuedev/libs/llm/ggml.py new file mode 100644 index 00000000..55d580a8 --- /dev/null +++ b/server/continuedev/libs/llm/ggml.py @@ -0,0 +1,226 @@ +import json +from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional + +from pydantic import Field + +from ...core.main import ChatMessage +from ..util.logging import logger +from .base import LLM, CompletionOptions +from .openai import CHAT_MODELS +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class GGML(LLM): + """ + See our [5 minute quickstart](https://github.com/continuedev/ggml-server-example) to run any model locally with ggml. While these models don't yet perform as well, they are free, entirely private, and run offline. + + Once the model is running on localhost:8000, change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.ggml import GGML + + config = ContinueConfig( + ... + models=Models( + default=GGML( + max_context_length=2048, + server_url="http://localhost:8000") + ) + ) + ``` + """ + + server_url: str = Field( + "http://localhost:8000", + description="URL of the OpenAI-compatible server where the model is being served", + ) + model: str = Field( + "ggml", description="The name of the model to use (optional for the GGML class)" + ) + + api_base: Optional[str] = Field(None, description="OpenAI API base URL.") + + api_type: Optional[Literal["azure", "openai"]] = Field( + None, description="OpenAI API type." + ) + + api_version: Optional[str] = Field( + None, description="OpenAI API version. For use with Azure OpenAI Service." + ) + + engine: Optional[str] = Field( + None, description="OpenAI engine. For use with Azure OpenAI Service." + ) + + template_messages: Optional[ + Callable[[List[Dict[str, str]]], str] + ] = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + class Config: + arbitrary_types_allowed = True + + def get_headers(self): + headers = { + "Content-Type": "application/json", + } + if self.api_key is not None: + if self.api_type == "azure": + headers["api-key"] = self.api_key + else: + headers["Authorization"] = f"Bearer {self.api_key}" + + return headers + + def get_full_server_url(self, endpoint: str): + endpoint = endpoint.lstrip("/").rstrip("/") + + if self.api_type == "azure": + if self.engine is None or self.api_version is None or self.api_base is None: + raise Exception( + "For Azure OpenAI Service, you must specify engine, api_version, and api_base." + ) + + return f"{self.api_base}/openai/deployments/{self.engine}/{endpoint}?api-version={self.api_version}" + else: + return f"{self.server_url}/v1/{endpoint}" + + async def _raw_stream_complete(self, prompt, options): + args = self.collect_args(options) + + async with self.create_client_session() as client_session: + async with client_session.post( + self.get_full_server_url(endpoint="completions"), + json={ + "prompt": prompt, + "stream": True, + **args, + }, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + if resp.status != 200: + raise Exception( + f"Error calling /chat/completions endpoint: {resp.status}" + ) + + async for line in resp.content.iter_any(): + if line: + chunks = line.decode("utf-8") + for chunk in chunks.split("\n"): + if ( + chunk.startswith(": ping - ") + or chunk.startswith("data: [DONE]") + or chunk.strip() == "" + ): + continue + elif chunk.startswith("data: "): + chunk = chunk[6:] + try: + j = json.loads(chunk) + except Exception: + continue + if ( + "choices" in j + and len(j["choices"]) > 0 + and "text" in j["choices"][0] + ): + yield j["choices"][0]["text"] + + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) + + async def generator(): + async with self.create_client_session() as client_session: + async with client_session.post( + self.get_full_server_url(endpoint="chat/completions"), + json={"messages": messages, "stream": True, **args}, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + if resp.status != 200: + raise Exception( + f"Error calling /chat/completions endpoint: {resp.status}" + ) + + async for line, end in resp.content.iter_chunks(): + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if ( + chunk.strip() == "" + or json_chunk.startswith(": ping - ") + or json_chunk.startswith("data: [DONE]") + ): + continue + try: + yield json.loads(chunk[6:])["choices"][0]["delta"] + except: + pass + + # Because quite often the first attempt fails, and it works thereafter + try: + async for chunk in generator(): + yield chunk + except Exception as e: + logger.warning(f"Error calling /chat/completions endpoint: {e}") + async for chunk in generator(): + yield chunk + + async def _raw_complete(self, prompt: str, options) -> Coroutine[Any, Any, str]: + args = self.collect_args(options) + + async with self.create_client_session() as client_session: + async with client_session.post( + self.get_full_server_url(endpoint="completions"), + json={ + "prompt": prompt, + **args, + }, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + if resp.status != 200: + raise Exception( + f"Error calling /chat/completions endpoint: {resp.status}" + ) + + text = await resp.text() + try: + completion = json.loads(text)["choices"][0]["text"] + return completion + except Exception as e: + raise Exception( + f"Error calling /completion endpoint: {e}\n\nResponse text: {text}" + ) + + async def _complete(self, prompt: str, options: CompletionOptions): + completion = "" + if self.model in CHAT_MODELS: + async for chunk in self._stream_chat( + [{"role": "user", "content": prompt}], options + ): + if "content" in chunk: + completion += chunk["content"] + + else: + async for chunk in self._raw_stream_complete(prompt, options): + completion += chunk + + return completion + + async def _stream_complete(self, prompt, options: CompletionOptions): + if self.model in CHAT_MODELS: + async for chunk in self._stream_chat( + [{"role": "user", "content": prompt}], options + ): + if "content" in chunk: + yield chunk["content"] + + else: + async for chunk in self._raw_stream_complete(prompt, options): + yield chunk diff --git a/server/continuedev/libs/llm/google_palm_api.py b/server/continuedev/libs/llm/google_palm_api.py new file mode 100644 index 00000000..3379fefe --- /dev/null +++ b/server/continuedev/libs/llm/google_palm_api.py @@ -0,0 +1,50 @@ +from typing import List + +import requests +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM + + +class GooglePaLMAPI(LLM): + """ + The Google PaLM API is currently in public preview, so production applications are not supported yet. However, you can [create an API key in Google MakerSuite](https://makersuite.google.com/u/2/app/apikey) and begin trying out the `chat-bison-001` model. Change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.core.models import Models + from continuedev.libs.llm.hf_inference_api import GooglePaLMAPI + + config = ContinueConfig( + ... + models=Models( + default=GooglePaLMAPI( + model="chat-bison-001" + api_key="", + ) + ) + ``` + """ + + api_key: str = Field(..., description="Google PaLM API key") + + model: str = "chat-bison-001" + + async def _stream_complete(self, prompt, options): + api_url = f"https://generativelanguage.googleapis.com/v1beta2/models/{self.model}:generateMessage?key={self.api_key}" + body = {"prompt": {"messages": [{"content": prompt}]}} + response = requests.post(api_url, json=body) + yield response.json()["candidates"][0]["content"] + + async def _stream_chat(self, messages: List[ChatMessage], options): + msg_lst = [] + for message in messages: + msg_lst.append({"content": message["content"]}) + + api_url = f"https://generativelanguage.googleapis.com/v1beta2/models/{self.model}:generateMessage?key={self.api_key}" + body = {"prompt": {"messages": msg_lst}} + response = requests.post(api_url, json=body) + yield { + "content": response.json()["candidates"][0]["content"], + "role": "assistant", + } diff --git a/server/continuedev/libs/llm/hf_inference_api.py b/server/continuedev/libs/llm/hf_inference_api.py new file mode 100644 index 00000000..990ec7c8 --- /dev/null +++ b/server/continuedev/libs/llm/hf_inference_api.py @@ -0,0 +1,78 @@ +from typing import Callable, Dict, List, Union + +from huggingface_hub import InferenceClient +from pydantic import Field + +from .base import LLM, CompletionOptions +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class HuggingFaceInferenceAPI(LLM): + """ + Hugging Face Inference API is a great option for newly released language models. Sign up for an account and add billing [here](https://huggingface.co/settings/billing), access the Inference Endpoints [here](https://ui.endpoints.huggingface.co), click on โ€œNew endpointโ€, and fill out the form (e.g. select a model like [WizardCoder-Python-34B-V1.0](https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0)), and then deploy your model by clicking โ€œCreate Endpointโ€. Change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.core.models import Models + from continuedev.libs.llm.hf_inference_api import HuggingFaceInferenceAPI + + config = ContinueConfig( + ... + models=Models( + default=HuggingFaceInferenceAPI( + endpoint_url="", + hf_token="", + ) + ) + ``` + """ + + model: str = Field( + "Hugging Face Inference API", + description="The name of the model to use (optional for the HuggingFaceInferenceAPI class)", + ) + hf_token: str = Field(..., description="Your Hugging Face API token") + endpoint_url: str = Field( + None, description="Your Hugging Face Inference API endpoint URL" + ) + + template_messages: Union[ + Callable[[List[Dict[str, str]]], str], None + ] = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + class Config: + arbitrary_types_allowed = True + + def collect_args(self, options: CompletionOptions): + options.stop = None + args = super().collect_args(options) + + if "max_tokens" in args: + args["max_new_tokens"] = args["max_tokens"] + del args["max_tokens"] + if "stop" in args: + args["stop_sequences"] = args["stop"] + del args["stop"] + + return args + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + + client = InferenceClient(self.endpoint_url, token=self.hf_token) + + stream = client.text_generation(prompt, stream=True, details=True, **args) + + for r in stream: + # skip special tokens + if r.token.special: + continue + # stop if we encounter a stop sequence + if options.stop is not None: + if r.token.text in options.stop: + break + yield r.token.text diff --git a/server/continuedev/libs/llm/hf_tgi.py b/server/continuedev/libs/llm/hf_tgi.py new file mode 100644 index 00000000..62458db4 --- /dev/null +++ b/server/continuedev/libs/llm/hf_tgi.py @@ -0,0 +1,65 @@ +import json +from typing import Any, Callable, List + +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM, CompletionOptions +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class HuggingFaceTGI(LLM): + model: str = "huggingface-tgi" + server_url: str = Field( + "http://localhost:8080", description="URL of your TGI server" + ) + + template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + class Config: + arbitrary_types_allowed = True + + def collect_args(self, options: CompletionOptions) -> Any: + args = super().collect_args(options) + args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1} + args.pop("max_tokens", None) + args.pop("model", None) + args.pop("functions", None) + return args + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + + async with self.create_client_session() as client_session: + async with client_session.post( + f"{self.server_url}/generate_stream", + json={"inputs": prompt, "parameters": args}, + headers={"Content-Type": "application/json"}, + proxy=self.proxy, + ) as resp: + async for line in resp.content.iter_any(): + if line: + text = line.decode("utf-8") + chunks = text.split("\n") + + for chunk in chunks: + if chunk.startswith("data: "): + chunk = chunk[len("data: ") :] + elif chunk.startswith("data:"): + chunk = chunk[len("data:") :] + + if chunk.strip() == "": + continue + + try: + json_chunk = json.loads(chunk) + except Exception as e: + print(f"Error parsing JSON: {e}") + continue + + yield json_chunk["token"]["text"] diff --git a/server/continuedev/libs/llm/hugging_face.py b/server/continuedev/libs/llm/hugging_face.py new file mode 100644 index 00000000..c2e934c0 --- /dev/null +++ b/server/continuedev/libs/llm/hugging_face.py @@ -0,0 +1,19 @@ +# TODO: This class is far out of date + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .llm import LLM + + +class HuggingFace(LLM): + def __init__(self, model_path: str = "Salesforce/codegen-2B-mono"): + self.model_path = model_path + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = AutoModelForCausalLM.from_pretrained(model_path) + + def complete(self, prompt: str, **kwargs): + args = {"max_tokens": 100} + args.update(kwargs) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + generated_ids = self.model.generate(input_ids, max_length=args["max_tokens"]) + return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) diff --git a/server/continuedev/libs/llm/llamacpp.py b/server/continuedev/libs/llm/llamacpp.py new file mode 100644 index 00000000..bc856a52 --- /dev/null +++ b/server/continuedev/libs/llm/llamacpp.py @@ -0,0 +1,86 @@ +import json +from typing import Any, Callable, Dict + +from pydantic import Field + +from .base import LLM +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class LlamaCpp(LLM): + """ + Run the llama.cpp server binary to start the API server. If running on a remote server, be sure to set host to 0.0.0.0: + + ```shell + .\server.exe -c 4096 --host 0.0.0.0 -t 16 --mlock -m models\meta\llama\codellama-7b-instruct.Q8_0.gguf + ``` + + After it's up and running, change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.llamacpp import LlamaCpp + + config = ContinueConfig( + ... + models=Models( + default=LlamaCpp( + max_context_length=4096, + server_url="http://localhost:8080") + ) + ) + ``` + """ + + model: str = "llamacpp" + server_url: str = Field("http://localhost:8080", description="URL of the server") + + llama_cpp_args: Dict[str, Any] = Field( + {"stop": ["[INST]"]}, + description="A list of additional arguments to pass to llama.cpp. See [here](https://github.com/ggerganov/llama.cpp/tree/master/examples/server#api-endpoints) for the complete catalog of options.", + ) + + template_messages: Callable = llama2_template_messages + prompt_templates = { + "edit": simplified_edit_prompt, + } + + class Config: + arbitrary_types_allowed = True + + def collect_args(self, options) -> Any: + args = super().collect_args(options) + if "max_tokens" in args: + args["n_predict"] = args["max_tokens"] + del args["max_tokens"] + if "frequency_penalty" in args: + del args["frequency_penalty"] + if "presence_penalty" in args: + del args["presence_penalty"] + + for k, v in self.llama_cpp_args.items(): + if k not in args: + args[k] = v + + return args + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + headers = {"Content-Type": "application/json"} + + async def server_generator(): + async with self.create_client_session() as client_session: + async with client_session.post( + f"{self.server_url}/completion", + json={"prompt": prompt, "stream": True, **args}, + headers=headers, + proxy=self.proxy, + ) as resp: + async for line in resp.content: + content = line.decode("utf-8") + if content.strip() == "": + continue + yield json.loads(content[6:])["content"] + + async for chunk in server_generator(): + yield chunk diff --git a/server/continuedev/libs/llm/ollama.py b/server/continuedev/libs/llm/ollama.py new file mode 100644 index 00000000..82cbc852 --- /dev/null +++ b/server/continuedev/libs/llm/ollama.py @@ -0,0 +1,106 @@ +import json +from typing import Callable + +import aiohttp +from pydantic import Field + +from ...core.main import ContinueCustomException +from ..util.logging import logger +from .base import LLM +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class Ollama(LLM): + """ + [Ollama](https://ollama.ai/) is an application for Mac and Linux that makes it easy to locally run open-source models, including Llama-2. Download the app from the website, and it will walk you through setup in a couple of minutes. You can also read more in their [README](https://github.com/jmorganca/ollama). Continue can then be configured to use the `Ollama` LLM class: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.ollama import Ollama + + config = ContinueConfig( + ... + models=Models( + default=Ollama(model="llama2") + ) + ) + ``` + """ + + model: str = "llama2" + server_url: str = Field( + "http://localhost:11434", description="URL of the Ollama server" + ) + + _client_session: aiohttp.ClientSession = None + + template_messages: Callable = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + class Config: + arbitrary_types_allowed = True + + async def start(self, **kwargs): + await super().start(**kwargs) + self._client_session = self.create_client_session() + try: + async with self._client_session.post( + f"{self.server_url}/api/generate", + proxy=self.proxy, + json={ + "prompt": "", + "model": self.model, + }, + ) as _: + pass + except Exception as e: + logger.warning(f"Error pre-loading Ollama model: {e}") + + async def stop(self): + await self._client_session.close() + + async def get_downloaded_models(self): + async with self._client_session.get( + f"{self.server_url}/api/tags", + proxy=self.proxy, + ) as resp: + js_data = await resp.json() + return list(map(lambda x: x["name"], js_data["models"])) + + async def _stream_complete(self, prompt, options): + async with self._client_session.post( + f"{self.server_url}/api/generate", + json={ + "template": prompt, + "model": self.model, + "system": self.system_message, + "options": {"temperature": options.temperature}, + }, + proxy=self.proxy, + ) as resp: + if resp.status == 400: + txt = await resp.text() + extra_msg = "" + if "no such file" in txt: + extra_msg = f"\n\nThis means that the model '{self.model}' is not downloaded.\n\nYou have the following models downloaded: {', '.join(await self.get_downloaded_models())}.\n\nTo download this model, run `ollama run {self.model}` in your terminal." + raise ContinueCustomException( + f"Ollama returned an error: {txt}{extra_msg}", + "Invalid request to Ollama", + ) + elif resp.status != 200: + raise ContinueCustomException( + f"Ollama returned an error: {await resp.text()}", + "Invalid request to Ollama", + ) + async for line in resp.content.iter_any(): + if line: + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + yield j["response"] diff --git a/server/continuedev/libs/llm/openai.py b/server/continuedev/libs/llm/openai.py new file mode 100644 index 00000000..ba29279b --- /dev/null +++ b/server/continuedev/libs/llm/openai.py @@ -0,0 +1,156 @@ +from typing import Callable, List, Literal, Optional + +import certifi +import openai +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM + +CHAT_MODELS = { + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-4", + "gpt-3.5-turbo-0613", + "gpt-4-32k", +} +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16_384, + "gpt-4": 8192, + "gpt-35-turbo-16k": 16_384, + "gpt-35-turbo-0613": 4096, + "gpt-35-turbo": 4096, + "gpt-4-32k": 32_768, +} + + +class OpenAI(LLM): + """ + The OpenAI class can be used to access OpenAI models like gpt-4 and gpt-3.5-turbo. + + If you are locally serving a model that uses an OpenAI-compatible server, you can simply change the `api_base` in the `OpenAI` class like this: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.openai import OpenAI + + config = ContinueConfig( + ... + models=Models( + default=OpenAI( + api_key="EMPTY", + model="", + api_base="http://localhost:8000", # change to your server + ) + ) + ) + ``` + + Options for serving models locally with an OpenAI-compatible server include: + + - [text-gen-webui](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai#setup--installation) + - [FastChat](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) + - [LocalAI](https://localai.io/basics/getting_started/) + - [llama-cpp-python](https://github.com/abetlen/llama-cpp-python#web-server) + """ + + api_key: str = Field( + ..., + description="OpenAI API key", + ) + + proxy: Optional[str] = Field(None, description="Proxy URL to use for requests.") + + api_base: Optional[str] = Field(None, description="OpenAI API base URL.") + + api_type: Optional[Literal["azure", "openai"]] = Field( + None, description="OpenAI API type." + ) + + api_version: Optional[str] = Field( + None, description="OpenAI API version. For use with Azure OpenAI Service." + ) + + engine: Optional[str] = Field( + None, description="OpenAI engine. For use with Azure OpenAI Service." + ) + + async def start( + self, unique_id: Optional[str] = None, write_log: Callable[[str], None] = None + ): + await super().start(write_log=write_log, unique_id=unique_id) + + if self.context_length is None: + self.context_length = MAX_TOKENS_FOR_MODEL.get(self.model, 4096) + + openai.api_key = self.api_key + if self.api_type is not None: + openai.api_type = self.api_type + if self.api_base is not None: + openai.api_base = self.api_base + if self.api_version is not None: + openai.api_version = self.api_version + + if self.verify_ssl is not None and self.verify_ssl is False: + openai.verify_ssl_certs = False + + if self.proxy is not None: + openai.proxy = self.proxy + + openai.ca_bundle_path = self.ca_bundle_path or certifi.where() + + def collect_args(self, options): + args = super().collect_args(options) + if self.engine is not None: + args["engine"] = self.engine + + if not args["model"].endswith("0613") and "functions" in args: + del args["functions"] + + return args + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + args["stream"] = True + + if args["model"] in CHAT_MODELS: + async for chunk in await openai.ChatCompletion.acreate( + messages=[{"role": "user", "content": prompt}], + **args, + headers=self.headers, + ): + if len(chunk.choices) > 0 and "content" in chunk.choices[0].delta: + yield chunk.choices[0].delta.content + else: + async for chunk in await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers): + if len(chunk.choices) > 0: + yield chunk.choices[0].text + + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) + + async for chunk in await openai.ChatCompletion.acreate( + messages=messages, + stream=True, + **args, + headers=self.headers, + ): + if not hasattr(chunk, "choices") or len(chunk.choices) == 0: + continue + yield chunk.choices[0].delta + + async def _complete(self, prompt: str, options): + args = self.collect_args(options) + + if args["model"] in CHAT_MODELS: + resp = await openai.ChatCompletion.acreate( + messages=[{"role": "user", "content": prompt}], + **args, + headers=self.headers, + ) + return resp.choices[0].message.content + else: + return ( + (await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers)).choices[0].text + ) diff --git a/server/continuedev/libs/llm/openai_free_trial.py b/server/continuedev/libs/llm/openai_free_trial.py new file mode 100644 index 00000000..b6e707f9 --- /dev/null +++ b/server/continuedev/libs/llm/openai_free_trial.py @@ -0,0 +1,83 @@ +from typing import Callable, List, Optional + +from ...core.main import ChatMessage +from .base import LLM +from .openai import OpenAI +from .proxy_server import ProxyServer + + +class OpenAIFreeTrial(LLM): + """ + With the `OpenAIFreeTrial` `LLM`, new users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code. + + Once you are using Continue regularly though, you will need to add an OpenAI API key that has access to GPT-4 by following these steps: + + 1. Copy your API key from https://platform.openai.com/account/api-keys + 2. Open `~/.continue/config.py`. You can do this by using the '/config' command in Continue + 3. Change the default LLMs to look like this: + + ```python title="~/.continue/config.py" + API_KEY = "" + config = ContinueConfig( + ... + models=Models( + default=OpenAIFreeTrial(model="gpt-4", api_key=API_KEY), + summarize=OpenAIFreeTrial(model="gpt-3.5-turbo", api_key=API_KEY) + ) + ) + ``` + + The `OpenAIFreeTrial` class will automatically switch to using your API key instead of ours. If you'd like to explicitly use one or the other, you can use the `ProxyServer` or `OpenAI` classes instead. + + These classes support any models available through the OpenAI API, assuming your API key has access, including "gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", and "gpt-4-32k". + """ + + api_key: Optional[str] = None + + llm: Optional[LLM] = None + + def update_llm_properties(self): + if self.llm is not None: + self.llm.system_message = self.system_message + + async def start( + self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None + ): + await super().start(write_log=write_log, unique_id=unique_id) + if self.api_key is None or self.api_key.strip() == "": + self.llm = ProxyServer( + model=self.model, + verify_ssl=self.verify_ssl, + ca_bundle_path=self.ca_bundle_path, + ) + else: + self.llm = OpenAI( + api_key=self.api_key, + model=self.model, + verify_ssl=self.verify_ssl, + ca_bundle_path=self.ca_bundle_path, + ) + + await self.llm.start(write_log=write_log, unique_id=unique_id) + + async def stop(self): + await self.llm.stop() + + async def _complete(self, prompt: str, options): + self.update_llm_properties() + return await self.llm._complete(prompt, options) + + async def _stream_complete(self, prompt, options): + self.update_llm_properties() + resp = self.llm._stream_complete(prompt, options) + async for item in resp: + yield item + + async def _stream_chat(self, messages: List[ChatMessage], options): + self.update_llm_properties() + resp = self.llm._stream_chat(messages=messages, options=options) + async for item in resp: + yield item + + def count_tokens(self, text: str): + return self.llm.count_tokens(text) diff --git a/server/continuedev/libs/llm/prompt_utils.py b/server/continuedev/libs/llm/prompt_utils.py new file mode 100644 index 00000000..930b5220 --- /dev/null +++ b/server/continuedev/libs/llm/prompt_utils.py @@ -0,0 +1,76 @@ +from typing import Dict, List, Union + +from ...models.filesystem import RangeInFileWithContents +from ...models.filesystem_edit import FileEdit + + +class MarkdownStyleEncoderDecoder: + # Filename -> the part of the file you care about + range_in_files: List[RangeInFileWithContents] + + def __init__(self, range_in_files: List[RangeInFileWithContents]): + self.range_in_files = range_in_files + + def encode(self) -> str: + return "\n\n".join( + [ + f"File ({rif.filepath})\n```\n{rif.contents}\n```" + for rif in self.range_in_files + ] + ) + + def _suggestions_to_file_edits(self, suggestions: Dict[str, str]) -> List[FileEdit]: + file_edits: List[FileEdit] = [] + for suggestion_filepath, suggestion in suggestions.items(): + matching_rifs = list( + filter(lambda r: r.filepath == suggestion_filepath, self.range_in_files) + ) + if len(matching_rifs) > 0: + range_in_file = matching_rifs[0] + file_edits.append( + FileEdit( + range=range_in_file.range, + filepath=range_in_file.filepath, + replacement=suggestion, + ) + ) + + return file_edits + + def _decode_to_suggestions(self, completion: str) -> Dict[str, str]: + if len(self.range_in_files) == 0: + return {} + + if "```" not in completion: + completion = "```\n" + completion + "\n```" + if completion.strip().splitlines()[0].strip() == "```": + first_filepath = self.range_in_files[0].filepath + completion = f"File ({first_filepath})\n" + completion + + suggestions: Dict[str, str] = {} + current_file_lines: List[str] = [] + current_filepath: Union[str, None] = None + last_was_file = False + inside_file = False + for line in completion.splitlines(): + if line.strip().startswith("File ("): + last_was_file = True + current_filepath = line.strip()[6:-1] + elif last_was_file and line.startswith("```"): + last_was_file = False + inside_file = True + elif inside_file: + if line.startswith("```"): + inside_file = False + suggestions[current_filepath] = "\n".join(current_file_lines) + current_file_lines = [] + current_filepath = None + else: + current_file_lines.append(line) + + return suggestions + + def decode(self, completion: str) -> List[FileEdit]: + suggestions = self._decode_to_suggestions(completion) + file_edits = self._suggestions_to_file_edits(suggestions) + return file_edits diff --git a/server/continuedev/libs/llm/prompts/chat.py b/server/continuedev/libs/llm/prompts/chat.py new file mode 100644 index 00000000..036f1b1a --- /dev/null +++ b/server/continuedev/libs/llm/prompts/chat.py @@ -0,0 +1,174 @@ +from textwrap import dedent +from typing import Dict, List + +from anthropic import AI_PROMPT, HUMAN_PROMPT + + +def anthropic_template_messages(messages: List[Dict[str, str]]) -> str: + prompt = "" + + # Anthropic prompt must start with a Human turn + if ( + len(messages) > 0 + and messages[0]["role"] != "user" + and messages[0]["role"] != "system" + ): + prompt += f"{HUMAN_PROMPT} Hello." + for msg in messages: + prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " + + prompt += AI_PROMPT + return prompt + + +def template_alpaca_messages(msgs: List[Dict[str, str]]) -> str: + prompt = "" + + if msgs[0]["role"] == "system": + prompt += f"{msgs[0]['content']}\n" + msgs.pop(0) + + for msg in msgs: + prompt += "### Instruction:\n" if msg["role"] == "user" else "### Response:\n" + prompt += f"{msg['content']}\n" + + prompt += "### Response:\n" + + return prompt + + +def raw_input_template(msgs: List[Dict[str, str]]) -> str: + return msgs[-1]["content"] + + +SQL_CODER_DEFAULT_SCHEMA = """\ +CREATE TABLE products ( + product_id INTEGER PRIMARY KEY, -- Unique ID for each product + name VARCHAR(50), -- Name of the product + price DECIMAL(10,2), -- Price of each unit of the product + quantity INTEGER -- Current quantity in stock +); + +CREATE TABLE customers ( + customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer + name VARCHAR(50), -- Name of the customer + address VARCHAR(100) -- Mailing address of the customer +); + +CREATE TABLE salespeople ( + salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson + name VARCHAR(50), -- Name of the salesperson + region VARCHAR(50) -- Geographic sales region +); + +CREATE TABLE sales ( + sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale + product_id INTEGER, -- ID of product sold + customer_id INTEGER, -- ID of customer who made purchase + salesperson_id INTEGER, -- ID of salesperson who made the sale + sale_date DATE, -- Date the sale occurred + quantity INTEGER -- Quantity of product sold +); + +CREATE TABLE product_suppliers ( + supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier + product_id INTEGER, -- Product ID supplied + supply_price DECIMAL(10,2) -- Unit price charged by supplier +); + +-- sales.product_id can be joined with products.product_id +-- sales.customer_id can be joined with customers.customer_id +-- sales.salesperson_id can be joined with salespeople.salesperson_id +-- product_suppliers.product_id can be joined with products.product_id +""" + + +def _sqlcoder_template_messages( + msgs: List[Dict[str, str]], schema: str = SQL_CODER_DEFAULT_SCHEMA +) -> str: + question = msgs[-1]["content"] + return f"""\ +Your task is to convert a question into a SQL query, given a Postgres database schema. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. +- When creating a ratio, always cast the numerator as float + +### Input: +Generate a SQL query that answers the question `{question}`. +This query will run on a database whose schema is represented in this string: +{schema} + +### Response: +Based on your instructions, here is the SQL query I have generated to answer the question `{question}`: +```sql +""" + + +def sqlcoder_template_messages(schema: str = SQL_CODER_DEFAULT_SCHEMA): + if schema == "" or schema == "": + schema = SQL_CODER_DEFAULT_SCHEMA + + def fn(msgs): + return _sqlcoder_template_messages(msgs, schema=schema) + + fn.__name__ = "sqlcoder_template_messages" + return fn + + +def llama2_template_messages(msgs: List[Dict[str, str]]) -> str: + if len(msgs) == 0: + return "" + + if msgs[0]["role"] == "assistant": + # These models aren't trained to handle assistant message coming first, + # and typically these are just introduction messages from Continue + msgs.pop(0) + + prompt = "" + has_system = msgs[0]["role"] == "system" + + if has_system and msgs[0]["content"].strip() == "": + has_system = False + msgs = msgs[1:] + + if has_system: + system_message = dedent( + f"""\ + <> + {msgs[0]["content"]} + <> + + """ + ) + if len(msgs) > 1: + prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" + else: + prompt += f"[INST] {system_message} [/INST]" + return + + for i in range(2 if has_system else 0, len(msgs)): + if msgs[i]["role"] == "user": + prompt += f"[INST] {msgs[i]['content']} [/INST]" + else: + prompt += msgs[i]["content"] + " " + + return prompt + + +def code_llama_template_messages(msgs: List[Dict[str, str]]) -> str: + return f"[INST] {msgs[-1]['content']}\n[/INST]" + + +def extra_space_template_messages(msgs: List[Dict[str, str]]) -> str: + return f" {msgs[-1]['content']}" + + +def code_llama_python_template_messages(msgs: List[Dict[str, str]]) -> str: + return dedent( + f"""\ + [INST] + You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']} + Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag. + [/INST]""" + ) diff --git a/server/continuedev/libs/llm/prompts/edit.py b/server/continuedev/libs/llm/prompts/edit.py new file mode 100644 index 00000000..eaa694c5 --- /dev/null +++ b/server/continuedev/libs/llm/prompts/edit.py @@ -0,0 +1,27 @@ +from textwrap import dedent + +simplified_edit_prompt = dedent( + """\ + Consider the following code: + ``` + {{{code_to_edit}}} + ``` + Edit the code to perfectly satisfy the following user request: + {{{user_input}}} + Output nothing except for the code. No code block, no English explanation, no start/end tags.""" +) + +simplest_edit_prompt = dedent( + """\ + Here is the code before editing: + ``` + {{{code_to_edit}}} + ``` + + Here is the edit requested: + "{{{user_input}}}" + + Here is the code after editing:""" +) + +codellama_infill_edit_prompt = "{{file_prefix}}{{file_suffix}}" diff --git a/server/continuedev/libs/llm/proxy_server.py b/server/continuedev/libs/llm/proxy_server.py new file mode 100644 index 00000000..7c3462eb --- /dev/null +++ b/server/continuedev/libs/llm/proxy_server.py @@ -0,0 +1,108 @@ +import json +import traceback +from typing import List + +import aiohttp + +from ...core.main import ChatMessage +from ..util.telemetry import posthog_logger +from .base import LLM + +# SERVER_URL = "http://127.0.0.1:8080" +SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" + +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192, +} + + +class ProxyServer(LLM): + _client_session: aiohttp.ClientSession + + class Config: + arbitrary_types_allowed = True + + async def start( + self, + **kwargs, + ): + await super().start(**kwargs) + self._client_session = self.create_client_session() + + self.context_length = MAX_TOKENS_FOR_MODEL[self.model] + + async def stop(self): + await self._client_session.close() + + def get_headers(self): + return {"unique_id": self.unique_id} + + async def _complete(self, prompt: str, options): + args = self.collect_args(options) + + async with self._client_session.post( + f"{SERVER_URL}/complete", + json={"messages": [{"role": "user", "content": prompt}], **args}, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + resp_text = await resp.text() + if resp.status != 200: + raise Exception(resp_text) + + return resp_text + + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) + async with self._client_session.post( + f"{SERVER_URL}/stream_chat", + json={"messages": messages, **args}, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + if resp.status != 200: + raise Exception(await resp.text()) + + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + json_chunk = "{}" if json_chunk == "" else json_chunk + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + loaded_chunk = json.loads(chunk) + yield loaded_chunk + + except Exception as e: + posthog_logger.capture_event( + "proxy_server_parse_error", + { + "error_title": "Proxy server stream_chat parsing failed", + "error_message": "\n".join( + traceback.format_exception(e) + ), + }, + ) + else: + break + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + + async with self._client_session.post( + f"{SERVER_URL}/stream_complete", + json={"messages": [{"role": "user", "content": prompt}], **args}, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + if resp.status != 200: + raise Exception(await resp.text()) + + async for line in resp.content.iter_any(): + if line: + decoded_line = line.decode("utf-8") + yield decoded_line diff --git a/server/continuedev/libs/llm/queued.py b/server/continuedev/libs/llm/queued.py new file mode 100644 index 00000000..2db749eb --- /dev/null +++ b/server/continuedev/libs/llm/queued.py @@ -0,0 +1,77 @@ +import asyncio +from typing import Any, List, Union + +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM, CompletionOptions + + +class QueuedLLM(LLM): + """ + QueuedLLM exists to make up for LLM servers that cannot handle multiple requests at once. It uses a lock to ensure that only one request is being processed at a time. + + If you are already using another LLM class and are experiencing this problem, you can just wrap it with the QueuedLLM class like this: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.queued import QueuedLLM + + config = ContinueConfig( + ... + models=Models( + default=QueuedLLM(llm=) + ) + ) + ``` + """ + + llm: LLM = Field(..., description="The LLM to wrap with a lock") + _lock: asyncio.Lock + + model: str = "queued" + + def dict(self, **kwargs): + return self.llm.dict(**kwargs) + + async def start(self, *args, **kwargs): + await super().start(*args, **kwargs) + await self.llm.start(*args, **kwargs) + self._lock = asyncio.Lock() + self.model = self.llm.model + self.template_messages = self.llm.template_messages + self.prompt_templates = self.llm.prompt_templates + self.context_length = self.llm.context_length + + async def stop(self): + await self.llm.stop() + + def collect_args(self, options: CompletionOptions): + return self.llm.collect_args(options) + + def compile_chat_messages( + self, + options: CompletionOptions, + msgs: List[ChatMessage], + functions: Union[List[Any], None] = None, + ): + return self.llm.compile_chat_messages(options, msgs, functions) + + def template_prompt_like_messages(self, prompt: str) -> str: + return self.llm.template_prompt_like_messages(prompt) + + async def _complete(self, prompt: str, options: CompletionOptions): + async with self._lock: + resp = await self.llm._complete(prompt, options) + return resp + + async def _stream_complete(self, prompt: str, options: CompletionOptions): + async with self._lock: + async for chunk in self.llm._stream_complete(prompt, options): + yield chunk + + async def _stream_chat( + self, messages: List[ChatMessage], options: CompletionOptions + ): + async with self._lock: + async for chunk in self.llm._stream_chat(messages, options): + yield chunk diff --git a/server/continuedev/libs/llm/replicate.py b/server/continuedev/libs/llm/replicate.py new file mode 100644 index 00000000..3423193b --- /dev/null +++ b/server/continuedev/libs/llm/replicate.py @@ -0,0 +1,78 @@ +import concurrent.futures +from typing import List + +import replicate +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM +from .prompts.edit import simplified_edit_prompt + + +class ReplicateLLM(LLM): + """ + Replicate is a great option for newly released language models or models that you've deployed through their platform. Sign up for an account [here](https://replicate.ai/), copy your API key, and then select any model from the [Replicate Streaming List](https://replicate.com/collections/streaming-language-models). Change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.core.models import Models + from continuedev.libs.llm.replicate import ReplicateLLM + + config = ContinueConfig( + ... + models=Models( + default=ReplicateLLM( + model="replicate/codellama-13b-instruct:da5676342de1a5a335b848383af297f592b816b950a43d251a0a9edd0113604b", + api_key="my-replicate-api-key") + ) + ) + ``` + + If you don't specify the `model` parameter, it will default to `replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781`. + """ + + api_key: str = Field(..., description="Replicate API key") + + model: str = "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781" + + _client: replicate.Client = None + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + async def start(self, **kwargs): + await super().start(**kwargs) + self._client = replicate.Client(api_token=self.api_key) + + async def _complete(self, prompt: str, options): + def helper(): + output = self._client.run( + self.model, input={"message": prompt, "prompt": prompt} + ) + completion = "" + for item in output: + completion += item + + return completion + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(helper) + completion = future.result() + + return completion + + async def _stream_complete(self, prompt, options): + for item in self._client.run( + self.model, input={"message": prompt, "prompt": prompt} + ): + yield item + + async def _stream_chat(self, messages: List[ChatMessage], options): + for item in self._client.run( + self.model, + input={ + "message": messages[-1]["content"], + "prompt": messages[-1]["content"], + }, + ): + yield {"content": item, "role": "assistant"} diff --git a/server/continuedev/libs/llm/text_gen_interface.py b/server/continuedev/libs/llm/text_gen_interface.py new file mode 100644 index 00000000..225fd3b6 --- /dev/null +++ b/server/continuedev/libs/llm/text_gen_interface.py @@ -0,0 +1,114 @@ +import json +from typing import Any, Callable, Dict, List, Union + +import websockets +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplest_edit_prompt + + +class TextGenUI(LLM): + """ + TextGenUI is a comprehensive, open-source language model UI and local server. You can set it up with an OpenAI-compatible server plugin, but if for some reason that doesn't work, you can use this class like so: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.text_gen_interface import TextGenUI + + config = ContinueConfig( + ... + models=Models( + default=TextGenUI( + model="", + ) + ) + ) + ``` + """ + + model: str = "text-gen-ui" + server_url: str = Field( + "http://localhost:5000", description="URL of your TextGenUI server" + ) + streaming_url: str = Field( + "http://localhost:5005", + description="URL of your TextGenUI streaming server (separate from main server URL)", + ) + + prompt_templates = { + "edit": simplest_edit_prompt, + } + + template_messages: Union[ + Callable[[List[Dict[str, str]]], str], None + ] = llama2_template_messages + + class Config: + arbitrary_types_allowed = True + + def collect_args(self, options) -> Any: + args = super().collect_args(options) + args = {**args, "max_new_tokens": options.max_tokens} + args.pop("max_tokens", None) + return args + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + + ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}" + payload = json.dumps({"prompt": prompt, "stream": True, **args}) + async with websockets.connect( + f"{ws_url}/api/v1/stream", ping_interval=None + ) as websocket: + await websocket.send(payload) + + while True: + incoming_data = await websocket.recv() + incoming_data = json.loads(incoming_data) + + match incoming_data["event"]: + case "text_stream": + yield incoming_data["text"] + case "stream_end": + break + + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) + + async def generator(): + ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}" + history = list(map(lambda x: x["content"], messages)) + payload = json.dumps( + { + "user_input": messages[-1]["content"], + "history": {"internal": [history], "visible": [history]}, + "stream": True, + **args, + } + ) + async with websockets.connect( + f"{ws_url}/api/v1/chat-stream", ping_interval=None + ) as websocket: + await websocket.send(payload) + + prev = "" + while True: + incoming_data = await websocket.recv() + incoming_data = json.loads(incoming_data) + + match incoming_data["event"]: + case "text_stream": + visible = incoming_data["history"]["visible"][-1] + if len(visible) > 0: + yield { + "role": "assistant", + "content": visible[-1].replace(prev, ""), + } + prev = visible[-1] + case "stream_end": + break + + async for chunk in generator(): + yield chunk diff --git a/server/continuedev/libs/llm/together.py b/server/continuedev/libs/llm/together.py new file mode 100644 index 00000000..35b3a424 --- /dev/null +++ b/server/continuedev/libs/llm/together.py @@ -0,0 +1,125 @@ +import json +from typing import Callable + +import aiohttp +from pydantic import Field + +from ...core.main import ContinueCustomException +from ..util.logging import logger +from .base import LLM +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class TogetherLLM(LLM): + """ + The Together API is a cloud platform for running large AI models. You can sign up [here](https://api.together.xyz/signup), copy your API key on the initial welcome screen, and then hit the play button on any model from the [Together Models list](https://docs.together.ai/docs/models-inference). Change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.core.models import Models + from continuedev.libs.llm.together import TogetherLLM + + config = ContinueConfig( + ... + models=Models( + default=TogetherLLM( + api_key="", + model="togethercomputer/llama-2-13b-chat" + ) + ) + ) + ``` + """ + + api_key: str = Field(..., description="Together API key") + + model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct" + base_url: str = Field( + "https://api.together.xyz", + description="The base URL for your Together API instance", + ) + + _client_session: aiohttp.ClientSession = None + + template_messages: Callable = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + async def start(self, **kwargs): + await super().start(**kwargs) + self._client_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) + + async def stop(self): + await self._client_session.close() + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + + async with self._client_session.post( + f"{self.base_url}/inference", + json={ + "prompt": prompt, + "stream_tokens": True, + **args, + }, + headers={"Authorization": f"Bearer {self.api_key}"}, + proxy=self.proxy, + ) as resp: + async for line in resp.content.iter_chunks(): + if line[1]: + json_chunk = line[0].decode("utf-8") + if json_chunk.startswith(": ping - ") or json_chunk.startswith( + "data: [DONE]" + ): + continue + + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + if chunk.startswith("data: "): + chunk = chunk[6:] + if chunk == "[DONE]": + break + try: + json_chunk = json.loads(chunk) + except Exception as e: + logger.warning(f"Invalid JSON chunk: {chunk}\n\n{e}") + continue + if "choices" in json_chunk: + yield json_chunk["choices"][0]["text"] + + async def _complete(self, prompt: str, options): + args = self.collect_args(options) + + async with self._client_session.post( + f"{self.base_url}/inference", + json={"prompt": prompt, **args}, + headers={"Authorization": f"Bearer {self.api_key}"}, + proxy=self.proxy, + ) as resp: + text = await resp.text() + j = json.loads(text) + try: + if "choices" not in j["output"]: + raise Exception(text) + if "output" in j: + return j["output"]["choices"][0]["text"] + except Exception as e: + j = await resp.json() + if "error" in j: + if j["error"].startswith("invalid hexlify value"): + raise ContinueCustomException( + message=f"Invalid Together API key:\n\n{j['error']}", + title="Together API Error", + ) + else: + raise ContinueCustomException( + message=j["error"], title="Together API Error" + ) + + raise e diff --git a/server/continuedev/libs/util/calculate_diff.py b/server/continuedev/libs/util/calculate_diff.py new file mode 100644 index 00000000..99301ae7 --- /dev/null +++ b/server/continuedev/libs/util/calculate_diff.py @@ -0,0 +1,154 @@ +import difflib +from typing import List + +from ...models.filesystem import FileEdit +from ...models.main import Position, Range + + +def calculate_diff(filepath: str, original: str, updated: str) -> List[FileEdit]: + s = difflib.SequenceMatcher(None, original, updated) + offset = 0 # The indices are offset by previous deletions/insertions + edits = [] + for tag, i1, i2, j1, j2 in s.get_opcodes(): + i1, i2, j1, j2 = i1 + offset, i2 + offset, j1 + offset, j2 + offset + replacement = updated[j1:j2] + if tag == "equal": + pass + elif tag == "delete": + edits.append( + FileEdit.from_deletion(filepath, Range.from_indices(original, i1, i2)) + ) + offset -= i2 - i1 + elif tag == "insert": + edits.append( + FileEdit.from_insertion( + filepath, Position.from_index(original, i1), replacement + ) + ) + offset += j2 - j1 + elif tag == "replace": + edits.append( + FileEdit( + filepath=filepath, + range=Range.from_indices(original, i1, i2), + replacement=replacement, + ) + ) + offset += (j2 - j1) - (i2 - i1) + else: + raise Exception("Unexpected difflib.SequenceMatcher tag: " + tag) + + return edits + + +def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit]: + # original_lines = original.splitlines() + # updated_lines = updated.splitlines() + # offset = 0 + # while len(original_lines) and len(updated_lines) and original_lines[0] == updated_lines[0]: + # original_lines = original_lines[1:] + # updated_lines = updated_lines[1:] + + # while len(original_lines) and len(updated_lines) and original_lines[-1] == updated_lines[-1]: + # original_lines = original_lines[:-1] + # updated_lines = updated_lines[:-1] + + # original = "\n".join(original_lines) + # updated = "\n".join(updated_lines) + + edits = [] + max_iterations = 1000 + i = 0 + while not original == updated: + # TODO - For some reason it can't handle a single newline at the end of the file? + s = difflib.SequenceMatcher(None, original, updated) + opcodes = s.get_opcodes() + for edit_index in range(len(opcodes)): + tag, i1, i2, j1, j2 = s.get_opcodes()[edit_index] + replacement = updated[j1:j2] + if tag == "equal": + continue # ;) + elif tag == "delete": + edits.append( + FileEdit.from_deletion( + filepath, Range.from_indices(original, i1, i2) + ) + ) + elif tag == "insert": + edits.append( + FileEdit.from_insertion( + filepath, Position.from_index(original, i1), replacement + ) + ) + elif tag == "replace": + edits.append( + FileEdit( + filepath=filepath, + range=Range.from_indices(original, i1, i2), + replacement=replacement, + ) + ) + else: + raise Exception("Unexpected difflib.SequenceMatcher tag: " + tag) + break + + original = apply_edit_to_str(original, edits[-1]) + + i += 1 + if i > max_iterations: + raise Exception("Max iterations reached") + + return edits + + +def read_range_in_str(s: str, r: Range) -> str: + lines = s.splitlines()[r.start.line : r.end.line + 1] + if len(lines) == 0: + return "" + + lines[0] = lines[0][r.start.character :] + lines[-1] = lines[-1][: r.end.character + 1] + return "\n".join(lines) + + +def apply_edit_to_str(s: str, edit: FileEdit) -> str: + read_range_in_str(s, edit.range) + + # Split lines and deal with some edge cases (could obviously be nicer) + lines = s.splitlines() + if s.startswith("\n"): + lines.insert(0, "") + if s.endswith("\n"): + lines.append("") + + if len(lines) == 0: + lines = [""] + + end = Position(line=edit.range.end.line, character=edit.range.end.character) + if edit.range.end.line == len(lines) and edit.range.end.character == 0: + end = Position( + line=edit.range.end.line - 1, + character=len(lines[min(len(lines) - 1, edit.range.end.line - 1)]), + ) + + before_lines = lines[: edit.range.start.line] + after_lines = lines[end.line + 1 :] + between_str = ( + lines[min(len(lines) - 1, edit.range.start.line)][: edit.range.start.character] + + edit.replacement + + lines[min(len(lines) - 1, end.line)][end.character + 1 :] + ) + + Range( + start=edit.range.start, + end=Position( + line=edit.range.start.line + len(edit.replacement.splitlines()) - 1, + character=edit.range.start.character + + len(edit.replacement.splitlines()[-1]) + if edit.replacement != "" + else 0, + ), + ) + + lines = before_lines + between_str.splitlines() + after_lines + return "\n".join(lines) diff --git a/server/continuedev/libs/util/commonregex.py b/server/continuedev/libs/util/commonregex.py new file mode 100644 index 00000000..c2f6bb82 --- /dev/null +++ b/server/continuedev/libs/util/commonregex.py @@ -0,0 +1,144 @@ +# coding: utf-8 +import re +from typing import Any + +date = re.compile( + "(?:(?]+[^\s`!()\[\]{};:'\".,<>?\xab\xbb\u201c\u201d\u2018\u2019])?)", + re.IGNORECASE, +) +email = re.compile( + "([a-z0-9!#$%&'*+\/=?^_`{|.}~-]+@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?)", + re.IGNORECASE, +) +ip = re.compile( + "(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)", + re.IGNORECASE, +) +ipv6 = re.compile( + "\s*(?!.*::.*::)(?:(?!:)|:(?=:))(?:[0-9a-f]{0,4}(?:(?<=::)|(?", + "times": "