diff options
Diffstat (limited to 'server/continuedev/libs')
45 files changed, 4091 insertions, 0 deletions
diff --git a/server/continuedev/libs/__init__.py b/server/continuedev/libs/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/server/continuedev/libs/__init__.py diff --git a/server/continuedev/libs/chroma/.gitignore b/server/continuedev/libs/chroma/.gitignore new file mode 100644 index 00000000..6320cd24 --- /dev/null +++ b/server/continuedev/libs/chroma/.gitignore @@ -0,0 +1 @@ +data
\ No newline at end of file diff --git a/server/continuedev/libs/chroma/query.py b/server/continuedev/libs/chroma/query.py new file mode 100644 index 00000000..d77cce49 --- /dev/null +++ b/server/continuedev/libs/chroma/query.py @@ -0,0 +1,218 @@ +import json +import os +import subprocess +from functools import cached_property +from typing import List, Tuple + +from llama_index import ( + Document, + GPTVectorStoreIndex, + StorageContext, + load_index_from_storage, +) +from llama_index.langchain_helpers.text_splitter import TokenTextSplitter + +from ..util.logging import logger +from .update import filter_ignored_files, load_gpt_index_documents + + +class ChromaIndexManager: + workspace_dir: str + + def __init__(self, workspace_dir: str): + self.workspace_dir = workspace_dir + + @cached_property + def current_commit(self) -> str: + """Get the current commit.""" + return ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], cwd=self.workspace_dir + ) + .decode("utf-8") + .strip() + ) + + @cached_property + def current_branch(self) -> str: + """Get the current branch.""" + return ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=self.workspace_dir + ) + .decode("utf-8") + .strip() + ) + + @cached_property + def index_dir(self) -> str: + return os.path.join( + self.workspace_dir, ".continue", "chroma", self.current_branch + ) + + @cached_property + def git_root_dir(self): + """Get the root directory of a Git repository.""" + try: + return ( + subprocess.check_output( + ["git", "rev-parse", "--show-toplevel"], cwd=self.workspace_dir + ) + .strip() + .decode() + ) + except subprocess.CalledProcessError: + return None + + def check_index_exists(self): + return os.path.exists(os.path.join(self.index_dir, "metadata.json")) + + def create_codebase_index(self): + """Create a new index for the current branch.""" + if not self.check_index_exists(): + os.makedirs(self.index_dir) + else: + return + + documents = load_gpt_index_documents(self.workspace_dir) + + chunks = {} + doc_chunks = [] + for doc in documents: + text_splitter = TokenTextSplitter() + try: + text_chunks = text_splitter.split_text(doc.text) + except: + logger.warning(f"ERROR (probably found special token): {doc.text}") + continue # lol + filename = doc.extra_info["filename"] + chunks[filename] = len(text_chunks) + for i, text in enumerate(text_chunks): + doc_chunks.append(Document(text, doc_id=f"{filename}::{i}")) + + with open(f"{self.index_dir}/metadata.json", "w") as f: + json.dump({"commit": self.current_commit, "chunks": chunks}, f, indent=4) + + index = GPTVectorStoreIndex([]) + + for chunk in doc_chunks: + index.insert(chunk) + + # d = 1536 # Dimension of text-ada-embedding-002 + # faiss_index = faiss.IndexFlatL2(d) + # index = GPTFaissIndex(documents, faiss_index=faiss_index) + # index.save_to_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index") + + index.storage_context.persist(persist_dir=self.index_dir) + + logger.debug("Codebase index created") + + def get_modified_deleted_files(self) -> Tuple[List[str], List[str]]: + """Get a list of all files that have been modified since the last commit.""" + metadata = f"{self.index_dir}/metadata.json" + with open(metadata, "r") as f: + previous_commit = json.load(f)["commit"] + + modified_deleted_files = ( + subprocess.check_output( + ["git", "diff", "--name-only", previous_commit, self.current_commit] + ) + .decode("utf-8") + .strip() + ) + modified_deleted_files = modified_deleted_files.split("\n") + modified_deleted_files = [f for f in modified_deleted_files if f] + + deleted_files = [ + f + for f in modified_deleted_files + if not os.path.exists(os.path.join(self.workspace_dir, f)) + ] + modified_files = [ + f + for f in modified_deleted_files + if os.path.exists(os.path.join(self.workspace_dir, f)) + ] + + return filter_ignored_files( + modified_files, self.index_dir + ), filter_ignored_files(deleted_files, self.index_dir) + + def update_codebase_index(self): + """Update the index with a list of files.""" + + if not self.check_index_exists(): + self.create_codebase_index() + else: + # index = GPTFaissIndex.load_from_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index") + index = GPTVectorStoreIndex.load_from_disk(f"{self.index_dir}/index.json") + modified_files, deleted_files = self.get_modified_deleted_files() + + with open(f"{self.index_dir}/metadata.json", "r") as f: + metadata = json.load(f) + + for file in deleted_files: + num_chunks = metadata["chunks"][file] + for i in range(num_chunks): + index.delete(f"{file}::{i}") + + del metadata["chunks"][file] + + logger.debug(f"Deleted {file}") + + for file in modified_files: + if file in metadata["chunks"]: + num_chunks = metadata["chunks"][file] + + for i in range(num_chunks): + index.delete(f"{file}::{i}") + + logger.debug(f"Deleted old version of {file}") + + with open(file, "r") as f: + text = f.read() + + text_splitter = TokenTextSplitter() + text_chunks = text_splitter.split_text(text) + + for i, text in enumerate(text_chunks): + index.insert(Document(text, doc_id=f"{file}::{i}")) + + metadata["chunks"][file] = len(text_chunks) + + logger.debug(f"Inserted new version of {file}") + + metadata["commit"] = self.current_commit + + with open(f"{self.index_dir}/metadata.json", "w") as f: + json.dump(metadata, f, indent=4) + + logger.debug("Codebase index updated") + + def query_codebase_index(self, query: str) -> str: + """Query the codebase index.""" + if not self.check_index_exists(): + logger.debug(f"No index found for the codebase at {self.index_dir}") + return "" + + storage_context = StorageContext.from_defaults(persist_dir=self.index_dir) + index = load_index_from_storage(storage_context) + # index = GPTVectorStoreIndex.load_from_disk(path) + engine = index.as_query_engine() + return engine.query(query) + + def query_additional_index(self, query: str) -> str: + """Query the additional index.""" + index = GPTVectorStoreIndex.load_from_disk( + os.path.join(self.index_dir, "additional_index.json") + ) + return index.query(query) + + def replace_additional_index(self, info: str): + """Replace the additional index with the given info.""" + with open(f"{self.index_dir}/additional_context.txt", "w") as f: + f.write(info) + documents = [Document(info)] + index = GPTVectorStoreIndex(documents) + index.save_to_disk(f"{self.index_dir}/additional_index.json") + logger.debug("Additional index replaced") diff --git a/server/continuedev/libs/chroma/update.py b/server/continuedev/libs/chroma/update.py new file mode 100644 index 00000000..7a1217f9 --- /dev/null +++ b/server/continuedev/libs/chroma/update.py @@ -0,0 +1,66 @@ +# import faiss +import os +import subprocess +from typing import List + +from dotenv import load_dotenv +from llama_index import Document, SimpleDirectoryReader + +load_dotenv() + +FILE_TYPES_TO_IGNORE = [".pyc", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico"] + + +def filter_ignored_files(files: List[str], root_dir: str): + """Further filter files before indexing.""" + for file in files: + if ( + file.endswith(tuple(FILE_TYPES_TO_IGNORE)) + or file.startswith(".git") + or file.startswith("archive") + ): + continue # nice + yield root_dir + "/" + file + + +def get_git_ignored_files(root_dir: str): + """Get the list of ignored files in a Git repository.""" + try: + output = ( + subprocess.check_output( + ["git", "ls-files", "--ignored", "--others", "--exclude-standard"], + cwd=root_dir, + ) + .strip() + .decode() + ) + return output.split("\n") + except subprocess.CalledProcessError: + return [] + + +def get_all_files(root_dir: str): + """Get a list of all files in a directory.""" + for dir_path, _, file_names in os.walk(root_dir): + for file_name in file_names: + yield os.path.join(os.path.relpath(dir_path, root_dir), file_name) + + +def get_input_files(root_dir: str): + """Get a list of all files in a Git repository that are not ignored.""" + ignored_files = set(get_git_ignored_files(root_dir)) + all_files = set(get_all_files(root_dir)) + nonignored_files = all_files - ignored_files + return filter_ignored_files(nonignored_files, root_dir) + + +def load_gpt_index_documents(root: str) -> List[Document]: + """Loads a list of GPTIndex Documents, respecting .gitignore files.""" + # Get input files + input_files = get_input_files(root) + # Use SimpleDirectoryReader to load the files into Documents + return SimpleDirectoryReader( + root, + input_files=input_files, + file_metadata=lambda filename: {"filename": filename}, + ).load_data() diff --git a/server/continuedev/libs/constants/default_config.py b/server/continuedev/libs/constants/default_config.py new file mode 100644 index 00000000..a007eef1 --- /dev/null +++ b/server/continuedev/libs/constants/default_config.py @@ -0,0 +1,88 @@ +default_config = """\ +\"\"\" +This is the Continue configuration file. + +See https://continue.dev/docs/customization to for documentation of the available options. +\"\"\" + +from continuedev.core.models import Models +from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig +from continuedev.libs.llm import OpenAIFreeTrial + +from continuedev.plugins.context_providers import ( + DiffContextProvider, + TerminalContextProvider, + URLContextProvider, + GitHubIssuesContextProvider +) +from continuedev.plugins.steps import ( + ClearHistoryStep, + CommentCodeStep, + EditHighlightedCodeStep, + GenerateShellCommandStep, + OpenConfigStep, +) +from continuedev.plugins.steps.share_session import ShareSessionStep + +config = ContinueConfig( + allow_anonymous_telemetry=True, + models=Models( + default=OpenAIFreeTrial(api_key="", model="gpt-4"), + summarize=OpenAIFreeTrial(api_key="", model="gpt-3.5-turbo") + ), + system_message=None, + temperature=0.5, + custom_commands=[ + CustomCommand( + name="test", + description="Write unit tests for highlighted code", + prompt="Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated. Give the tests just as chat output, don't edit any file.", + ) + ], + slash_commands=[ + SlashCommand( + name="edit", + description="Edit highlighted code", + step=EditHighlightedCodeStep, + ), + SlashCommand( + name="config", + description="Customize Continue", + step=OpenConfigStep, + ), + SlashCommand( + name="comment", + description="Write comments for the highlighted code", + step=CommentCodeStep, + ), + SlashCommand( + name="clear", + description="Clear step history", + step=ClearHistoryStep, + ), + SlashCommand( + name="share", + description="Download and share this session", + step=ShareSessionStep, + ), + SlashCommand( + name="cmd", + description="Generate a shell command", + step=GenerateShellCommandStep, + ), + ], + context_providers=[ + # GitHubIssuesContextProvider( + # repo_name="<your github username or organization>/<your repo name>", + # auth_token="<your github auth token>" + # ), + DiffContextProvider(), + URLContextProvider( + preset_urls = [ + # Add any common urls you reference here so they appear in autocomplete + ] + ), + TerminalContextProvider(), + ], +) +""" diff --git a/server/continuedev/libs/constants/main.py b/server/continuedev/libs/constants/main.py new file mode 100644 index 00000000..f5964df6 --- /dev/null +++ b/server/continuedev/libs/constants/main.py @@ -0,0 +1,6 @@ +## PATHS ## + +CONTINUE_GLOBAL_FOLDER = ".continue" +CONTINUE_SESSIONS_FOLDER = "sessions" +CONTINUE_SERVER_FOLDER = "server" +CONTINUE_SERVER_VERSION_FILE = "server_version.txt" diff --git a/server/continuedev/libs/llm/__init__.py b/server/continuedev/libs/llm/__init__.py new file mode 100644 index 00000000..829ffede --- /dev/null +++ b/server/continuedev/libs/llm/__init__.py @@ -0,0 +1,14 @@ +from .anthropic import AnthropicLLM # noqa: F401 +from .ggml import GGML # noqa: F401 +from .google_palm_api import GooglePaLMAPI # noqa: F401 +from .hf_inference_api import HuggingFaceInferenceAPI # noqa: F401 +from .hf_tgi import HuggingFaceTGI # noqa: F401 +from .llamacpp import LlamaCpp # noqa: F401 +from .ollama import Ollama # noqa: F401 +from .openai import OpenAI # noqa: F401 +from .openai_free_trial import OpenAIFreeTrial # noqa: F401 +from .proxy_server import ProxyServer # noqa: F401 +from .queued import QueuedLLM # noqa: F401 +from .replicate import ReplicateLLM # noqa: F401 +from .text_gen_interface import TextGenUI # noqa: F401 +from .together import TogetherLLM # noqa: F401 diff --git a/server/continuedev/libs/llm/anthropic.py b/server/continuedev/libs/llm/anthropic.py new file mode 100644 index 00000000..7d0708f1 --- /dev/null +++ b/server/continuedev/libs/llm/anthropic.py @@ -0,0 +1,74 @@ +from typing import Any, Callable, Coroutine + +from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic + +from .base import LLM, CompletionOptions +from .prompts.chat import anthropic_template_messages + + +class AnthropicLLM(LLM): + """ + Import the `AnthropicLLM` class and set it as the default model: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.anthropic import AnthropicLLM + + config = ContinueConfig( + ... + models=Models( + default=AnthropicLLM(api_key="<API_KEY>", model="claude-2") + ) + ) + ``` + + Claude 2 is not yet publicly released. You can request early access [here](https://www.anthropic.com/earlyaccess). + + """ + + api_key: str + "Anthropic API key" + + model: str = "claude-2" + + _async_client: AsyncAnthropic = None + + template_messages: Callable = anthropic_template_messages + + class Config: + arbitrary_types_allowed = True + + async def start(self, **kwargs): + await super().start(**kwargs) + self._async_client = AsyncAnthropic(api_key=self.api_key) + + if self.model == "claude-2": + self.context_length = 100_000 + + def collect_args(self, options: CompletionOptions): + options.stop = None + args = super().collect_args(options) + + if "max_tokens" in args: + args["max_tokens_to_sample"] = args["max_tokens"] + del args["max_tokens"] + if "frequency_penalty" in args: + del args["frequency_penalty"] + if "presence_penalty" in args: + del args["presence_penalty"] + return args + + async def _stream_complete(self, prompt: str, options): + args = self.collect_args(options) + prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" + + async for chunk in await self._async_client.completions.create( + prompt=prompt, stream=True, **args + ): + yield chunk.completion + + async def _complete(self, prompt: str, options) -> Coroutine[Any, Any, str]: + args = self.collect_args(options) + prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" + return ( + await self._async_client.completions.create(prompt=prompt, **args) + ).completion diff --git a/server/continuedev/libs/llm/base.py b/server/continuedev/libs/llm/base.py new file mode 100644 index 00000000..d77cb9fc --- /dev/null +++ b/server/continuedev/libs/llm/base.py @@ -0,0 +1,458 @@ +import ssl +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union + +import aiohttp +import certifi +from pydantic import Field, validator + +from ...core.main import ChatMessage +from ...models.main import ContinueBaseModel +from ..util.count_tokens import ( + DEFAULT_ARGS, + DEFAULT_MAX_TOKENS, + compile_chat_messages, + count_tokens, + format_chat_messages, + prune_raw_prompt_from_top, +) +from ..util.devdata import dev_data_logger +from ..util.telemetry import posthog_logger + + +class CompletionOptions(ContinueBaseModel): + """Options for the completion.""" + + @validator( + "*", + pre=True, + always=True, + ) + def ignore_none_and_set_default(cls, value, field): + return value if value is not None else field.default + + model: Optional[str] = Field(None, description="The model name") + temperature: Optional[float] = Field( + None, description="The temperature of the completion." + ) + top_p: Optional[float] = Field(None, description="The top_p of the completion.") + top_k: Optional[int] = Field(None, description="The top_k of the completion.") + presence_penalty: Optional[float] = Field( + None, description="The presence penalty Aof the completion." + ) + frequency_penalty: Optional[float] = Field( + None, description="The frequency penalty of the completion." + ) + stop: Optional[List[str]] = Field( + None, description="The stop tokens of the completion." + ) + max_tokens: int = Field( + DEFAULT_MAX_TOKENS, description="The maximum number of tokens to generate." + ) + functions: Optional[List[Any]] = Field( + None, description="The functions/tools to make available to the model." + ) + + +class LLM(ContinueBaseModel): + title: Optional[str] = Field( + None, + description="A title that will identify this model in the model selection dropdown", + ) + + unique_id: Optional[str] = Field(None, description="The unique ID of the user.") + model: str = Field( + ..., description="The name of the model to be used (e.g. gpt-4, codellama)" + ) + + system_message: Optional[str] = Field( + None, description="A system message that will always be followed by the LLM" + ) + + context_length: int = Field( + 2048, + description="The maximum context length of the LLM in tokens, as counted by count_tokens.", + ) + + stop_tokens: Optional[List[str]] = Field( + None, description="Tokens that will stop the completion." + ) + temperature: Optional[float] = Field( + None, description="The temperature of the completion." + ) + top_p: Optional[float] = Field(None, description="The top_p of the completion.") + top_k: Optional[int] = Field(None, description="The top_k of the completion.") + presence_penalty: Optional[float] = Field( + None, description="The presence penalty Aof the completion." + ) + frequency_penalty: Optional[float] = Field( + None, description="The frequency penalty of the completion." + ) + + timeout: Optional[int] = Field( + 300, + description="Set the timeout for each request to the LLM. If you are running a local LLM that takes a while to respond, you might want to set this to avoid timeouts.", + ) + verify_ssl: Optional[bool] = Field( + None, description="Whether to verify SSL certificates for requests." + ) + ca_bundle_path: str = Field( + None, + description="Path to a custom CA bundle to use when making the HTTP request", + ) + proxy: Optional[str] = Field( + None, + description="Proxy URL to use when making the HTTP request", + ) + headers: Optional[Dict[str, str]] = Field( + None, + description="Headers to use when making the HTTP request", + ) + prompt_templates: dict = Field( + {}, + description='A dictionary of prompt templates that can be used to customize the behavior of the LLM in certain situations. For example, set the "edit" key in order to change the prompt that is used for the /edit slash command. Each value in the dictionary is a string templated in mustache syntax, and filled in at runtime with the variables specific to the situation. See the documentation for more information.', + ) + + template_messages: Optional[Callable[[List[Dict[str, str]]], str]] = Field( + None, + description="A function that takes a list of messages and returns a prompt. This ensures that models like llama2, which are trained on specific chat formats, will always receive input in that format.", + ) + write_log: Optional[Callable[[str], None]] = Field( + None, + description="A function that is called upon every prompt and completion, by default to log to the file which can be viewed by clicking on the magnifying glass.", + ) + + api_key: Optional[str] = Field( + None, description="The API key for the LLM provider." + ) + + class Config: + arbitrary_types_allowed = True + extra = "allow" + fields = { + "title": { + "description": "A title that will identify this model in the model selection dropdown" + }, + "system_message": { + "description": "A system message that will always be followed by the LLM" + }, + "context_length": { + "description": "The maximum context length of the LLM in tokens, as counted by count_tokens." + }, + "unique_id": {"description": "The unique ID of the user."}, + "model": { + "description": "The name of the model to be used (e.g. gpt-4, codellama)" + }, + "timeout": { + "description": "Set the timeout for each request to the LLM. If you are running a local LLM that takes a while to respond, you might want to set this to avoid timeouts." + }, + "prompt_templates": { + "description": 'A dictionary of prompt templates that can be used to customize the behavior of the LLM in certain situations. For example, set the "edit" key in order to change the prompt that is used for the /edit slash command. Each value in the dictionary is a string templated in mustache syntax, and filled in at runtime with the variables specific to the situation. See the documentation for more information.' + }, + "template_messages": { + "description": "A function that takes a list of messages and returns a prompt. This ensures that models like llama2, which are trained on specific chat formats, will always receive input in that format." + }, + "write_log": { + "description": "A function that is called upon every prompt and completion, by default to log to the file which can be viewed by clicking on the magnifying glass." + }, + "api_key": {"description": "The API key for the LLM provider."}, + "verify_ssl": { + "description": "Whether to verify SSL certificates for requests." + }, + "ca_bundle_path": { + "description": "Path to a custom CA bundle to use when making the HTTP request" + }, + "headers": { + "description": "Headers to use when making the HTTP request" + }, + "proxy": {"description": "Proxy URL to use when making the HTTP request"}, + "stop_tokens": {"description": "Tokens that will stop the completion."}, + "temperature": { + "description": "The sampling temperature used for generation." + }, + "top_p": { + "description": "The top_p sampling parameter used for generation." + }, + "top_k": { + "description": "The top_k sampling parameter used for generation." + }, + "presence_penalty": { + "description": "The presence penalty used for completions." + }, + "frequency_penalty": { + "description": "The frequency penalty used for completions." + }, + } + + def dict(self, **kwargs): + original_dict = super().dict(**kwargs) + original_dict.pop("write_log") + if self.template_messages is not None: + original_dict["template_messages"] = self.template_messages.__name__ + original_dict.pop("unique_id") + original_dict["class_name"] = self.__class__.__name__ + return original_dict + + async def start( + self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None + ): + """Start the connection to the LLM.""" + self.write_log = write_log + self.unique_id = unique_id + + async def stop(self): + """Stop the connection to the LLM.""" + pass + + def create_client_session(self): + if self.verify_ssl is False: + return aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=False), + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers=self.headers + ) + else: + ca_bundle_path = ( + certifi.where() if self.ca_bundle_path is None else self.ca_bundle_path + ) + ssl_context = ssl.create_default_context(cafile=ca_bundle_path) + return aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl_context=ssl_context), + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers=self.headers, + ) + + def collect_args(self, options: CompletionOptions) -> Dict[str, Any]: + """Collect the arguments for the LLM.""" + args = {**DEFAULT_ARGS.copy(), "model": self.model} + args.update(options.dict(exclude_unset=True, exclude_none=True)) + return args + + def compile_chat_messages( + self, + options: CompletionOptions, + msgs: List[ChatMessage], + functions: Optional[List[Any]] = None, + ) -> List[Dict]: + return compile_chat_messages( + model_name=options.model, + msgs=msgs, + context_length=self.context_length, + max_tokens=options.max_tokens, + functions=functions, + system_message=self.system_message, + ) + + def template_prompt_like_messages(self, prompt: str) -> str: + if self.template_messages is None: + return prompt + + msgs = [{"role": "user", "content": prompt}] + if self.system_message is not None: + msgs.insert(0, {"role": "system", "content": self.system_message}) + + return self.template_messages(msgs) + + async def stream_complete( + self, + prompt: str, + raw: bool = False, + model: str = None, + temperature: float = None, + top_p: float = None, + top_k: int = None, + presence_penalty: float = None, + frequency_penalty: float = None, + stop: Optional[List[str]] = None, + max_tokens: Optional[int] = None, + functions: Optional[List[Any]] = None, + log: bool = True, + ) -> Generator[Union[Any, List, Dict], None, None]: + """Yield completion response, either streamed or not.""" + options = CompletionOptions( + model=model or self.model, + temperature=temperature or self.temperature, + top_p=top_p or self.top_p, + top_k=top_k or self.top_k, + presence_penalty=presence_penalty or self.presence_penalty, + frequency_penalty=frequency_penalty or self.frequency_penalty, + stop=stop or self.stop_tokens, + max_tokens=max_tokens, + functions=functions, + ) + + prompt = prune_raw_prompt_from_top( + self.model, self.context_length, prompt, options.max_tokens + ) + + if not raw: + prompt = self.template_prompt_like_messages(prompt) + + if log: + self.write_log(prompt) + + completion = "" + async for chunk in self._stream_complete(prompt=prompt, options=options): + yield chunk + completion += chunk + + # if log: + # self.write_log(f"Completion: \n\n{completion}") + + dev_data_logger.capture( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + posthog_logger.capture_event( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + + async def complete( + self, + prompt: str, + raw: bool = False, + model: str = None, + temperature: float = None, + top_p: float = None, + top_k: int = None, + presence_penalty: float = None, + frequency_penalty: float = None, + stop: Optional[List[str]] = None, + max_tokens: Optional[int] = None, + functions: Optional[List[Any]] = None, + log: bool = True, + ) -> str: + """Yield completion response, either streamed or not.""" + options = CompletionOptions( + model=model or self.model, + temperature=temperature or self.temperature, + top_p=top_p or self.top_p, + top_k=top_k or self.top_k, + presence_penalty=presence_penalty or self.presence_penalty, + frequency_penalty=frequency_penalty or self.frequency_penalty, + stop=stop or self.stop_tokens, + max_tokens=max_tokens, + functions=functions, + ) + + prompt = prune_raw_prompt_from_top( + self.model, self.context_length, prompt, options.max_tokens + ) + + if not raw: + prompt = self.template_prompt_like_messages(prompt) + + if log: + self.write_log(prompt) + + completion = await self._complete(prompt=prompt, options=options) + + # if log: + # self.write_log(f"Completion: \n\n{completion}") + + dev_data_logger.capture( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + posthog_logger.capture_event( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + + return completion + + async def stream_chat( + self, + messages: List[ChatMessage], + model: str = None, + temperature: float = None, + top_p: float = None, + top_k: int = None, + presence_penalty: float = None, + frequency_penalty: float = None, + stop: Optional[List[str]] = None, + max_tokens: Optional[int] = None, + functions: Optional[List[Any]] = None, + log: bool = True, + ) -> Generator[Union[Any, List, Dict], None, None]: + """Yield completion response, either streamed or not.""" + options = CompletionOptions( + model=model or self.model, + temperature=temperature or self.temperature, + top_p=top_p or self.top_p, + top_k=top_k or self.top_k, + presence_penalty=presence_penalty or self.presence_penalty, + frequency_penalty=frequency_penalty or self.frequency_penalty, + stop=stop or self.stop_tokens, + max_tokens=max_tokens, + functions=functions, + ) + + messages = self.compile_chat_messages( + options=options, msgs=messages, functions=functions + ) + if self.template_messages is not None: + prompt = self.template_messages(messages) + else: + prompt = format_chat_messages(messages) + + if log: + self.write_log(prompt) + + completion = "" + + # Use the template_messages function if it exists and do a raw completion + if self.template_messages is None: + async for chunk in self._stream_chat(messages=messages, options=options): + yield chunk + if "content" in chunk: + completion += chunk["content"] + else: + async for chunk in self._stream_complete(prompt=prompt, options=options): + yield {"role": "assistant", "content": chunk} + completion += chunk + + # if log: + # self.write_log(f"Completion: \n\n{completion}") + + dev_data_logger.capture( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + posthog_logger.capture_event( + "tokens_generated", + {"model": self.model, "tokens": self.count_tokens(completion)}, + ) + + def _stream_complete( + self, prompt, options: CompletionOptions + ) -> Generator[str, None, None]: + """Stream the completion through generator.""" + raise NotImplementedError + + async def _complete( + self, prompt: str, options: CompletionOptions + ) -> Coroutine[Any, Any, str]: + """Return the completion of the text with the given temperature.""" + completion = "" + async for chunk in self._stream_complete(prompt=prompt, options=options): + completion += chunk + return completion + + async def _stream_chat( + self, messages: List[ChatMessage], options: CompletionOptions + ) -> Generator[Union[Any, List, Dict], None, None]: + """Stream the chat through generator.""" + if self.template_messages is None: + raise NotImplementedError( + "You must either implement template_messages or _stream_chat" + ) + + async for chunk in self._stream_complete( + prompt=self.template_messages(messages), options=options + ): + yield {"role": "assistant", "content": chunk} + + def count_tokens(self, text: str): + """Return the number of tokens in the given text.""" + return count_tokens(self.model, text) diff --git a/server/continuedev/libs/llm/ggml.py b/server/continuedev/libs/llm/ggml.py new file mode 100644 index 00000000..55d580a8 --- /dev/null +++ b/server/continuedev/libs/llm/ggml.py @@ -0,0 +1,226 @@ +import json +from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional + +from pydantic import Field + +from ...core.main import ChatMessage +from ..util.logging import logger +from .base import LLM, CompletionOptions +from .openai import CHAT_MODELS +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class GGML(LLM): + """ + See our [5 minute quickstart](https://github.com/continuedev/ggml-server-example) to run any model locally with ggml. While these models don't yet perform as well, they are free, entirely private, and run offline. + + Once the model is running on localhost:8000, change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.ggml import GGML + + config = ContinueConfig( + ... + models=Models( + default=GGML( + max_context_length=2048, + server_url="http://localhost:8000") + ) + ) + ``` + """ + + server_url: str = Field( + "http://localhost:8000", + description="URL of the OpenAI-compatible server where the model is being served", + ) + model: str = Field( + "ggml", description="The name of the model to use (optional for the GGML class)" + ) + + api_base: Optional[str] = Field(None, description="OpenAI API base URL.") + + api_type: Optional[Literal["azure", "openai"]] = Field( + None, description="OpenAI API type." + ) + + api_version: Optional[str] = Field( + None, description="OpenAI API version. For use with Azure OpenAI Service." + ) + + engine: Optional[str] = Field( + None, description="OpenAI engine. For use with Azure OpenAI Service." + ) + + template_messages: Optional[ + Callable[[List[Dict[str, str]]], str] + ] = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + class Config: + arbitrary_types_allowed = True + + def get_headers(self): + headers = { + "Content-Type": "application/json", + } + if self.api_key is not None: + if self.api_type == "azure": + headers["api-key"] = self.api_key + else: + headers["Authorization"] = f"Bearer {self.api_key}" + + return headers + + def get_full_server_url(self, endpoint: str): + endpoint = endpoint.lstrip("/").rstrip("/") + + if self.api_type == "azure": + if self.engine is None or self.api_version is None or self.api_base is None: + raise Exception( + "For Azure OpenAI Service, you must specify engine, api_version, and api_base." + ) + + return f"{self.api_base}/openai/deployments/{self.engine}/{endpoint}?api-version={self.api_version}" + else: + return f"{self.server_url}/v1/{endpoint}" + + async def _raw_stream_complete(self, prompt, options): + args = self.collect_args(options) + + async with self.create_client_session() as client_session: + async with client_session.post( + self.get_full_server_url(endpoint="completions"), + json={ + "prompt": prompt, + "stream": True, + **args, + }, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + if resp.status != 200: + raise Exception( + f"Error calling /chat/completions endpoint: {resp.status}" + ) + + async for line in resp.content.iter_any(): + if line: + chunks = line.decode("utf-8") + for chunk in chunks.split("\n"): + if ( + chunk.startswith(": ping - ") + or chunk.startswith("data: [DONE]") + or chunk.strip() == "" + ): + continue + elif chunk.startswith("data: "): + chunk = chunk[6:] + try: + j = json.loads(chunk) + except Exception: + continue + if ( + "choices" in j + and len(j["choices"]) > 0 + and "text" in j["choices"][0] + ): + yield j["choices"][0]["text"] + + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) + + async def generator(): + async with self.create_client_session() as client_session: + async with client_session.post( + self.get_full_server_url(endpoint="chat/completions"), + json={"messages": messages, "stream": True, **args}, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + if resp.status != 200: + raise Exception( + f"Error calling /chat/completions endpoint: {resp.status}" + ) + + async for line, end in resp.content.iter_chunks(): + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if ( + chunk.strip() == "" + or json_chunk.startswith(": ping - ") + or json_chunk.startswith("data: [DONE]") + ): + continue + try: + yield json.loads(chunk[6:])["choices"][0]["delta"] + except: + pass + + # Because quite often the first attempt fails, and it works thereafter + try: + async for chunk in generator(): + yield chunk + except Exception as e: + logger.warning(f"Error calling /chat/completions endpoint: {e}") + async for chunk in generator(): + yield chunk + + async def _raw_complete(self, prompt: str, options) -> Coroutine[Any, Any, str]: + args = self.collect_args(options) + + async with self.create_client_session() as client_session: + async with client_session.post( + self.get_full_server_url(endpoint="completions"), + json={ + "prompt": prompt, + **args, + }, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + if resp.status != 200: + raise Exception( + f"Error calling /chat/completions endpoint: {resp.status}" + ) + + text = await resp.text() + try: + completion = json.loads(text)["choices"][0]["text"] + return completion + except Exception as e: + raise Exception( + f"Error calling /completion endpoint: {e}\n\nResponse text: {text}" + ) + + async def _complete(self, prompt: str, options: CompletionOptions): + completion = "" + if self.model in CHAT_MODELS: + async for chunk in self._stream_chat( + [{"role": "user", "content": prompt}], options + ): + if "content" in chunk: + completion += chunk["content"] + + else: + async for chunk in self._raw_stream_complete(prompt, options): + completion += chunk + + return completion + + async def _stream_complete(self, prompt, options: CompletionOptions): + if self.model in CHAT_MODELS: + async for chunk in self._stream_chat( + [{"role": "user", "content": prompt}], options + ): + if "content" in chunk: + yield chunk["content"] + + else: + async for chunk in self._raw_stream_complete(prompt, options): + yield chunk diff --git a/server/continuedev/libs/llm/google_palm_api.py b/server/continuedev/libs/llm/google_palm_api.py new file mode 100644 index 00000000..3379fefe --- /dev/null +++ b/server/continuedev/libs/llm/google_palm_api.py @@ -0,0 +1,50 @@ +from typing import List + +import requests +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM + + +class GooglePaLMAPI(LLM): + """ + The Google PaLM API is currently in public preview, so production applications are not supported yet. However, you can [create an API key in Google MakerSuite](https://makersuite.google.com/u/2/app/apikey) and begin trying out the `chat-bison-001` model. Change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.core.models import Models + from continuedev.libs.llm.hf_inference_api import GooglePaLMAPI + + config = ContinueConfig( + ... + models=Models( + default=GooglePaLMAPI( + model="chat-bison-001" + api_key="<MAKERSUITE_API_KEY>", + ) + ) + ``` + """ + + api_key: str = Field(..., description="Google PaLM API key") + + model: str = "chat-bison-001" + + async def _stream_complete(self, prompt, options): + api_url = f"https://generativelanguage.googleapis.com/v1beta2/models/{self.model}:generateMessage?key={self.api_key}" + body = {"prompt": {"messages": [{"content": prompt}]}} + response = requests.post(api_url, json=body) + yield response.json()["candidates"][0]["content"] + + async def _stream_chat(self, messages: List[ChatMessage], options): + msg_lst = [] + for message in messages: + msg_lst.append({"content": message["content"]}) + + api_url = f"https://generativelanguage.googleapis.com/v1beta2/models/{self.model}:generateMessage?key={self.api_key}" + body = {"prompt": {"messages": msg_lst}} + response = requests.post(api_url, json=body) + yield { + "content": response.json()["candidates"][0]["content"], + "role": "assistant", + } diff --git a/server/continuedev/libs/llm/hf_inference_api.py b/server/continuedev/libs/llm/hf_inference_api.py new file mode 100644 index 00000000..990ec7c8 --- /dev/null +++ b/server/continuedev/libs/llm/hf_inference_api.py @@ -0,0 +1,78 @@ +from typing import Callable, Dict, List, Union + +from huggingface_hub import InferenceClient +from pydantic import Field + +from .base import LLM, CompletionOptions +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class HuggingFaceInferenceAPI(LLM): + """ + Hugging Face Inference API is a great option for newly released language models. Sign up for an account and add billing [here](https://huggingface.co/settings/billing), access the Inference Endpoints [here](https://ui.endpoints.huggingface.co), click on “New endpoint”, and fill out the form (e.g. select a model like [WizardCoder-Python-34B-V1.0](https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0)), and then deploy your model by clicking “Create Endpoint”. Change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.core.models import Models + from continuedev.libs.llm.hf_inference_api import HuggingFaceInferenceAPI + + config = ContinueConfig( + ... + models=Models( + default=HuggingFaceInferenceAPI( + endpoint_url="<INFERENCE_API_ENDPOINT_URL>", + hf_token="<HUGGING_FACE_TOKEN>", + ) + ) + ``` + """ + + model: str = Field( + "Hugging Face Inference API", + description="The name of the model to use (optional for the HuggingFaceInferenceAPI class)", + ) + hf_token: str = Field(..., description="Your Hugging Face API token") + endpoint_url: str = Field( + None, description="Your Hugging Face Inference API endpoint URL" + ) + + template_messages: Union[ + Callable[[List[Dict[str, str]]], str], None + ] = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + class Config: + arbitrary_types_allowed = True + + def collect_args(self, options: CompletionOptions): + options.stop = None + args = super().collect_args(options) + + if "max_tokens" in args: + args["max_new_tokens"] = args["max_tokens"] + del args["max_tokens"] + if "stop" in args: + args["stop_sequences"] = args["stop"] + del args["stop"] + + return args + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + + client = InferenceClient(self.endpoint_url, token=self.hf_token) + + stream = client.text_generation(prompt, stream=True, details=True, **args) + + for r in stream: + # skip special tokens + if r.token.special: + continue + # stop if we encounter a stop sequence + if options.stop is not None: + if r.token.text in options.stop: + break + yield r.token.text diff --git a/server/continuedev/libs/llm/hf_tgi.py b/server/continuedev/libs/llm/hf_tgi.py new file mode 100644 index 00000000..62458db4 --- /dev/null +++ b/server/continuedev/libs/llm/hf_tgi.py @@ -0,0 +1,65 @@ +import json +from typing import Any, Callable, List + +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM, CompletionOptions +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class HuggingFaceTGI(LLM): + model: str = "huggingface-tgi" + server_url: str = Field( + "http://localhost:8080", description="URL of your TGI server" + ) + + template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + class Config: + arbitrary_types_allowed = True + + def collect_args(self, options: CompletionOptions) -> Any: + args = super().collect_args(options) + args = {**args, "max_new_tokens": args.get("max_tokens", 1024), "best_of": 1} + args.pop("max_tokens", None) + args.pop("model", None) + args.pop("functions", None) + return args + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + + async with self.create_client_session() as client_session: + async with client_session.post( + f"{self.server_url}/generate_stream", + json={"inputs": prompt, "parameters": args}, + headers={"Content-Type": "application/json"}, + proxy=self.proxy, + ) as resp: + async for line in resp.content.iter_any(): + if line: + text = line.decode("utf-8") + chunks = text.split("\n") + + for chunk in chunks: + if chunk.startswith("data: "): + chunk = chunk[len("data: ") :] + elif chunk.startswith("data:"): + chunk = chunk[len("data:") :] + + if chunk.strip() == "": + continue + + try: + json_chunk = json.loads(chunk) + except Exception as e: + print(f"Error parsing JSON: {e}") + continue + + yield json_chunk["token"]["text"] diff --git a/server/continuedev/libs/llm/hugging_face.py b/server/continuedev/libs/llm/hugging_face.py new file mode 100644 index 00000000..c2e934c0 --- /dev/null +++ b/server/continuedev/libs/llm/hugging_face.py @@ -0,0 +1,19 @@ +# TODO: This class is far out of date + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .llm import LLM + + +class HuggingFace(LLM): + def __init__(self, model_path: str = "Salesforce/codegen-2B-mono"): + self.model_path = model_path + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = AutoModelForCausalLM.from_pretrained(model_path) + + def complete(self, prompt: str, **kwargs): + args = {"max_tokens": 100} + args.update(kwargs) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + generated_ids = self.model.generate(input_ids, max_length=args["max_tokens"]) + return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) diff --git a/server/continuedev/libs/llm/llamacpp.py b/server/continuedev/libs/llm/llamacpp.py new file mode 100644 index 00000000..bc856a52 --- /dev/null +++ b/server/continuedev/libs/llm/llamacpp.py @@ -0,0 +1,86 @@ +import json +from typing import Any, Callable, Dict + +from pydantic import Field + +from .base import LLM +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class LlamaCpp(LLM): + """ + Run the llama.cpp server binary to start the API server. If running on a remote server, be sure to set host to 0.0.0.0: + + ```shell + .\server.exe -c 4096 --host 0.0.0.0 -t 16 --mlock -m models\meta\llama\codellama-7b-instruct.Q8_0.gguf + ``` + + After it's up and running, change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.llamacpp import LlamaCpp + + config = ContinueConfig( + ... + models=Models( + default=LlamaCpp( + max_context_length=4096, + server_url="http://localhost:8080") + ) + ) + ``` + """ + + model: str = "llamacpp" + server_url: str = Field("http://localhost:8080", description="URL of the server") + + llama_cpp_args: Dict[str, Any] = Field( + {"stop": ["[INST]"]}, + description="A list of additional arguments to pass to llama.cpp. See [here](https://github.com/ggerganov/llama.cpp/tree/master/examples/server#api-endpoints) for the complete catalog of options.", + ) + + template_messages: Callable = llama2_template_messages + prompt_templates = { + "edit": simplified_edit_prompt, + } + + class Config: + arbitrary_types_allowed = True + + def collect_args(self, options) -> Any: + args = super().collect_args(options) + if "max_tokens" in args: + args["n_predict"] = args["max_tokens"] + del args["max_tokens"] + if "frequency_penalty" in args: + del args["frequency_penalty"] + if "presence_penalty" in args: + del args["presence_penalty"] + + for k, v in self.llama_cpp_args.items(): + if k not in args: + args[k] = v + + return args + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + headers = {"Content-Type": "application/json"} + + async def server_generator(): + async with self.create_client_session() as client_session: + async with client_session.post( + f"{self.server_url}/completion", + json={"prompt": prompt, "stream": True, **args}, + headers=headers, + proxy=self.proxy, + ) as resp: + async for line in resp.content: + content = line.decode("utf-8") + if content.strip() == "": + continue + yield json.loads(content[6:])["content"] + + async for chunk in server_generator(): + yield chunk diff --git a/server/continuedev/libs/llm/ollama.py b/server/continuedev/libs/llm/ollama.py new file mode 100644 index 00000000..82cbc852 --- /dev/null +++ b/server/continuedev/libs/llm/ollama.py @@ -0,0 +1,106 @@ +import json +from typing import Callable + +import aiohttp +from pydantic import Field + +from ...core.main import ContinueCustomException +from ..util.logging import logger +from .base import LLM +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class Ollama(LLM): + """ + [Ollama](https://ollama.ai/) is an application for Mac and Linux that makes it easy to locally run open-source models, including Llama-2. Download the app from the website, and it will walk you through setup in a couple of minutes. You can also read more in their [README](https://github.com/jmorganca/ollama). Continue can then be configured to use the `Ollama` LLM class: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.ollama import Ollama + + config = ContinueConfig( + ... + models=Models( + default=Ollama(model="llama2") + ) + ) + ``` + """ + + model: str = "llama2" + server_url: str = Field( + "http://localhost:11434", description="URL of the Ollama server" + ) + + _client_session: aiohttp.ClientSession = None + + template_messages: Callable = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + class Config: + arbitrary_types_allowed = True + + async def start(self, **kwargs): + await super().start(**kwargs) + self._client_session = self.create_client_session() + try: + async with self._client_session.post( + f"{self.server_url}/api/generate", + proxy=self.proxy, + json={ + "prompt": "", + "model": self.model, + }, + ) as _: + pass + except Exception as e: + logger.warning(f"Error pre-loading Ollama model: {e}") + + async def stop(self): + await self._client_session.close() + + async def get_downloaded_models(self): + async with self._client_session.get( + f"{self.server_url}/api/tags", + proxy=self.proxy, + ) as resp: + js_data = await resp.json() + return list(map(lambda x: x["name"], js_data["models"])) + + async def _stream_complete(self, prompt, options): + async with self._client_session.post( + f"{self.server_url}/api/generate", + json={ + "template": prompt, + "model": self.model, + "system": self.system_message, + "options": {"temperature": options.temperature}, + }, + proxy=self.proxy, + ) as resp: + if resp.status == 400: + txt = await resp.text() + extra_msg = "" + if "no such file" in txt: + extra_msg = f"\n\nThis means that the model '{self.model}' is not downloaded.\n\nYou have the following models downloaded: {', '.join(await self.get_downloaded_models())}.\n\nTo download this model, run `ollama run {self.model}` in your terminal." + raise ContinueCustomException( + f"Ollama returned an error: {txt}{extra_msg}", + "Invalid request to Ollama", + ) + elif resp.status != 200: + raise ContinueCustomException( + f"Ollama returned an error: {await resp.text()}", + "Invalid request to Ollama", + ) + async for line in resp.content.iter_any(): + if line: + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + yield j["response"] diff --git a/server/continuedev/libs/llm/openai.py b/server/continuedev/libs/llm/openai.py new file mode 100644 index 00000000..ba29279b --- /dev/null +++ b/server/continuedev/libs/llm/openai.py @@ -0,0 +1,156 @@ +from typing import Callable, List, Literal, Optional + +import certifi +import openai +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM + +CHAT_MODELS = { + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-4", + "gpt-3.5-turbo-0613", + "gpt-4-32k", +} +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16_384, + "gpt-4": 8192, + "gpt-35-turbo-16k": 16_384, + "gpt-35-turbo-0613": 4096, + "gpt-35-turbo": 4096, + "gpt-4-32k": 32_768, +} + + +class OpenAI(LLM): + """ + The OpenAI class can be used to access OpenAI models like gpt-4 and gpt-3.5-turbo. + + If you are locally serving a model that uses an OpenAI-compatible server, you can simply change the `api_base` in the `OpenAI` class like this: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.openai import OpenAI + + config = ContinueConfig( + ... + models=Models( + default=OpenAI( + api_key="EMPTY", + model="<MODEL_NAME>", + api_base="http://localhost:8000", # change to your server + ) + ) + ) + ``` + + Options for serving models locally with an OpenAI-compatible server include: + + - [text-gen-webui](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai#setup--installation) + - [FastChat](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) + - [LocalAI](https://localai.io/basics/getting_started/) + - [llama-cpp-python](https://github.com/abetlen/llama-cpp-python#web-server) + """ + + api_key: str = Field( + ..., + description="OpenAI API key", + ) + + proxy: Optional[str] = Field(None, description="Proxy URL to use for requests.") + + api_base: Optional[str] = Field(None, description="OpenAI API base URL.") + + api_type: Optional[Literal["azure", "openai"]] = Field( + None, description="OpenAI API type." + ) + + api_version: Optional[str] = Field( + None, description="OpenAI API version. For use with Azure OpenAI Service." + ) + + engine: Optional[str] = Field( + None, description="OpenAI engine. For use with Azure OpenAI Service." + ) + + async def start( + self, unique_id: Optional[str] = None, write_log: Callable[[str], None] = None + ): + await super().start(write_log=write_log, unique_id=unique_id) + + if self.context_length is None: + self.context_length = MAX_TOKENS_FOR_MODEL.get(self.model, 4096) + + openai.api_key = self.api_key + if self.api_type is not None: + openai.api_type = self.api_type + if self.api_base is not None: + openai.api_base = self.api_base + if self.api_version is not None: + openai.api_version = self.api_version + + if self.verify_ssl is not None and self.verify_ssl is False: + openai.verify_ssl_certs = False + + if self.proxy is not None: + openai.proxy = self.proxy + + openai.ca_bundle_path = self.ca_bundle_path or certifi.where() + + def collect_args(self, options): + args = super().collect_args(options) + if self.engine is not None: + args["engine"] = self.engine + + if not args["model"].endswith("0613") and "functions" in args: + del args["functions"] + + return args + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + args["stream"] = True + + if args["model"] in CHAT_MODELS: + async for chunk in await openai.ChatCompletion.acreate( + messages=[{"role": "user", "content": prompt}], + **args, + headers=self.headers, + ): + if len(chunk.choices) > 0 and "content" in chunk.choices[0].delta: + yield chunk.choices[0].delta.content + else: + async for chunk in await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers): + if len(chunk.choices) > 0: + yield chunk.choices[0].text + + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) + + async for chunk in await openai.ChatCompletion.acreate( + messages=messages, + stream=True, + **args, + headers=self.headers, + ): + if not hasattr(chunk, "choices") or len(chunk.choices) == 0: + continue + yield chunk.choices[0].delta + + async def _complete(self, prompt: str, options): + args = self.collect_args(options) + + if args["model"] in CHAT_MODELS: + resp = await openai.ChatCompletion.acreate( + messages=[{"role": "user", "content": prompt}], + **args, + headers=self.headers, + ) + return resp.choices[0].message.content + else: + return ( + (await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers)).choices[0].text + ) diff --git a/server/continuedev/libs/llm/openai_free_trial.py b/server/continuedev/libs/llm/openai_free_trial.py new file mode 100644 index 00000000..b6e707f9 --- /dev/null +++ b/server/continuedev/libs/llm/openai_free_trial.py @@ -0,0 +1,83 @@ +from typing import Callable, List, Optional + +from ...core.main import ChatMessage +from .base import LLM +from .openai import OpenAI +from .proxy_server import ProxyServer + + +class OpenAIFreeTrial(LLM): + """ + With the `OpenAIFreeTrial` `LLM`, new users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code. + + Once you are using Continue regularly though, you will need to add an OpenAI API key that has access to GPT-4 by following these steps: + + 1. Copy your API key from https://platform.openai.com/account/api-keys + 2. Open `~/.continue/config.py`. You can do this by using the '/config' command in Continue + 3. Change the default LLMs to look like this: + + ```python title="~/.continue/config.py" + API_KEY = "<API_KEY>" + config = ContinueConfig( + ... + models=Models( + default=OpenAIFreeTrial(model="gpt-4", api_key=API_KEY), + summarize=OpenAIFreeTrial(model="gpt-3.5-turbo", api_key=API_KEY) + ) + ) + ``` + + The `OpenAIFreeTrial` class will automatically switch to using your API key instead of ours. If you'd like to explicitly use one or the other, you can use the `ProxyServer` or `OpenAI` classes instead. + + These classes support any models available through the OpenAI API, assuming your API key has access, including "gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", and "gpt-4-32k". + """ + + api_key: Optional[str] = None + + llm: Optional[LLM] = None + + def update_llm_properties(self): + if self.llm is not None: + self.llm.system_message = self.system_message + + async def start( + self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None + ): + await super().start(write_log=write_log, unique_id=unique_id) + if self.api_key is None or self.api_key.strip() == "": + self.llm = ProxyServer( + model=self.model, + verify_ssl=self.verify_ssl, + ca_bundle_path=self.ca_bundle_path, + ) + else: + self.llm = OpenAI( + api_key=self.api_key, + model=self.model, + verify_ssl=self.verify_ssl, + ca_bundle_path=self.ca_bundle_path, + ) + + await self.llm.start(write_log=write_log, unique_id=unique_id) + + async def stop(self): + await self.llm.stop() + + async def _complete(self, prompt: str, options): + self.update_llm_properties() + return await self.llm._complete(prompt, options) + + async def _stream_complete(self, prompt, options): + self.update_llm_properties() + resp = self.llm._stream_complete(prompt, options) + async for item in resp: + yield item + + async def _stream_chat(self, messages: List[ChatMessage], options): + self.update_llm_properties() + resp = self.llm._stream_chat(messages=messages, options=options) + async for item in resp: + yield item + + def count_tokens(self, text: str): + return self.llm.count_tokens(text) diff --git a/server/continuedev/libs/llm/prompt_utils.py b/server/continuedev/libs/llm/prompt_utils.py new file mode 100644 index 00000000..930b5220 --- /dev/null +++ b/server/continuedev/libs/llm/prompt_utils.py @@ -0,0 +1,76 @@ +from typing import Dict, List, Union + +from ...models.filesystem import RangeInFileWithContents +from ...models.filesystem_edit import FileEdit + + +class MarkdownStyleEncoderDecoder: + # Filename -> the part of the file you care about + range_in_files: List[RangeInFileWithContents] + + def __init__(self, range_in_files: List[RangeInFileWithContents]): + self.range_in_files = range_in_files + + def encode(self) -> str: + return "\n\n".join( + [ + f"File ({rif.filepath})\n```\n{rif.contents}\n```" + for rif in self.range_in_files + ] + ) + + def _suggestions_to_file_edits(self, suggestions: Dict[str, str]) -> List[FileEdit]: + file_edits: List[FileEdit] = [] + for suggestion_filepath, suggestion in suggestions.items(): + matching_rifs = list( + filter(lambda r: r.filepath == suggestion_filepath, self.range_in_files) + ) + if len(matching_rifs) > 0: + range_in_file = matching_rifs[0] + file_edits.append( + FileEdit( + range=range_in_file.range, + filepath=range_in_file.filepath, + replacement=suggestion, + ) + ) + + return file_edits + + def _decode_to_suggestions(self, completion: str) -> Dict[str, str]: + if len(self.range_in_files) == 0: + return {} + + if "```" not in completion: + completion = "```\n" + completion + "\n```" + if completion.strip().splitlines()[0].strip() == "```": + first_filepath = self.range_in_files[0].filepath + completion = f"File ({first_filepath})\n" + completion + + suggestions: Dict[str, str] = {} + current_file_lines: List[str] = [] + current_filepath: Union[str, None] = None + last_was_file = False + inside_file = False + for line in completion.splitlines(): + if line.strip().startswith("File ("): + last_was_file = True + current_filepath = line.strip()[6:-1] + elif last_was_file and line.startswith("```"): + last_was_file = False + inside_file = True + elif inside_file: + if line.startswith("```"): + inside_file = False + suggestions[current_filepath] = "\n".join(current_file_lines) + current_file_lines = [] + current_filepath = None + else: + current_file_lines.append(line) + + return suggestions + + def decode(self, completion: str) -> List[FileEdit]: + suggestions = self._decode_to_suggestions(completion) + file_edits = self._suggestions_to_file_edits(suggestions) + return file_edits diff --git a/server/continuedev/libs/llm/prompts/chat.py b/server/continuedev/libs/llm/prompts/chat.py new file mode 100644 index 00000000..036f1b1a --- /dev/null +++ b/server/continuedev/libs/llm/prompts/chat.py @@ -0,0 +1,174 @@ +from textwrap import dedent +from typing import Dict, List + +from anthropic import AI_PROMPT, HUMAN_PROMPT + + +def anthropic_template_messages(messages: List[Dict[str, str]]) -> str: + prompt = "" + + # Anthropic prompt must start with a Human turn + if ( + len(messages) > 0 + and messages[0]["role"] != "user" + and messages[0]["role"] != "system" + ): + prompt += f"{HUMAN_PROMPT} Hello." + for msg in messages: + prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " + + prompt += AI_PROMPT + return prompt + + +def template_alpaca_messages(msgs: List[Dict[str, str]]) -> str: + prompt = "" + + if msgs[0]["role"] == "system": + prompt += f"{msgs[0]['content']}\n" + msgs.pop(0) + + for msg in msgs: + prompt += "### Instruction:\n" if msg["role"] == "user" else "### Response:\n" + prompt += f"{msg['content']}\n" + + prompt += "### Response:\n" + + return prompt + + +def raw_input_template(msgs: List[Dict[str, str]]) -> str: + return msgs[-1]["content"] + + +SQL_CODER_DEFAULT_SCHEMA = """\ +CREATE TABLE products ( + product_id INTEGER PRIMARY KEY, -- Unique ID for each product + name VARCHAR(50), -- Name of the product + price DECIMAL(10,2), -- Price of each unit of the product + quantity INTEGER -- Current quantity in stock +); + +CREATE TABLE customers ( + customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer + name VARCHAR(50), -- Name of the customer + address VARCHAR(100) -- Mailing address of the customer +); + +CREATE TABLE salespeople ( + salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson + name VARCHAR(50), -- Name of the salesperson + region VARCHAR(50) -- Geographic sales region +); + +CREATE TABLE sales ( + sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale + product_id INTEGER, -- ID of product sold + customer_id INTEGER, -- ID of customer who made purchase + salesperson_id INTEGER, -- ID of salesperson who made the sale + sale_date DATE, -- Date the sale occurred + quantity INTEGER -- Quantity of product sold +); + +CREATE TABLE product_suppliers ( + supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier + product_id INTEGER, -- Product ID supplied + supply_price DECIMAL(10,2) -- Unit price charged by supplier +); + +-- sales.product_id can be joined with products.product_id +-- sales.customer_id can be joined with customers.customer_id +-- sales.salesperson_id can be joined with salespeople.salesperson_id +-- product_suppliers.product_id can be joined with products.product_id +""" + + +def _sqlcoder_template_messages( + msgs: List[Dict[str, str]], schema: str = SQL_CODER_DEFAULT_SCHEMA +) -> str: + question = msgs[-1]["content"] + return f"""\ +Your task is to convert a question into a SQL query, given a Postgres database schema. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. +- When creating a ratio, always cast the numerator as float + +### Input: +Generate a SQL query that answers the question `{question}`. +This query will run on a database whose schema is represented in this string: +{schema} + +### Response: +Based on your instructions, here is the SQL query I have generated to answer the question `{question}`: +```sql +""" + + +def sqlcoder_template_messages(schema: str = SQL_CODER_DEFAULT_SCHEMA): + if schema == "<MY_DATABASE_SCHEMA>" or schema == "": + schema = SQL_CODER_DEFAULT_SCHEMA + + def fn(msgs): + return _sqlcoder_template_messages(msgs, schema=schema) + + fn.__name__ = "sqlcoder_template_messages" + return fn + + +def llama2_template_messages(msgs: List[Dict[str, str]]) -> str: + if len(msgs) == 0: + return "" + + if msgs[0]["role"] == "assistant": + # These models aren't trained to handle assistant message coming first, + # and typically these are just introduction messages from Continue + msgs.pop(0) + + prompt = "" + has_system = msgs[0]["role"] == "system" + + if has_system and msgs[0]["content"].strip() == "": + has_system = False + msgs = msgs[1:] + + if has_system: + system_message = dedent( + f"""\ + <<SYS>> + {msgs[0]["content"]} + <</SYS>> + + """ + ) + if len(msgs) > 1: + prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" + else: + prompt += f"[INST] {system_message} [/INST]" + return + + for i in range(2 if has_system else 0, len(msgs)): + if msgs[i]["role"] == "user": + prompt += f"[INST] {msgs[i]['content']} [/INST]" + else: + prompt += msgs[i]["content"] + " " + + return prompt + + +def code_llama_template_messages(msgs: List[Dict[str, str]]) -> str: + return f"[INST] {msgs[-1]['content']}\n[/INST]" + + +def extra_space_template_messages(msgs: List[Dict[str, str]]) -> str: + return f" {msgs[-1]['content']}" + + +def code_llama_python_template_messages(msgs: List[Dict[str, str]]) -> str: + return dedent( + f"""\ + [INST] + You are an expert Python programmer and personal assistant, here is your task: {msgs[-1]['content']} + Your answer should start with a [PYTHON] tag and end with a [/PYTHON] tag. + [/INST]""" + ) diff --git a/server/continuedev/libs/llm/prompts/edit.py b/server/continuedev/libs/llm/prompts/edit.py new file mode 100644 index 00000000..eaa694c5 --- /dev/null +++ b/server/continuedev/libs/llm/prompts/edit.py @@ -0,0 +1,27 @@ +from textwrap import dedent + +simplified_edit_prompt = dedent( + """\ + Consider the following code: + ``` + {{{code_to_edit}}} + ``` + Edit the code to perfectly satisfy the following user request: + {{{user_input}}} + Output nothing except for the code. No code block, no English explanation, no start/end tags.""" +) + +simplest_edit_prompt = dedent( + """\ + Here is the code before editing: + ``` + {{{code_to_edit}}} + ``` + + Here is the edit requested: + "{{{user_input}}}" + + Here is the code after editing:""" +) + +codellama_infill_edit_prompt = "{{file_prefix}}<FILL>{{file_suffix}}" diff --git a/server/continuedev/libs/llm/proxy_server.py b/server/continuedev/libs/llm/proxy_server.py new file mode 100644 index 00000000..7c3462eb --- /dev/null +++ b/server/continuedev/libs/llm/proxy_server.py @@ -0,0 +1,108 @@ +import json +import traceback +from typing import List + +import aiohttp + +from ...core.main import ChatMessage +from ..util.telemetry import posthog_logger +from .base import LLM + +# SERVER_URL = "http://127.0.0.1:8080" +SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" + +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192, +} + + +class ProxyServer(LLM): + _client_session: aiohttp.ClientSession + + class Config: + arbitrary_types_allowed = True + + async def start( + self, + **kwargs, + ): + await super().start(**kwargs) + self._client_session = self.create_client_session() + + self.context_length = MAX_TOKENS_FOR_MODEL[self.model] + + async def stop(self): + await self._client_session.close() + + def get_headers(self): + return {"unique_id": self.unique_id} + + async def _complete(self, prompt: str, options): + args = self.collect_args(options) + + async with self._client_session.post( + f"{SERVER_URL}/complete", + json={"messages": [{"role": "user", "content": prompt}], **args}, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + resp_text = await resp.text() + if resp.status != 200: + raise Exception(resp_text) + + return resp_text + + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) + async with self._client_session.post( + f"{SERVER_URL}/stream_chat", + json={"messages": messages, **args}, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + if resp.status != 200: + raise Exception(await resp.text()) + + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + json_chunk = "{}" if json_chunk == "" else json_chunk + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + loaded_chunk = json.loads(chunk) + yield loaded_chunk + + except Exception as e: + posthog_logger.capture_event( + "proxy_server_parse_error", + { + "error_title": "Proxy server stream_chat parsing failed", + "error_message": "\n".join( + traceback.format_exception(e) + ), + }, + ) + else: + break + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + + async with self._client_session.post( + f"{SERVER_URL}/stream_complete", + json={"messages": [{"role": "user", "content": prompt}], **args}, + headers=self.get_headers(), + proxy=self.proxy, + ) as resp: + if resp.status != 200: + raise Exception(await resp.text()) + + async for line in resp.content.iter_any(): + if line: + decoded_line = line.decode("utf-8") + yield decoded_line diff --git a/server/continuedev/libs/llm/queued.py b/server/continuedev/libs/llm/queued.py new file mode 100644 index 00000000..2db749eb --- /dev/null +++ b/server/continuedev/libs/llm/queued.py @@ -0,0 +1,77 @@ +import asyncio +from typing import Any, List, Union + +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM, CompletionOptions + + +class QueuedLLM(LLM): + """ + QueuedLLM exists to make up for LLM servers that cannot handle multiple requests at once. It uses a lock to ensure that only one request is being processed at a time. + + If you are already using another LLM class and are experiencing this problem, you can just wrap it with the QueuedLLM class like this: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.queued import QueuedLLM + + config = ContinueConfig( + ... + models=Models( + default=QueuedLLM(llm=<OTHER_LLM_CLASS>) + ) + ) + ``` + """ + + llm: LLM = Field(..., description="The LLM to wrap with a lock") + _lock: asyncio.Lock + + model: str = "queued" + + def dict(self, **kwargs): + return self.llm.dict(**kwargs) + + async def start(self, *args, **kwargs): + await super().start(*args, **kwargs) + await self.llm.start(*args, **kwargs) + self._lock = asyncio.Lock() + self.model = self.llm.model + self.template_messages = self.llm.template_messages + self.prompt_templates = self.llm.prompt_templates + self.context_length = self.llm.context_length + + async def stop(self): + await self.llm.stop() + + def collect_args(self, options: CompletionOptions): + return self.llm.collect_args(options) + + def compile_chat_messages( + self, + options: CompletionOptions, + msgs: List[ChatMessage], + functions: Union[List[Any], None] = None, + ): + return self.llm.compile_chat_messages(options, msgs, functions) + + def template_prompt_like_messages(self, prompt: str) -> str: + return self.llm.template_prompt_like_messages(prompt) + + async def _complete(self, prompt: str, options: CompletionOptions): + async with self._lock: + resp = await self.llm._complete(prompt, options) + return resp + + async def _stream_complete(self, prompt: str, options: CompletionOptions): + async with self._lock: + async for chunk in self.llm._stream_complete(prompt, options): + yield chunk + + async def _stream_chat( + self, messages: List[ChatMessage], options: CompletionOptions + ): + async with self._lock: + async for chunk in self.llm._stream_chat(messages, options): + yield chunk diff --git a/server/continuedev/libs/llm/replicate.py b/server/continuedev/libs/llm/replicate.py new file mode 100644 index 00000000..3423193b --- /dev/null +++ b/server/continuedev/libs/llm/replicate.py @@ -0,0 +1,78 @@ +import concurrent.futures +from typing import List + +import replicate +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM +from .prompts.edit import simplified_edit_prompt + + +class ReplicateLLM(LLM): + """ + Replicate is a great option for newly released language models or models that you've deployed through their platform. Sign up for an account [here](https://replicate.ai/), copy your API key, and then select any model from the [Replicate Streaming List](https://replicate.com/collections/streaming-language-models). Change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.core.models import Models + from continuedev.libs.llm.replicate import ReplicateLLM + + config = ContinueConfig( + ... + models=Models( + default=ReplicateLLM( + model="replicate/codellama-13b-instruct:da5676342de1a5a335b848383af297f592b816b950a43d251a0a9edd0113604b", + api_key="my-replicate-api-key") + ) + ) + ``` + + If you don't specify the `model` parameter, it will default to `replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781`. + """ + + api_key: str = Field(..., description="Replicate API key") + + model: str = "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781" + + _client: replicate.Client = None + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + async def start(self, **kwargs): + await super().start(**kwargs) + self._client = replicate.Client(api_token=self.api_key) + + async def _complete(self, prompt: str, options): + def helper(): + output = self._client.run( + self.model, input={"message": prompt, "prompt": prompt} + ) + completion = "" + for item in output: + completion += item + + return completion + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(helper) + completion = future.result() + + return completion + + async def _stream_complete(self, prompt, options): + for item in self._client.run( + self.model, input={"message": prompt, "prompt": prompt} + ): + yield item + + async def _stream_chat(self, messages: List[ChatMessage], options): + for item in self._client.run( + self.model, + input={ + "message": messages[-1]["content"], + "prompt": messages[-1]["content"], + }, + ): + yield {"content": item, "role": "assistant"} diff --git a/server/continuedev/libs/llm/text_gen_interface.py b/server/continuedev/libs/llm/text_gen_interface.py new file mode 100644 index 00000000..225fd3b6 --- /dev/null +++ b/server/continuedev/libs/llm/text_gen_interface.py @@ -0,0 +1,114 @@ +import json +from typing import Any, Callable, Dict, List, Union + +import websockets +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplest_edit_prompt + + +class TextGenUI(LLM): + """ + TextGenUI is a comprehensive, open-source language model UI and local server. You can set it up with an OpenAI-compatible server plugin, but if for some reason that doesn't work, you can use this class like so: + + ```python title="~/.continue/config.py" + from continuedev.libs.llm.text_gen_interface import TextGenUI + + config = ContinueConfig( + ... + models=Models( + default=TextGenUI( + model="<MODEL_NAME>", + ) + ) + ) + ``` + """ + + model: str = "text-gen-ui" + server_url: str = Field( + "http://localhost:5000", description="URL of your TextGenUI server" + ) + streaming_url: str = Field( + "http://localhost:5005", + description="URL of your TextGenUI streaming server (separate from main server URL)", + ) + + prompt_templates = { + "edit": simplest_edit_prompt, + } + + template_messages: Union[ + Callable[[List[Dict[str, str]]], str], None + ] = llama2_template_messages + + class Config: + arbitrary_types_allowed = True + + def collect_args(self, options) -> Any: + args = super().collect_args(options) + args = {**args, "max_new_tokens": options.max_tokens} + args.pop("max_tokens", None) + return args + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + + ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}" + payload = json.dumps({"prompt": prompt, "stream": True, **args}) + async with websockets.connect( + f"{ws_url}/api/v1/stream", ping_interval=None + ) as websocket: + await websocket.send(payload) + + while True: + incoming_data = await websocket.recv() + incoming_data = json.loads(incoming_data) + + match incoming_data["event"]: + case "text_stream": + yield incoming_data["text"] + case "stream_end": + break + + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) + + async def generator(): + ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}" + history = list(map(lambda x: x["content"], messages)) + payload = json.dumps( + { + "user_input": messages[-1]["content"], + "history": {"internal": [history], "visible": [history]}, + "stream": True, + **args, + } + ) + async with websockets.connect( + f"{ws_url}/api/v1/chat-stream", ping_interval=None + ) as websocket: + await websocket.send(payload) + + prev = "" + while True: + incoming_data = await websocket.recv() + incoming_data = json.loads(incoming_data) + + match incoming_data["event"]: + case "text_stream": + visible = incoming_data["history"]["visible"][-1] + if len(visible) > 0: + yield { + "role": "assistant", + "content": visible[-1].replace(prev, ""), + } + prev = visible[-1] + case "stream_end": + break + + async for chunk in generator(): + yield chunk diff --git a/server/continuedev/libs/llm/together.py b/server/continuedev/libs/llm/together.py new file mode 100644 index 00000000..35b3a424 --- /dev/null +++ b/server/continuedev/libs/llm/together.py @@ -0,0 +1,125 @@ +import json +from typing import Callable + +import aiohttp +from pydantic import Field + +from ...core.main import ContinueCustomException +from ..util.logging import logger +from .base import LLM +from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt + + +class TogetherLLM(LLM): + """ + The Together API is a cloud platform for running large AI models. You can sign up [here](https://api.together.xyz/signup), copy your API key on the initial welcome screen, and then hit the play button on any model from the [Together Models list](https://docs.together.ai/docs/models-inference). Change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.core.models import Models + from continuedev.libs.llm.together import TogetherLLM + + config = ContinueConfig( + ... + models=Models( + default=TogetherLLM( + api_key="<API_KEY>", + model="togethercomputer/llama-2-13b-chat" + ) + ) + ) + ``` + """ + + api_key: str = Field(..., description="Together API key") + + model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct" + base_url: str = Field( + "https://api.together.xyz", + description="The base URL for your Together API instance", + ) + + _client_session: aiohttp.ClientSession = None + + template_messages: Callable = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + + async def start(self, **kwargs): + await super().start(**kwargs) + self._client_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) + + async def stop(self): + await self._client_session.close() + + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) + + async with self._client_session.post( + f"{self.base_url}/inference", + json={ + "prompt": prompt, + "stream_tokens": True, + **args, + }, + headers={"Authorization": f"Bearer {self.api_key}"}, + proxy=self.proxy, + ) as resp: + async for line in resp.content.iter_chunks(): + if line[1]: + json_chunk = line[0].decode("utf-8") + if json_chunk.startswith(": ping - ") or json_chunk.startswith( + "data: [DONE]" + ): + continue + + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + if chunk.startswith("data: "): + chunk = chunk[6:] + if chunk == "[DONE]": + break + try: + json_chunk = json.loads(chunk) + except Exception as e: + logger.warning(f"Invalid JSON chunk: {chunk}\n\n{e}") + continue + if "choices" in json_chunk: + yield json_chunk["choices"][0]["text"] + + async def _complete(self, prompt: str, options): + args = self.collect_args(options) + + async with self._client_session.post( + f"{self.base_url}/inference", + json={"prompt": prompt, **args}, + headers={"Authorization": f"Bearer {self.api_key}"}, + proxy=self.proxy, + ) as resp: + text = await resp.text() + j = json.loads(text) + try: + if "choices" not in j["output"]: + raise Exception(text) + if "output" in j: + return j["output"]["choices"][0]["text"] + except Exception as e: + j = await resp.json() + if "error" in j: + if j["error"].startswith("invalid hexlify value"): + raise ContinueCustomException( + message=f"Invalid Together API key:\n\n{j['error']}", + title="Together API Error", + ) + else: + raise ContinueCustomException( + message=j["error"], title="Together API Error" + ) + + raise e diff --git a/server/continuedev/libs/util/calculate_diff.py b/server/continuedev/libs/util/calculate_diff.py new file mode 100644 index 00000000..99301ae7 --- /dev/null +++ b/server/continuedev/libs/util/calculate_diff.py @@ -0,0 +1,154 @@ +import difflib +from typing import List + +from ...models.filesystem import FileEdit +from ...models.main import Position, Range + + +def calculate_diff(filepath: str, original: str, updated: str) -> List[FileEdit]: + s = difflib.SequenceMatcher(None, original, updated) + offset = 0 # The indices are offset by previous deletions/insertions + edits = [] + for tag, i1, i2, j1, j2 in s.get_opcodes(): + i1, i2, j1, j2 = i1 + offset, i2 + offset, j1 + offset, j2 + offset + replacement = updated[j1:j2] + if tag == "equal": + pass + elif tag == "delete": + edits.append( + FileEdit.from_deletion(filepath, Range.from_indices(original, i1, i2)) + ) + offset -= i2 - i1 + elif tag == "insert": + edits.append( + FileEdit.from_insertion( + filepath, Position.from_index(original, i1), replacement + ) + ) + offset += j2 - j1 + elif tag == "replace": + edits.append( + FileEdit( + filepath=filepath, + range=Range.from_indices(original, i1, i2), + replacement=replacement, + ) + ) + offset += (j2 - j1) - (i2 - i1) + else: + raise Exception("Unexpected difflib.SequenceMatcher tag: " + tag) + + return edits + + +def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit]: + # original_lines = original.splitlines() + # updated_lines = updated.splitlines() + # offset = 0 + # while len(original_lines) and len(updated_lines) and original_lines[0] == updated_lines[0]: + # original_lines = original_lines[1:] + # updated_lines = updated_lines[1:] + + # while len(original_lines) and len(updated_lines) and original_lines[-1] == updated_lines[-1]: + # original_lines = original_lines[:-1] + # updated_lines = updated_lines[:-1] + + # original = "\n".join(original_lines) + # updated = "\n".join(updated_lines) + + edits = [] + max_iterations = 1000 + i = 0 + while not original == updated: + # TODO - For some reason it can't handle a single newline at the end of the file? + s = difflib.SequenceMatcher(None, original, updated) + opcodes = s.get_opcodes() + for edit_index in range(len(opcodes)): + tag, i1, i2, j1, j2 = s.get_opcodes()[edit_index] + replacement = updated[j1:j2] + if tag == "equal": + continue # ;) + elif tag == "delete": + edits.append( + FileEdit.from_deletion( + filepath, Range.from_indices(original, i1, i2) + ) + ) + elif tag == "insert": + edits.append( + FileEdit.from_insertion( + filepath, Position.from_index(original, i1), replacement + ) + ) + elif tag == "replace": + edits.append( + FileEdit( + filepath=filepath, + range=Range.from_indices(original, i1, i2), + replacement=replacement, + ) + ) + else: + raise Exception("Unexpected difflib.SequenceMatcher tag: " + tag) + break + + original = apply_edit_to_str(original, edits[-1]) + + i += 1 + if i > max_iterations: + raise Exception("Max iterations reached") + + return edits + + +def read_range_in_str(s: str, r: Range) -> str: + lines = s.splitlines()[r.start.line : r.end.line + 1] + if len(lines) == 0: + return "" + + lines[0] = lines[0][r.start.character :] + lines[-1] = lines[-1][: r.end.character + 1] + return "\n".join(lines) + + +def apply_edit_to_str(s: str, edit: FileEdit) -> str: + read_range_in_str(s, edit.range) + + # Split lines and deal with some edge cases (could obviously be nicer) + lines = s.splitlines() + if s.startswith("\n"): + lines.insert(0, "") + if s.endswith("\n"): + lines.append("") + + if len(lines) == 0: + lines = [""] + + end = Position(line=edit.range.end.line, character=edit.range.end.character) + if edit.range.end.line == len(lines) and edit.range.end.character == 0: + end = Position( + line=edit.range.end.line - 1, + character=len(lines[min(len(lines) - 1, edit.range.end.line - 1)]), + ) + + before_lines = lines[: edit.range.start.line] + after_lines = lines[end.line + 1 :] + between_str = ( + lines[min(len(lines) - 1, edit.range.start.line)][: edit.range.start.character] + + edit.replacement + + lines[min(len(lines) - 1, end.line)][end.character + 1 :] + ) + + Range( + start=edit.range.start, + end=Position( + line=edit.range.start.line + len(edit.replacement.splitlines()) - 1, + character=edit.range.start.character + + len(edit.replacement.splitlines()[-1]) + if edit.replacement != "" + else 0, + ), + ) + + lines = before_lines + between_str.splitlines() + after_lines + return "\n".join(lines) diff --git a/server/continuedev/libs/util/commonregex.py b/server/continuedev/libs/util/commonregex.py new file mode 100644 index 00000000..c2f6bb82 --- /dev/null +++ b/server/continuedev/libs/util/commonregex.py @@ -0,0 +1,144 @@ +# coding: utf-8 +import re +from typing import Any + +date = re.compile( + "(?:(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?\s+(?:of\s+)?(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)|(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)\s+(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?)(?:\,)?\s*(?:\d{4})?|[0-3]?\d[-\./][0-3]?\d[-\./]\d{2,4}", + re.IGNORECASE, +) +time = re.compile("\d{1,2}:\d{2} ?(?:[ap]\.?m\.?)?|\d[ap]\.?m\.?", re.IGNORECASE) +phone = re.compile( + """((?:(?<![\d-])(?:\+?\d{1,3}[-.\s*]?)?(?:\(?\d{3}\)?[-.\s*]?)?\d{3}[-.\s*]?\d{4}(?![\d-]))|(?:(?<![\d-])(?:(?:\(\+?\d{2}\))|(?:\+?\d{2}))\s*\d{2}\s*\d{3}\s*\d{4}(?![\d-])))""" +) +phones_with_exts = re.compile( + "((?:(?:\+?1\s*(?:[.-]\s*)?)?(?:\(\s*(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9])\s*\)|(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9]))\s*(?:[.-]\s*)?)?(?:[2-9]1[02-9]|[2-9][02-9]1|[2-9][02-9]{2})\s*(?:[.-]\s*)?(?:[0-9]{4})(?:\s*(?:#|x\.?|ext\.?|extension)\s*(?:\d+)?))", + re.IGNORECASE, +) +link = re.compile( + "(?i)((?:https?://|www\d{0,3}[.])?[a-z0-9.\-]+[.](?:(?:international)|(?:construction)|(?:contractors)|(?:enterprises)|(?:photography)|(?:immobilien)|(?:management)|(?:technology)|(?:directory)|(?:education)|(?:equipment)|(?:institute)|(?:marketing)|(?:solutions)|(?:builders)|(?:clothing)|(?:computer)|(?:democrat)|(?:diamonds)|(?:graphics)|(?:holdings)|(?:lighting)|(?:plumbing)|(?:training)|(?:ventures)|(?:academy)|(?:careers)|(?:company)|(?:domains)|(?:florist)|(?:gallery)|(?:guitars)|(?:holiday)|(?:kitchen)|(?:recipes)|(?:shiksha)|(?:singles)|(?:support)|(?:systems)|(?:agency)|(?:berlin)|(?:camera)|(?:center)|(?:coffee)|(?:estate)|(?:kaufen)|(?:luxury)|(?:monash)|(?:museum)|(?:photos)|(?:repair)|(?:social)|(?:tattoo)|(?:travel)|(?:viajes)|(?:voyage)|(?:build)|(?:cheap)|(?:codes)|(?:dance)|(?:email)|(?:glass)|(?:house)|(?:ninja)|(?:photo)|(?:shoes)|(?:solar)|(?:today)|(?:aero)|(?:arpa)|(?:asia)|(?:bike)|(?:buzz)|(?:camp)|(?:club)|(?:coop)|(?:farm)|(?:gift)|(?:guru)|(?:info)|(?:jobs)|(?:kiwi)|(?:land)|(?:limo)|(?:link)|(?:menu)|(?:mobi)|(?:moda)|(?:name)|(?:pics)|(?:pink)|(?:post)|(?:rich)|(?:ruhr)|(?:sexy)|(?:tips)|(?:wang)|(?:wien)|(?:zone)|(?:biz)|(?:cab)|(?:cat)|(?:ceo)|(?:com)|(?:edu)|(?:gov)|(?:int)|(?:mil)|(?:net)|(?:onl)|(?:org)|(?:pro)|(?:red)|(?:tel)|(?:uno)|(?:xxx)|(?:ac)|(?:ad)|(?:ae)|(?:af)|(?:ag)|(?:ai)|(?:al)|(?:am)|(?:an)|(?:ao)|(?:aq)|(?:ar)|(?:as)|(?:at)|(?:au)|(?:aw)|(?:ax)|(?:az)|(?:ba)|(?:bb)|(?:bd)|(?:be)|(?:bf)|(?:bg)|(?:bh)|(?:bi)|(?:bj)|(?:bm)|(?:bn)|(?:bo)|(?:br)|(?:bs)|(?:bt)|(?:bv)|(?:bw)|(?:by)|(?:bz)|(?:ca)|(?:cc)|(?:cd)|(?:cf)|(?:cg)|(?:ch)|(?:ci)|(?:ck)|(?:cl)|(?:cm)|(?:cn)|(?:co)|(?:cr)|(?:cu)|(?:cv)|(?:cw)|(?:cx)|(?:cy)|(?:cz)|(?:de)|(?:dj)|(?:dk)|(?:dm)|(?:do)|(?:dz)|(?:ec)|(?:ee)|(?:eg)|(?:er)|(?:es)|(?:et)|(?:eu)|(?:fi)|(?:fj)|(?:fk)|(?:fm)|(?:fo)|(?:fr)|(?:ga)|(?:gb)|(?:gd)|(?:ge)|(?:gf)|(?:gg)|(?:gh)|(?:gi)|(?:gl)|(?:gm)|(?:gn)|(?:gp)|(?:gq)|(?:gr)|(?:gs)|(?:gt)|(?:gu)|(?:gw)|(?:gy)|(?:hk)|(?:hm)|(?:hn)|(?:hr)|(?:ht)|(?:hu)|(?:id)|(?:ie)|(?:il)|(?:im)|(?:in)|(?:io)|(?:iq)|(?:ir)|(?:is)|(?:it)|(?:je)|(?:jm)|(?:jo)|(?:jp)|(?:ke)|(?:kg)|(?:kh)|(?:ki)|(?:km)|(?:kn)|(?:kp)|(?:kr)|(?:kw)|(?:ky)|(?:kz)|(?:la)|(?:lb)|(?:lc)|(?:li)|(?:lk)|(?:lr)|(?:ls)|(?:lt)|(?:lu)|(?:lv)|(?:ly)|(?:ma)|(?:mc)|(?:md)|(?:me)|(?:mg)|(?:mh)|(?:mk)|(?:ml)|(?:mm)|(?:mn)|(?:mo)|(?:mp)|(?:mq)|(?:mr)|(?:ms)|(?:mt)|(?:mu)|(?:mv)|(?:mw)|(?:mx)|(?:my)|(?:mz)|(?:na)|(?:nc)|(?:ne)|(?:nf)|(?:ng)|(?:ni)|(?:nl)|(?:no)|(?:np)|(?:nr)|(?:nu)|(?:nz)|(?:om)|(?:pa)|(?:pe)|(?:pf)|(?:pg)|(?:ph)|(?:pk)|(?:pl)|(?:pm)|(?:pn)|(?:pr)|(?:ps)|(?:pt)|(?:pw)|(?:py)|(?:qa)|(?:re)|(?:ro)|(?:rs)|(?:ru)|(?:rw)|(?:sa)|(?:sb)|(?:sc)|(?:sd)|(?:se)|(?:sg)|(?:sh)|(?:si)|(?:sj)|(?:sk)|(?:sl)|(?:sm)|(?:sn)|(?:so)|(?:sr)|(?:st)|(?:su)|(?:sv)|(?:sx)|(?:sy)|(?:sz)|(?:tc)|(?:td)|(?:tf)|(?:tg)|(?:th)|(?:tj)|(?:tk)|(?:tl)|(?:tm)|(?:tn)|(?:to)|(?:tp)|(?:tr)|(?:tt)|(?:tv)|(?:tw)|(?:tz)|(?:ua)|(?:ug)|(?:uk)|(?:us)|(?:uy)|(?:uz)|(?:va)|(?:vc)|(?:ve)|(?:vg)|(?:vi)|(?:vn)|(?:vu)|(?:wf)|(?:ws)|(?:ye)|(?:yt)|(?:za)|(?:zm)|(?:zw))(?:/[^\s()<>]+[^\s`!()\[\]{};:'\".,<>?\xab\xbb\u201c\u201d\u2018\u2019])?)", + re.IGNORECASE, +) +email = re.compile( + "([a-z0-9!#$%&'*+\/=?^_`{|.}~-]+@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?)", + re.IGNORECASE, +) +ip = re.compile( + "(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)", + re.IGNORECASE, +) +ipv6 = re.compile( + "\s*(?!.*::.*::)(?:(?!:)|:(?=:))(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)){6}(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)[0-9a-f]{0,4}(?:(?<=::)|(?<!:)|(?<=:)(?<!::):)|(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)(?:\.(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)){3})\s*", + re.VERBOSE | re.IGNORECASE | re.DOTALL, +) +price = re.compile("[$]\s?[+-]?[0-9]{1,3}(?:(?:,?[0-9]{3}))*(?:\.[0-9]{1,2})?") +hex_color = re.compile("(#(?:[0-9a-fA-F]{8})|#(?:[0-9a-fA-F]{3}){1,2})\\b") +credit_card = re.compile("((?:(?:\\d{4}[- ]?){3}\\d{4}|\\d{15,16}))(?![\\d])") +btc_address = re.compile( + "(?<![a-km-zA-HJ-NP-Z0-9])[13][a-km-zA-HJ-NP-Z0-9]{26,33}(?![a-km-zA-HJ-NP-Z0-9])" +) +street_address = re.compile( + "\d{1,4} [\w\s]{1,20}(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|park|parkway|pkwy|circle|cir|boulevard|blvd)\W?(?=\s|$)", + re.IGNORECASE, +) +zip_code = re.compile(r"\b\d{5}(?:[-\s]\d{4})?\b") +po_box = re.compile(r"P\.? ?O\.? Box \d+", re.IGNORECASE) +ssn = re.compile( + "(?!000|666|333)0*(?:[0-6][0-9][0-9]|[0-7][0-6][0-9]|[0-7][0-7][0-2])[- ](?!00)[0-9]{2}[- ](?!0000)[0-9]{4}" +) + +regexes = { + "dates": date, + "times": time, + "phones": phone, + "phones_with_exts": phones_with_exts, + "emails": email, + "ips": ip, + "ipv6s": ipv6, + "prices": price, + "hex_colors": hex_color, + "credit_cards": credit_card, + "btc_addresses": btc_address, + "street_addresses": street_address, + "zip_codes": zip_code, + "po_boxes": po_box, + "ssn_number": ssn, +} + +placeholders = { + "dates": "<DATE>", + "times": "<TIME>", + "phones": "<PHONE>", + "phones_with_exts": "<PHONE_WITH_EXT>", + "emails": "<EMAIL>", + "ips": "<IP>", + "ipv6s": "<IPV6>", + "prices": "<PRICE>", + "hex_colors": "<HEX_COLOR>", + "credit_cards": "<CREDIT_CARD>", + "btc_addresses": "<BTC_ADDRESS>", + "street_addresses": "<STREET_ADDRESS>", + "zip_codes": "<ZIP_CODE>", + "po_boxes": "<PO_BOX>", + "ssn_number": "<SSN>", +} + + +class regex: + def __init__(self, obj, regex): + self.obj = obj + self.regex = regex + + def __call__(self, *args): + def regex_method(text=None): + return [x.strip() for x in self.regex.findall(text or self.obj.text)] + + return regex_method + + +class CommonRegex(object): + def __init__(self, text=""): + self.text = text + + for k, v in list(regexes.items()): + setattr(self, k, regex(self, v)(self)) + + if text: + for key in list(regexes.keys()): + method = getattr(self, key) + setattr(self, key, method()) + + +pii_parser = CommonRegex() + + +def clean_pii_from_str(text: str): + """Replace personally identifiable information (PII) with placeholders.""" + for regex_name, regex in list(regexes.items()): + placeholder = placeholders[regex_name] + text = regex.sub(placeholder, text) + + return text + + +def clean_pii_from_any(v: Any) -> Any: + """Replace personally identifiable information (PII) with placeholders. Not guaranteed to return same type as input.""" + if isinstance(v, str): + return clean_pii_from_str(v) + elif isinstance(v, dict): + cleaned_dict = {} + for key, value in v.items(): + cleaned_dict[key] = clean_pii_from_any(value) + return cleaned_dict + elif isinstance(v, list): + return [clean_pii_from_any(x) for x in v] + else: + # Try to convert to string + try: + orig_text = str(v) + cleaned_text = clean_pii_from_str(orig_text) + if orig_text != cleaned_text: + return cleaned_text + else: + return v + except: + return v diff --git a/server/continuedev/libs/util/copy_codebase.py b/server/continuedev/libs/util/copy_codebase.py new file mode 100644 index 00000000..78f38148 --- /dev/null +++ b/server/continuedev/libs/util/copy_codebase.py @@ -0,0 +1,121 @@ +import os +import shutil +from pathlib import Path +from typing import Iterable, List, Union + +from watchdog.events import PatternMatchingEventHandler +from watchdog.observers import Observer + +from ...core.autopilot import Autopilot +from ...models.filesystem import FileSystem +from ...models.main import ( + AddDirectory, + AddFile, + DeleteDirectory, + DeleteFile, + FileSystemEdit, + RenameDirectory, + RenameFile, + SequentialFileSystemEdit, +) +from .map_path import map_path + + +def create_copy(orig_root: str, copy_root: str = None, ignore: Iterable[str] = []): + # TODO: Make ignore a spec, like .gitignore + if copy_root is None: + copy_root = Path(orig_root) / ".continue-copy" + ignore.append(str(copy_root)) + ignore = set(ignore) + + os.mkdir(copy_root) + # I think you're messing up a lot of absolute paths here + for child in os.listdir(): + if os.path.isdir(child): + if child not in ignore: + os.mkdir(map_path(child)) + create_copy(Path(orig_root) / child, Path(copy_root) / child, ignore) + else: + os.symlink(child, map_path(child)) + else: + if child not in ignore: + shutil.copyfile(child, map_path(child)) + else: + os.symlink(child, map_path(child)) + + +# The whole usage of watchdog here should only be specific to RealFileSystem, you want to have a different "Observer" class for VirtualFileSystem, which would depend on being sent notifications +class CopyCodebaseEventHandler(PatternMatchingEventHandler): + def __init__( + self, + ignore_directories: List[str], + ignore_patterns: List[str], + autopilot: Autopilot, + orig_root: str, + copy_root: str, + filesystem: FileSystem, + ): + super().__init__( + ignore_directories=ignore_directories, ignore_patterns=ignore_patterns + ) + self.autopilot = autopilot + self.orig_root = orig_root + self.copy_root = copy_root + self.filesystem = filesystem + + # For now, we'll just make the update immediately, but eventually need to sync with autopilot. + # It should be the autopilot that makes the update right? It's just another action, everything comes from a single stream. + + def _event_to_edit(self, event) -> Union[FileSystemEdit, None]: + # NOTE: You'll need to map paths to create both an action within the copy filesystem (the one you take) and one in the original filesystem (the one you'll record and allow the user to accept). Basically just need a converter built in to the FileSystemEdit class + src = event.src_path() + if event.is_directory: + if event.event_type == "moved": + return RenameDirectory(src, event.dest_path()) + elif event.event_type == "deleted": + return DeleteDirectory(src) + elif event.event_type == "created": + return AddDirectory(src) + else: + if event.event_type == "moved": + return RenameFile(src, event.dest_path()) + elif event.event_type == "deleted": + return DeleteFile(src) + elif event.event_type == "created": + contents = self.filesystem.read(src) + # Unclear whether it will always pass a "modified" event right after if something like echo "abc" > newfile.txt happens + return AddFile(src, contents) + elif event.event_type == "modified": + # Watchdog doesn't pass the contents or edit, so have to get it myself and diff + updated = self.filesystem.read(src) + copy_filepath = map_path(src, self.orig_root, self.copy_root) + old = self.filesystem.read(copy_filepath) + + edits = calculate_diff(src, updated, old) + return SequentialFileSystemEdit(edits) + return None + + def on_any_event(self, event): + edit = self._event_to_edit(event) + if edit is None: + return + edit = edit.with_mapped_paths(self.orig_root, self.copy_root) + action = ManualEditAction(edit) + self.autopilot.act(action) + + +def maintain_copy_workspace( + autopilot: Autopilot, filesystem: FileSystem, orig_root: str, copy_root: str +): + observer = Observer() + event_handler = CopyCodebaseEventHandler( + [".git"], [], autopilot, orig_root, copy_root, filesystem + ) + observer.schedule(event_handler, orig_root, recursive=True) + observer.start() + try: + while observer.isAlive(): + observer.join(1) + finally: + observer.stop() + observer.join() diff --git a/server/continuedev/libs/util/count_tokens.py b/server/continuedev/libs/util/count_tokens.py new file mode 100644 index 00000000..d895a2cf --- /dev/null +++ b/server/continuedev/libs/util/count_tokens.py @@ -0,0 +1,206 @@ +import json +from typing import Dict, List, Union + +from ...core.main import ChatMessage +from .templating import render_templated_string + +# TODO move many of these into specific LLM.properties() function that +# contains max tokens, if its a chat model or not, default args (not all models +# want to be run at 0.5 temp). also lets custom models made for long contexts +# exist here (likg LLongMA) +aliases = { + "ggml": "gpt-3.5-turbo", + "claude-2": "gpt-3.5-turbo", +} +DEFAULT_MAX_TOKENS = 1024 +DEFAULT_ARGS = { + "max_tokens": DEFAULT_MAX_TOKENS, + "temperature": 0.5, +} + +already_saw_import_err = False + + +def encoding_for_model(model_name: str): + global already_saw_import_err + if already_saw_import_err: + return None + + try: + import tiktoken + from tiktoken_ext import openai_public # noqa: F401 + + try: + return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) + except Exception as _: + return tiktoken.encoding_for_model("gpt-3.5-turbo") + except Exception as e: + print("Error importing tiktoken", e) + already_saw_import_err = True + return None + + +def count_tokens(model_name: str, text: Union[str, None]): + if text is None: + return 0 + encoding = encoding_for_model(model_name) + if encoding is None: + # Make a safe estimate given that tokens are usually typically ~4 characters on average + return len(text) // 2 + return len(encoding.encode(text, disallowed_special=())) + + +def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int: + # Doing simpler, safer version of what is here: + # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # every message follows <|start|>{role/name}\n{content}<|end|>\n + TOKENS_PER_MESSAGE = 4 + return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE + + +def prune_raw_prompt_from_top( + model_name: str, context_length: int, prompt: str, tokens_for_completion: int +): + max_tokens = context_length - tokens_for_completion + encoding = encoding_for_model(model_name) + + if encoding is None: + desired_length_in_chars = max_tokens * 2 + return prompt[-desired_length_in_chars:] + + tokens = encoding.encode(prompt, disallowed_special=()) + if len(tokens) <= max_tokens: + return prompt + else: + return encoding.decode(tokens[-max_tokens:]) + + +def prune_chat_history( + model_name: str, + chat_history: List[ChatMessage], + context_length: int, + tokens_for_completion: int, +): + total_tokens = tokens_for_completion + sum( + count_chat_message_tokens(model_name, message) for message in chat_history + ) + + # 1. Replace beyond last 5 messages with summary + i = 0 + while total_tokens > context_length and i < len(chat_history) - 5: + message = chat_history[0] + total_tokens -= count_tokens(model_name, message.content) + total_tokens += count_tokens(model_name, message.summary) + message.content = message.summary + i += 1 + + # 2. Remove entire messages until the last 5 + while ( + len(chat_history) > 5 + and total_tokens > context_length + and len(chat_history) > 0 + ): + message = chat_history.pop(0) + total_tokens -= count_tokens(model_name, message.content) + + # 3. Truncate message in the last 5, except last 1 + i = 0 + while ( + total_tokens > context_length + and len(chat_history) > 0 + and i < len(chat_history) - 1 + ): + message = chat_history[i] + total_tokens -= count_tokens(model_name, message.content) + total_tokens += count_tokens(model_name, message.summary) + message.content = message.summary + i += 1 + + # 4. Remove entire messages in the last 5, except last 1 + while total_tokens > context_length and len(chat_history) > 1: + message = chat_history.pop(0) + total_tokens -= count_tokens(model_name, message.content) + + # 5. Truncate last message + if total_tokens > context_length and len(chat_history) > 0: + message = chat_history[0] + message.content = prune_raw_prompt_from_top( + model_name, context_length, message.content, tokens_for_completion + ) + total_tokens = context_length + + return chat_history + + +# In case we've missed weird edge cases +TOKEN_BUFFER_FOR_SAFETY = 100 + + +def compile_chat_messages( + model_name: str, + msgs: Union[List[ChatMessage], None], + context_length: int, + max_tokens: int, + prompt: Union[str, None] = None, + functions: Union[List, None] = None, + system_message: Union[str, None] = None, +) -> List[Dict]: + """ + The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message + """ + + msgs_copy = [msg.copy(deep=True) for msg in msgs] if msgs is not None else [] + + if prompt is not None: + prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt) + msgs_copy += [prompt_msg] + + if system_message is not None and system_message.strip() != "": + # NOTE: System message takes second precedence to user prompt, so it is placed just before + # but move back to start after processing + rendered_system_message = render_templated_string(system_message) + system_chat_msg = ChatMessage( + role="system", + content=rendered_system_message, + summary=rendered_system_message, + ) + # insert at second-to-last position + msgs_copy.insert(-1, system_chat_msg) + + # Add tokens from functions + function_tokens = 0 + if functions is not None: + for function in functions: + function_tokens += count_tokens(model_name, json.dumps(function)) + + if max_tokens + function_tokens + TOKEN_BUFFER_FOR_SAFETY >= context_length: + raise ValueError( + f"max_tokens ({max_tokens}) is too close to context_length ({context_length}), which doesn't leave room for chat history. This would cause incoherent responses. Try increasing the context_length parameter of the model in your config file." + ) + + msgs_copy = prune_chat_history( + model_name, + msgs_copy, + context_length, + function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY, + ) + + history = [msg.to_dict(with_functions=functions is not None) for msg in msgs_copy] + + # Move system message back to start + if ( + system_message is not None + and len(history) >= 2 + and history[-2]["role"] == "system" + ): + system_message_dict = history.pop(-2) + history.insert(0, system_message_dict) + + return history + + +def format_chat_messages(messages: List[ChatMessage]) -> str: + formatted = "" + for msg in messages: + formatted += f"<{msg['role'].capitalize()}>\n{msg['content']}\n\n" + return formatted diff --git a/server/continuedev/libs/util/create_async_task.py b/server/continuedev/libs/util/create_async_task.py new file mode 100644 index 00000000..232d3fa1 --- /dev/null +++ b/server/continuedev/libs/util/create_async_task.py @@ -0,0 +1,38 @@ +import asyncio +import traceback +from typing import Callable, Coroutine, Optional + +import nest_asyncio + +from .logging import logger +from .telemetry import posthog_logger + +nest_asyncio.apply() + + +def create_async_task( + coro: Coroutine, on_error: Optional[Callable[[Exception], Coroutine]] = None +): + """asyncio.create_task and log errors by adding a callback""" + task = asyncio.create_task(coro) + + def callback(future: asyncio.Future): + try: + future.result() + except Exception as e: + formatted_tb = "\n".join(traceback.format_exception(e)) + logger.critical(f"Exception caught from async task: {formatted_tb}") + posthog_logger.capture_event( + "async_task_error", + { + "error_title": e.__str__() or e.__repr__(), + "error_message": "\n".join(traceback.format_exception(e)), + }, + ) + + # Log the error to the GUI + if on_error is not None: + asyncio.create_task(on_error(e)) + + task.add_done_callback(callback) + return task diff --git a/server/continuedev/libs/util/devdata.py b/server/continuedev/libs/util/devdata.py new file mode 100644 index 00000000..61b4351d --- /dev/null +++ b/server/continuedev/libs/util/devdata.py @@ -0,0 +1,67 @@ +""" +This file contains mechanisms for logging development data to files, SQL databases, and other formats. +""" + + +import json +from datetime import datetime +from typing import Any, Dict + +import aiohttp + +from .create_async_task import create_async_task +from .logging import logger +from .paths import getDevDataFilePath + + +class DevDataLogger: + user_token: str = None + data_server_url: str = None + + def setup(self, user_token: str = None, data_server_url: str = None): + self.user_token = user_token + self.data_server_url = data_server_url + + def _to_data_server(self, table_name: str, data: Dict[str, Any]): + async def _async_helper(self, table_name: str, data: Dict[str, Any]): + if self.user_token is None or self.data_server_url is None: + return + + async with aiohttp.ClientSession() as session: + await session.post( + f"{self.data_server_url}/event", + headers={"Authorization": f"Bearer {self.user_token}"}, + json={ + "table_name": table_name, + "data": data, + "user_token": self.user_token, + }, + ) + + create_async_task( + _async_helper(self, table_name, data), + lambda e: logger.warning(f"Failed to send dev data: {e}"), + ) + + def _static_columns(self): + return { + "user_token": self.user_token or "NO_USER_TOKEN", + "timestamp": datetime.now().isoformat(), + } + + def _to_local(self, table_name: str, data: Dict[str, Any]): + filepath = getDevDataFilePath(table_name) + with open(filepath, "a") as f: + json_line = json.dumps(data) + f.write(f"{json_line}\n") + + def capture(self, table_name: str, data: Dict[str, Any]): + try: + data = {**self._static_columns(), **data} + self._to_data_server(table_name, data) + self._to_local(table_name, data) + except Exception as e: + logger.warning(f"Failed to capture dev data: {e}") + + +dev_data_logger = DevDataLogger() diff --git a/server/continuedev/libs/util/edit_config.py b/server/continuedev/libs/util/edit_config.py new file mode 100644 index 00000000..4dc427d2 --- /dev/null +++ b/server/continuedev/libs/util/edit_config.py @@ -0,0 +1,149 @@ +import threading +from typing import Any, Dict, List + +import redbaron + +from .paths import getConfigFilePath + + +def get_config_source(): + config_file_path = getConfigFilePath() + with open(config_file_path, "r") as file: + source_code = file.read() + return source_code + + +def load_red(): + source_code = get_config_source() + + red = redbaron.RedBaron(source_code) + return red + + +def get_config_node(red): + for node in red: + if node.type == "assignment" and node.target.value == "config": + return node + else: + raise Exception("Config file appears to be improperly formatted") + + +def edit_property( + args: redbaron.RedBaron, key_path: List[str], value: redbaron.RedBaron +): + for i in range(len(args)): + node = args[i] + if node.type != "call_argument": + continue + + if node.target.value == key_path[0]: + if len(key_path) > 1: + edit_property(node.value.value[1].value, key_path[1:], value) + else: + args[i].value = value + return + + +edit_lock = threading.Lock() + + +def edit_config_property(key_path: List[str], value: redbaron.RedBaron): + with edit_lock: + red = load_red() + config = get_config_node(red) + config_args = config.value.value[1].value + edit_property(config_args, key_path, value) + + with open(getConfigFilePath(), "w") as file: + file.write(red.dumps()) + + +def add_config_import(line: str): + # check if the import already exists + source = get_config_source() + if line in source: + return + + with edit_lock: + red = load_red() + # if it doesn't exist, add it + red.insert(1, line) + + with open(getConfigFilePath(), "w") as file: + file.write(red.dumps()) + + +filtered_attrs = { + "class_name", + "name", + "llm", +} + +filtered_attrs_when_new = {"timeout", "prompt_templates"} + + +def escape_string(string: str) -> str: + return string.replace('"', '\\"').replace("'", "\\'") + + +def display_val(v: Any, k: str = None): + if k == "template_messages": + return v + + if isinstance(v, str): + return f'"{escape_string(v)}"' + return str(v) + + +def is_default(llm, k, v): + if k == "template_messages" and llm.__fields__[k].default is not None: + return llm.__fields__[k].default.__name__ == v + return v == llm.__fields__[k].default + + +def display_llm_class(llm, new: bool = False): + sep = ",\n\t\t\t" + args = sep.join( + [ + f"{k}={display_val(v, k)}" + for k, v in llm.dict().items() + if k not in filtered_attrs and v is not None and not is_default(llm, k, v) + ] + ) + return f"{llm.__class__.__name__}(\n\t\t\t{args}\n\t\t)" + + +def create_obj_node( + class_name: str, args: Dict[str, str], tabs: int = 1 +) -> redbaron.RedBaron: + args = [f"{key}={value}" for key, value in args.items()] + t = "\t" * tabs + new_line = "\n\t" + t + sep = "," + new_line + + return redbaron.RedBaron(f"{class_name}({new_line}{sep.join(args)}\n{t})")[0] + + +def create_string_node(string: str) -> redbaron.RedBaron: + string = escape_string(string) + if "\n" in string: + return redbaron.RedBaron(f'"""{string}"""')[0] + return redbaron.RedBaron(f'"{string}"')[0] + + +def create_literal_node(literal: str) -> redbaron.RedBaron: + return redbaron.RedBaron(literal)[0] + + +def create_float_node(float: float) -> redbaron.RedBaron: + return redbaron.RedBaron(f"{float}")[0] + + +# Example: +# edit_config_property( +# [ +# "models", +# "default", +# ], +# create_obj_node("OpenAI", {"api_key": '""', "model": '"gpt-4"'}), +# ) diff --git a/server/continuedev/libs/util/errors.py b/server/continuedev/libs/util/errors.py new file mode 100644 index 00000000..46074cfc --- /dev/null +++ b/server/continuedev/libs/util/errors.py @@ -0,0 +1,2 @@ +class SessionNotFound(Exception): + pass diff --git a/server/continuedev/libs/util/filter_files.py b/server/continuedev/libs/util/filter_files.py new file mode 100644 index 00000000..6ebaa274 --- /dev/null +++ b/server/continuedev/libs/util/filter_files.py @@ -0,0 +1,33 @@ +import fnmatch +from typing import List + +DEFAULT_IGNORE_DIRS = [ + ".git", + ".vscode", + ".idea", + ".vs", + ".venv", + "env", + ".env", + "node_modules", + "dist", + "build", + "target", + "out", + "bin", + ".pytest_cache", + ".vscode-test", + ".continue", + "__pycache__", +] + +DEFAULT_IGNORE_PATTERNS = DEFAULT_IGNORE_DIRS + list( + filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS) +) + + +def should_filter_path( + path: str, ignore_patterns: List[str] = DEFAULT_IGNORE_PATTERNS +) -> bool: + """Returns whether a file should be filtered""" + return any(fnmatch.fnmatch(path, pattern) for pattern in ignore_patterns) diff --git a/server/continuedev/libs/util/logging.py b/server/continuedev/libs/util/logging.py new file mode 100644 index 00000000..a4dc3562 --- /dev/null +++ b/server/continuedev/libs/util/logging.py @@ -0,0 +1,47 @@ +import logging +import os + +from .paths import getLogFilePath + +logfile_path = getLogFilePath() + +try: + # Truncate the logs that are more than a day old + if os.path.exists(logfile_path) and os.path.getsize(logfile_path) > 32 * 1024: + tail = None + with open(logfile_path, "rb") as f: + f.seek(-32 * 1024, os.SEEK_END) + tail = f.read().decode("utf-8") + + if tail is not None: + with open(logfile_path, "w") as f: + f.write(tail) + +except Exception as e: + print("Error truncating log file: {}".format(e)) + +# Create a logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +# Create a file handler +file_handler = logging.FileHandler(logfile_path) +file_handler.setLevel(logging.DEBUG) + +# Create a console handler +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) + +# Create a formatter +formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] %(message)s") + +# Add the formatter to the handlers +file_handler.setFormatter(formatter) +console_handler.setFormatter(formatter) + +# Add the handlers to the logger +logger.addHandler(file_handler) +logger.addHandler(console_handler) + +# Log a test message +logger.debug("------ Begin Logs ------") diff --git a/server/continuedev/libs/util/map_path.py b/server/continuedev/libs/util/map_path.py new file mode 100644 index 00000000..1dddc2e9 --- /dev/null +++ b/server/continuedev/libs/util/map_path.py @@ -0,0 +1,16 @@ +from pathlib import Path + + +def map_path(path: str, orig_root: str, copy_root: str) -> Path: + path = Path(path) + if path.is_relative_to(orig_root): + if path.is_absolute(): + return Path(copy_root) / path.relative_to(orig_root) + else: + return path + else: + if path.is_absolute(): + return path + else: + # For this one, you need to know the directory from which the relative path is being used. + return Path(orig_root) / path diff --git a/server/continuedev/libs/util/paths.py b/server/continuedev/libs/util/paths.py new file mode 100644 index 00000000..22e4b5b9 --- /dev/null +++ b/server/continuedev/libs/util/paths.py @@ -0,0 +1,148 @@ +import os +import re +from typing import Optional + +from ..constants.default_config import default_config +from ..constants.main import ( + CONTINUE_GLOBAL_FOLDER, + CONTINUE_SERVER_FOLDER, + CONTINUE_SESSIONS_FOLDER, +) + + +def find_data_file(filename): + datadir = os.path.dirname(__file__) + return os.path.abspath(os.path.join(datadir, filename)) + + +def getGlobalFolderPath(): + path = os.path.join(os.path.expanduser("~"), CONTINUE_GLOBAL_FOLDER) + os.makedirs(path, exist_ok=True) + return path + + +def getSessionsFolderPath(): + path = os.path.join(getGlobalFolderPath(), CONTINUE_SESSIONS_FOLDER) + os.makedirs(path, exist_ok=True) + return path + + +def getServerFolderPath(): + path = os.path.join(getGlobalFolderPath(), CONTINUE_SERVER_FOLDER) + os.makedirs(path, exist_ok=True) + return path + + +def getDevDataFolderPath(): + path = os.path.join(getGlobalFolderPath(), "dev_data") + os.makedirs(path, exist_ok=True) + return path + + +def getDiffsFolderPath(): + path = os.path.join(getGlobalFolderPath(), "diffs") + os.makedirs(path, exist_ok=True) + return path + + +def getDevDataFilePath(table_name: str): + filepath = os.path.join(getDevDataFolderPath(), f"{table_name}.jsonl") + if not os.path.exists(filepath): + with open(filepath, "w") as f: + f.write("") + + return filepath + + +def getMeilisearchExePath(): + binary_name = "meilisearch.exe" if os.name == "nt" else "meilisearch" + path = os.path.join(getServerFolderPath(), binary_name) + return path + + +def getSessionFilePath(session_id: str): + path = os.path.join(getSessionsFolderPath(), f"{session_id}.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + return path + + +def getSessionsListFilePath(): + path = os.path.join(getSessionsFolderPath(), "sessions.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + if not os.path.exists(path): + with open(path, "w") as f: + f.write("[]") + return path + + +def migrateConfigFile(existing: str) -> Optional[str]: + if existing.strip() == "": + return default_config + + migrated = ( + existing.replace("MaybeProxyOpenAI", "OpenAIFreeTrial") + .replace("maybe_proxy_openai", "openai_free_trial") + .replace("unused=", "saved=") + .replace("medium=", "summarize=") + ) + if migrated != existing: + return migrated + + return None + + +def getConfigFilePath() -> str: + path = os.path.join(getGlobalFolderPath(), "config.py") + os.makedirs(os.path.dirname(path), exist_ok=True) + + if not os.path.exists(path): + with open(path, "w") as f: + f.write(default_config) + else: + # Make any necessary migrations + with open(path, "r") as f: + existing_content = f.read() + + migrated = migrateConfigFile(existing_content) + + if migrated is not None: + with open(path, "w") as f: + f.write(migrated) + + return path + + +def convertConfigImports(shorten: bool) -> str: + path = getConfigFilePath() + # Make any necessary migrations + with open(path, "r") as f: + existing_content = f.read() + + if shorten: + migrated = existing_content.replace( + "from continuedev.src.continuedev.", "from continuedev." + ) + else: + migrated = re.sub( + r"(?<!src\.)continuedev\.(?!src)", + "continuedev.", + existing_content, + ) + + with open(path, "w") as f: + f.write(migrated) + + +def getLogFilePath(): + path = os.path.join(getGlobalFolderPath(), "continue.log") + os.makedirs(os.path.dirname(path), exist_ok=True) + return path + + +def getSavedContextGroupsPath(): + path = os.path.join(getGlobalFolderPath(), "saved_context_groups.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + if not os.path.exists(path): + with open(path, "w") as f: + f.write("\{\}") + return path diff --git a/server/continuedev/libs/util/queue.py b/server/continuedev/libs/util/queue.py new file mode 100644 index 00000000..e1f98cc6 --- /dev/null +++ b/server/continuedev/libs/util/queue.py @@ -0,0 +1,17 @@ +import asyncio +from typing import Dict + + +class AsyncSubscriptionQueue: + # The correct way to do this is probably to keep request IDs + queues: Dict[str, asyncio.Queue] = {} + + def post(self, messageType: str, data: any): + if messageType not in self.queues: + self.queues.update({messageType: asyncio.Queue()}) + self.queues[messageType].put_nowait(data) + + async def get(self, message_type: str) -> any: + if message_type not in self.queues: + self.queues.update({message_type: asyncio.Queue()}) + return await self.queues[message_type].get() diff --git a/server/continuedev/libs/util/ripgrep.py b/server/continuedev/libs/util/ripgrep.py new file mode 100644 index 00000000..f7e0af9a --- /dev/null +++ b/server/continuedev/libs/util/ripgrep.py @@ -0,0 +1,25 @@ +import os + + +def get_rg_path(): + if os.name == "nt": + paths_to_try = [ + f"C:\\Users\\{os.getlogin()}\\AppData\\Local\\Programs\\Microsoft VS Code\\resources\\app\\node_modules.asar.unpacked\\@vscode\\ripgrep\\bin\\rg.exe", + f"C:\\Users\\{os.getlogin()}\\AppData\\Local\\Programs\\Microsoft VS Code\\resources\\app\\node_modules.asar.unpacked\\vscode-ripgrep\\bin\\rg.exe", + ] + for path in paths_to_try: + if os.path.exists(path): + rg_path = path + break + elif os.name == "posix": + if "darwin" in os.sys.platform: + rg_path = "/Applications/Visual Studio Code.app/Contents/Resources/app/node_modules.asar.unpacked/@vscode/ripgrep/bin/rg" + else: + rg_path = "/usr/share/code/resources/app/node_modules.asar.unpacked/vscode-ripgrep/bin/rg" + else: + rg_path = "rg" + + if not os.path.exists(rg_path): + rg_path = "rg" + + return rg_path diff --git a/server/continuedev/libs/util/step_name_to_steps.py b/server/continuedev/libs/util/step_name_to_steps.py new file mode 100644 index 00000000..25fd8ba3 --- /dev/null +++ b/server/continuedev/libs/util/step_name_to_steps.py @@ -0,0 +1,47 @@ +from typing import Dict + +from ...core.main import Step +from ...core.steps import UserInputStep +from ...libs.util.logging import logger +from ...plugins.recipes.AddTransformRecipe.main import AddTransformRecipe +from ...plugins.recipes.CreatePipelineRecipe.main import CreatePipelineRecipe +from ...plugins.recipes.DDtoBQRecipe.main import DDtoBQRecipe +from ...plugins.recipes.DeployPipelineAirflowRecipe.main import ( + DeployPipelineAirflowRecipe, +) +from ...plugins.steps.chat import SimpleChatStep +from ...plugins.steps.clear_history import ClearHistoryStep +from ...plugins.steps.comment_code import CommentCodeStep +from ...plugins.steps.feedback import FeedbackStep +from ...plugins.steps.help import HelpStep +from ...plugins.steps.main import EditHighlightedCodeStep +from ...plugins.steps.on_traceback import DefaultOnTracebackStep +from ...plugins.steps.open_config import OpenConfigStep + +# This mapping is used to convert from string in ContinueConfig json to corresponding Step class. +# Used for example in slash_commands and steps_on_startup +step_name_to_step_class = { + "UserInputStep": UserInputStep, + "EditHighlightedCodeStep": EditHighlightedCodeStep, + "SimpleChatStep": SimpleChatStep, + "CommentCodeStep": CommentCodeStep, + "FeedbackStep": FeedbackStep, + "AddTransformRecipe": AddTransformRecipe, + "CreatePipelineRecipe": CreatePipelineRecipe, + "DDtoBQRecipe": DDtoBQRecipe, + "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe, + "DefaultOnTracebackStep": DefaultOnTracebackStep, + "ClearHistoryStep": ClearHistoryStep, + "OpenConfigStep": OpenConfigStep, + "HelpStep": HelpStep, +} + + +def get_step_from_name(step_name: str, params: Dict) -> Step: + try: + return step_name_to_step_class[step_name](**params) + except: + logger.error( + f"Incorrect parameters for step {step_name}. Parameters provided were: {params}" + ) + raise diff --git a/server/continuedev/libs/util/strings.py b/server/continuedev/libs/util/strings.py new file mode 100644 index 00000000..f2b6035f --- /dev/null +++ b/server/continuedev/libs/util/strings.py @@ -0,0 +1,64 @@ +from typing import Tuple + + +def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]: + lines = s.splitlines() + if len(lines) == 0: + return "", "" + + # Longest common whitespace prefix + lcp = lines[0].split(lines[0].strip())[0] + # Iterate through the lines + for i in range(1, len(lines)): + # Empty lines are wildcards + if lines[i].strip() == "": + continue # hey that's us! + # Iterate through the leading whitespace characters of the current line + for j in range(0, len(lcp)): + # If it doesn't have the same whitespace as lcp, then update lcp + if j >= len(lines[i]) or lcp[j] != lines[i][j]: + lcp = lcp[:j] + if lcp == "": + return s, "" + break + + return "\n".join(map(lambda x: x.lstrip(lcp), lines)), lcp + + +def strip_code_block(s: str) -> str: + """ + Strips the code block from a string, if it has one. + """ + if s.startswith("```\n") and s.endswith("\n```"): + return s[4:-4] + elif s.startswith("```") and s.endswith("```"): + return s[3:-3] + elif s.startswith("`") and s.endswith("`"): + return s[1:-1] + return s + + +def remove_quotes_and_escapes(output: str) -> str: + """ + Clean up the output of the completion API, removing unnecessary escapes and quotes + """ + output = output.strip() + + # Replace smart quotes + output = output.replace("“", '"') + output = output.replace("”", '"') + output = output.replace("‘", "'") + output = output.replace("’", "'") + + # Remove escapes + output = output.replace('\\"', '"') + output = output.replace("\\'", "'") + output = output.replace("\\n", "\n") + output = output.replace("\\t", "\t") + output = output.replace("\\\\", "\\") + if (output.startswith('"') and output.endswith('"')) or ( + output.startswith("'") and output.endswith("'") + ): + output = output[1:-1] + + return output diff --git a/server/continuedev/libs/util/telemetry.py b/server/continuedev/libs/util/telemetry.py new file mode 100644 index 00000000..1772fe20 --- /dev/null +++ b/server/continuedev/libs/util/telemetry.py @@ -0,0 +1,108 @@ +import os +import socket +from typing import Any, Dict, Optional + +from dotenv import load_dotenv + +from ..constants.main import CONTINUE_SERVER_VERSION_FILE +from .commonregex import clean_pii_from_any +from .paths import getServerFolderPath + +load_dotenv() +in_codespaces = os.getenv("CODESPACES") == "true" +POSTHOG_API_KEY = "phc_JS6XFROuNbhJtVCEdTSYk6gl5ArRrTNMpCcguAXlSPs" + + +def is_connected(): + try: + # connect to the host -- tells us if the host is actually reachable + socket.create_connection(("www.google.com", 80)) + return True + except OSError: + pass + return False + + +class PostHogLogger: + unique_id: str = "NO_UNIQUE_ID" + allow_anonymous_telemetry: bool = False + ide_info: Optional[Dict] = None + posthog = None + + def __init__(self, api_key: str): + self.api_key = api_key + + def setup( + self, unique_id: str, allow_anonymous_telemetry: bool, ide_info: Optional[Dict] + ): + self.unique_id = unique_id or "NO_UNIQUE_ID" + self.allow_anonymous_telemetry = allow_anonymous_telemetry or False + self.ide_info = ide_info + + # Capture initial event + self.capture_event("session_start", {"os": os.name}) + + def capture_event(self, event_name: str, event_properties: Any): + """Safely capture event. Telemetry should never be the reason Continue doesn't work""" + try: + self._capture_event(event_name, event_properties) + except Exception as e: + print(f"Failed to capture event: {e}") + pass + + _found_disconnected: bool = False + + def _capture_event(self, event_name: str, event_properties: Any): + # logger.debug( + # f"Logging to PostHog: {event_name} ({self.unique_id}, {self.allow_anonymous_telemetry}): {event_properties}") + telemetry_path = os.path.expanduser("~/.continue/telemetry.log") + + # Make sure the telemetry file exists + if not os.path.exists(telemetry_path): + os.makedirs(os.path.dirname(telemetry_path), exist_ok=True) + open(telemetry_path, "w").close() + + with open(telemetry_path, "a") as f: + str_to_write = f"{event_name}: {event_properties}\n{self.unique_id}\n{self.allow_anonymous_telemetry}\n\n" + f.write(str_to_write) + + if not self.allow_anonymous_telemetry: + return + + # Clean PII from event properties + event_properties = clean_pii_from_any(event_properties) + + # Add additional properties that are on every event + if in_codespaces: + event_properties["codespaces"] = True + + server_version_file = os.path.join( + getServerFolderPath(), CONTINUE_SERVER_VERSION_FILE + ) + if os.path.exists(server_version_file): + with open(server_version_file, "r") as f: + event_properties["server_version"] = f.read() + + # Add operating system + event_properties["os"] = os.name + if self.ide_info: + event_properties["ide_name"] = self.ide_info.get("name", None) + event_properties["ide_version"] = self.ide_info.get("version", None) + event_properties["ide_remote_name"] = self.ide_info.get("remoteName", None) + + # Send event to PostHog + if self.posthog is None: + from posthog import Posthog + + # The personal API key is necessary only if you want to use local evaluation of feature flags. + self.posthog = Posthog(self.api_key, host="https://app.posthog.com") + + if is_connected(): + self.posthog.capture(self.unique_id, event_name, event_properties) + else: + if not self._found_disconnected: + self._found_disconnected = True + raise ConnectionError("No internet connection") + + +posthog_logger = PostHogLogger(api_key=POSTHOG_API_KEY) diff --git a/server/continuedev/libs/util/templating.py b/server/continuedev/libs/util/templating.py new file mode 100644 index 00000000..8d6a32fc --- /dev/null +++ b/server/continuedev/libs/util/templating.py @@ -0,0 +1,76 @@ +import os +from typing import Callable, Dict, List, Union + +import chevron + +from ...core.main import ChatMessage + + +def get_vars_in_template(template): + """ + Get the variables in a template + """ + return [ + token[1] + for token in chevron.tokenizer.tokenize(template) + if token[0] == "variable" + ] + + +def escape_var(var: str) -> str: + """ + Escape a variable so it can be used in a template + """ + return var.replace(os.path.sep, "").replace(".", "") + + +def render_templated_string(template: str) -> str: + """ + Render system message or other templated string with mustache syntax. + Right now it only supports rendering absolute file paths as their contents. + """ + vars = get_vars_in_template(template) + + args = {} + for var in vars: + if var.startswith(os.path.sep): + # Escape vars which are filenames, because mustache doesn't allow / in variable names + escaped_var = escape_var(var) + template = template.replace(var, escaped_var) + + if os.path.exists(var): + args[escaped_var] = open(var, "r").read() + else: + args[escaped_var] = "" + + return chevron.render(template, args) + + +""" +A PromptTemplate can either be a template string (mustache syntax, e.g. {{user_input}}) or +a function which takes the history and a dictionary of additional key-value pairs and returns +either a string or a list of ChatMessages. +If a string is returned, it will be assumed that the chat history should be ignored +""" +PromptTemplate = Union[ + str, Callable[[ChatMessage, Dict[str, str]], Union[str, List[ChatMessage]]] +] + + +def render_prompt_template( + template: PromptTemplate, history: List[ChatMessage], other_data: Dict[str, str] +) -> str: + """ + Render a prompt template. + """ + if isinstance(template, str): + data = { + "history": history, + **other_data, + } + if len(history) > 0 and history[0].role == "system": + data["system_message"] = history.pop(0).content + + return chevron.render(template, data) + else: + return template(history, other_data) diff --git a/server/continuedev/libs/util/traceback/traceback_parsers.py b/server/continuedev/libs/util/traceback/traceback_parsers.py new file mode 100644 index 00000000..58a4f728 --- /dev/null +++ b/server/continuedev/libs/util/traceback/traceback_parsers.py @@ -0,0 +1,56 @@ +from boltons import tbutils + +from ....models.main import Traceback + +PYTHON_TRACEBACK_PREFIX = "Traceback (most recent call last):" + + +def get_python_traceback(output: str) -> str: + if PYTHON_TRACEBACK_PREFIX in output: + tb_string = output.split(PYTHON_TRACEBACK_PREFIX)[-1] + + # Then need to remove any lines below the traceback. Do this by noticing that + # the last line of the traceback is the first (other than they prefix) that doesn't begin with whitespace + lines = list(filter(lambda x: x.strip() != "", tb_string.splitlines())) + for i in range(len(lines) - 1): + if not lines[i].startswith(" "): + tb_string = "\n".join(lines[: i + 1]) + break + + return PYTHON_TRACEBACK_PREFIX + "\n" + tb_string + elif "SyntaxError" in output: + return "SyntaxError" + output.split("SyntaxError")[-1] + else: + return None + + +def get_javascript_traceback(output: str) -> str: + lines = output.splitlines() + first_line = None + for i in range(len(lines) - 1): + segs = lines[i].split(":") + if ( + len(segs) > 1 + and segs[0] != "" + and segs[1].startswith(" ") + and lines[i + 1].strip().startswith("at") + ): + first_line = lines[i] + break + + if first_line is not None: + return "\n".join(lines[lines.index(first_line) :]) + else: + return None + + +def parse_python_traceback(tb_string: str) -> Traceback: + # Remove anchor lines - tbutils doesn't always get them right + tb_string = "\n".join( + filter( + lambda x: x.strip().replace("~", "").replace("^", "") != "", + tb_string.splitlines(), + ) + ) + exc = tbutils.ParsedException.from_string(tb_string) + return Traceback.from_tbutil_parsed_exc(exc) |