diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-06-04 00:49:43 -0400 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-06-04 00:49:43 -0400 |
commit | 73e1f6c9e4634dbb4bbdde391a6730f5f90653d1 (patch) | |
tree | 8e6c00b3664c57fb3ff3ba296dc99382abb37460 | |
parent | e4a40479cc1aacb1d91481e047ba2790c41ec16c (diff) | |
parent | 939986e90772d0bbbf6c63ead00093da4c57c553 (diff) | |
download | sncontinue-73e1f6c9e4634dbb4bbdde391a6730f5f90653d1.tar.gz sncontinue-73e1f6c9e4634dbb4bbdde391a6730f5f90653d1.tar.bz2 sncontinue-73e1f6c9e4634dbb4bbdde391a6730f5f90653d1.zip |
Merge branch 'chroma'
-rw-r--r-- | continuedev/src/continuedev/core/policy.py | 12 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/chroma/query.py | 217 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/chroma/replace.py | 19 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/chroma/update.py | 161 | ||||
-rw-r--r-- | continuedev/src/continuedev/steps/chroma.py | 29 |
5 files changed, 195 insertions, 243 deletions
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index a946200e..4287bb6e 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -1,12 +1,12 @@ from typing import List, Tuple, Type +from ..steps.chroma import AnswerQuestionChroma, EditFileChroma, CreateCodebaseIndexChroma from ..steps.steps_on_startup import StepsOnStartupStep from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe from .main import Step, Validator, History, Policy from .observation import Observation, TracebackObservation, UserInputObservation from ..steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep, MessageStep, EmptyStep, SetupContinueWorkspaceStep from ..recipes.WritePytestsRecipe.main import WritePytestsRecipe -# from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma from ..recipes.ContinueRecipeRecipe.main import ContinueStepStep from ..steps.comment_code import CommentCodeStep @@ -17,7 +17,7 @@ class DemoPolicy(Policy): def next(self, history: History) -> Step: # At the very start, run initial Steps spcecified in the config if history.get_current() is None: - return MessageStep(message="Welcome to Continue!") >> SetupContinueWorkspaceStep() >> StepsOnStartupStep() + return MessageStep(message="Welcome to Continue!") >> SetupContinueWorkspaceStep() >> CreateCodebaseIndexChroma() >> StepsOnStartupStep() observation = history.get_current().observation if observation is not None and isinstance(observation, UserInputObservation): @@ -28,10 +28,10 @@ class DemoPolicy(Policy): return CreatePipelineRecipe() elif "/comment" in observation.user_input.lower(): return CommentCodeStep() - # elif "/ask" in observation.user_input: - # return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:])) - # elif "/edit" in observation.user_input: - # return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:])) + elif "/ask" in observation.user_input: + return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:])) + elif "/edit" in observation.user_input: + return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:])) elif "/step" in observation.user_input: return ContinueStepStep(prompt=" ".join(observation.user_input.split(" ")[1:])) return StarCoderEditHighlightedCodeStep(user_input=observation.user_input) diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py index 5a1c89b3..c27329f0 100644 --- a/continuedev/src/continuedev/libs/chroma/query.py +++ b/continuedev/src/continuedev/libs/chroma/query.py @@ -1,78 +1,183 @@ +import json import subprocess -import sys -from llama_index import GPTVectorStoreIndex, StorageContext, load_index_from_storage +from typing import List, Tuple +from llama_index import GPTVectorStoreIndex, StorageContext, load_index_from_storage, Document +from llama_index.langchain_helpers.text_splitter import TokenTextSplitter import os -from typer import Typer -from enum import Enum -from .update import update_codebase_index, create_codebase_index, index_dir_for, get_current_branch -from .replace import replace_additional_index +from .update import filter_ignored_files, load_gpt_index_documents +from functools import cached_property + + +class ChromaIndexManager: + workspace_dir: str + + def __init__(self, workspace_dir: str): + self.workspace_dir = workspace_dir + + @cached_property + def current_commit(self) -> str: + """Get the current commit.""" + return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=self.workspace_dir).decode("utf-8").strip() + + @cached_property + def current_branch(self) -> str: + """Get the current branch.""" + return subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=self.workspace_dir).decode("utf-8").strip() -app = Typer() + @cached_property + def index_dir(self) -> str: + return os.path.join(self.workspace_dir, ".continue", "chroma", self.current_branch) + @cached_property + def git_root_dir(self): + """Get the root directory of a Git repository.""" + try: + return subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], cwd=self.workspace_dir).strip().decode() + except subprocess.CalledProcessError: + return None -def query_codebase_index(query: str) -> str: - """Query the codebase index.""" - branch = subprocess.check_output( - ["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("utf-8").strip() - path = index_dir_for(branch) - if not os.path.exists(path): - print("No index found for the codebase at ", path) - return "" + def check_index_exists(self): + return os.path.exists(os.path.join(self.index_dir, "metadata.json")) - storage_context = StorageContext.from_defaults( - persist_dir=index_dir_for(branch)) - index = load_index_from_storage(storage_context) - # index = GPTVectorStoreIndex.load_from_disk(path) - engine = index.as_query_engine() - return engine.query(query) + def create_codebase_index(self): + """Create a new index for the current branch.""" + if not self.check_index_exists(): + os.makedirs(self.index_dir) + else: + return + documents = load_gpt_index_documents(self.workspace_dir) -def query_additional_index(query: str) -> str: - """Query the additional index.""" - index = GPTVectorStoreIndex.load_from_disk('data/additional_index.json') - return index.query(query) + chunks = {} + doc_chunks = [] + for doc in documents: + text_splitter = TokenTextSplitter() + try: + text_chunks = text_splitter.split_text(doc.text) + except: + print("ERROR (probably found special token): ", doc.text) + continue + filename = doc.extra_info["filename"] + chunks[filename] = len(text_chunks) + for i, text in enumerate(text_chunks): + doc_chunks.append(Document(text, doc_id=f"{filename}::{i}")) + with open(f"{self.index_dir}/metadata.json", "w") as f: + json.dump({"commit": self.current_commit, + "chunks": chunks}, f, indent=4) -class IndexTypeOption(str, Enum): - codebase = "codebase" - additional = "additional" + index = GPTVectorStoreIndex([]) + for chunk in doc_chunks: + index.insert(chunk) -@app.command() -def query(context: IndexTypeOption, query: str): - if context == IndexTypeOption.additional: - response = query_additional_index(query) - elif context == IndexTypeOption.codebase: - response = query_codebase_index(query) - else: - print("Error: unknown context") - print({"response": response}) + # d = 1536 # Dimension of text-ada-embedding-002 + # faiss_index = faiss.IndexFlatL2(d) + # index = GPTFaissIndex(documents, faiss_index=faiss_index) + # index.save_to_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index") + index.storage_context.persist(persist_dir=self.index_dir) -@app.command() -def check_index_exists(root_path: str): - branch = get_current_branch() - exists = os.path.exists(index_dir_for(branch)) - print({"exists": exists}) + print("Codebase index created") + def get_modified_deleted_files(self) -> Tuple[List[str], List[str]]: + """Get a list of all files that have been modified since the last commit.""" + metadata = f"{self.index_dir}/metadata.json" + with open(metadata, "r") as f: + previous_commit = json.load(f)["commit"] -@app.command() -def update(): - update_codebase_index() - print("Updated codebase index") + modified_deleted_files = subprocess.check_output( + ["git", "diff", "--name-only", previous_commit, self.current_commit]).decode("utf-8").strip() + modified_deleted_files = modified_deleted_files.split("\n") + modified_deleted_files = [f for f in modified_deleted_files if f] + deleted_files = [ + f for f in modified_deleted_files if not os.path.exists(os.path.join(self.workspace_dir, f))] + modified_files = [ + f for f in modified_deleted_files if os.path.exists(os.path.join(self.workspace_dir, f))] -@app.command("create") -def create_index(): - create_codebase_index() - print("Created file index") + return filter_ignored_files(modified_files, self.index_dir), filter_ignored_files(deleted_files, self.index_dir) + def update_codebase_index(self): + """Update the index with a list of files.""" -@app.command() -def replace_additional_index(info: str): - replace_additional_index() - print("Replaced additional index") + if not self.check_index_exists(): + self.create_codebase_index() + else: + # index = GPTFaissIndex.load_from_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index") + index = GPTVectorStoreIndex.load_from_disk( + f"{self.index_dir}/index.json") + modified_files, deleted_files = self.get_modified_deleted_files() + with open(f"{self.index_dir}/metadata.json", "r") as f: + metadata = json.load(f) -if __name__ == '__main__': - app() + for file in deleted_files: + + num_chunks = metadata["chunks"][file] + for i in range(num_chunks): + index.delete(f"{file}::{i}") + + del metadata["chunks"][file] + + print(f"Deleted {file}") + + for file in modified_files: + + if file in metadata["chunks"]: + + num_chunks = metadata["chunks"][file] + + for i in range(num_chunks): + index.delete(f"{file}::{i}") + + print(f"Deleted old version of {file}") + + with open(file, "r") as f: + text = f.read() + + text_splitter = TokenTextSplitter() + text_chunks = text_splitter.split_text(text) + + for i, text in enumerate(text_chunks): + index.insert(Document(text, doc_id=f"{file}::{i}")) + + metadata["chunks"][file] = len(text_chunks) + + print(f"Inserted new version of {file}") + + metadata["commit"] = self.current_commit + + with open(f"{self.index_dir}/metadata.json", "w") as f: + json.dump(metadata, f, indent=4) + + print("Codebase index updated") + + def query_codebase_index(self, query: str) -> str: + """Query the codebase index.""" + if not self.check_index_exists(): + print("No index found for the codebase at ", self.index_dir) + return "" + + storage_context = StorageContext.from_defaults( + persist_dir=self.index_dir) + index = load_index_from_storage(storage_context) + # index = GPTVectorStoreIndex.load_from_disk(path) + engine = index.as_query_engine() + return engine.query(query) + + def query_additional_index(self, query: str) -> str: + """Query the additional index.""" + index = GPTVectorStoreIndex.load_from_disk( + os.path.join(self.index_dir, 'additional_index.json')) + return index.query(query) + + def replace_additional_index(self, info: str): + """Replace the additional index with the given info.""" + with open(f'{self.index_dir}/additional_context.txt', 'w') as f: + f.write(info) + documents = [Document(info)] + index = GPTVectorStoreIndex(documents) + index.save_to_disk(f'{self.index_dir}/additional_index.json') + print("Additional index replaced") diff --git a/continuedev/src/continuedev/libs/chroma/replace.py b/continuedev/src/continuedev/libs/chroma/replace.py deleted file mode 100644 index 1868b152..00000000 --- a/continuedev/src/continuedev/libs/chroma/replace.py +++ /dev/null @@ -1,19 +0,0 @@ -import sys -from llama_index import GPTVectorStoreIndex, Document - - -def replace_additional_index(info: str): - """Replace the additional index.""" - with open('data/additional_context.txt', 'w') as f: - f.write(info) - documents = [Document(info)] - index = GPTVectorStoreIndex(documents) - index.save_to_disk('data/additional_index.json') - print("Additional index replaced") - - -if __name__ == "__main__": - """python3 replace.py <info>""" - info = sys.argv[1] if len(sys.argv) > 1 else None - if info: - replace_additional_index(info) diff --git a/continuedev/src/continuedev/libs/chroma/update.py b/continuedev/src/continuedev/libs/chroma/update.py index 3b9eb743..23ed950f 100644 --- a/continuedev/src/continuedev/libs/chroma/update.py +++ b/continuedev/src/continuedev/libs/chroma/update.py @@ -1,11 +1,9 @@ # import faiss -import json import os import subprocess -from llama_index.langchain_helpers.text_splitter import TokenTextSplitter -from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Document -from typing import List, Generator, Tuple +from llama_index import SimpleDirectoryReader, Document +from typing import List from dotenv import load_dotenv load_dotenv() @@ -21,7 +19,7 @@ FILE_TYPES_TO_IGNORE = [ ] -def further_filter(files: List[str], root_dir: str): +def filter_ignored_files(files: List[str], root_dir: str): """Further filter files before indexing.""" for file in files: if file.endswith(tuple(FILE_TYPES_TO_IGNORE)) or file.startswith('.git') or file.startswith('archive'): @@ -29,14 +27,6 @@ def further_filter(files: List[str], root_dir: str): yield root_dir + "/" + file -def get_git_root_dir(path: str): - """Get the root directory of a Git repository.""" - try: - return subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], cwd=path).strip().decode() - except subprocess.CalledProcessError: - return None - - def get_git_ignored_files(root_dir: str): """Get the list of ignored files in a Git repository.""" try: @@ -59,7 +49,7 @@ def get_input_files(root_dir: str): ignored_files = set(get_git_ignored_files(root_dir)) all_files = set(get_all_files(root_dir)) nonignored_files = all_files - ignored_files - return further_filter(nonignored_files, root_dir) + return filter_ignored_files(nonignored_files, root_dir) def load_gpt_index_documents(root: str) -> List[Document]: @@ -68,146 +58,3 @@ def load_gpt_index_documents(root: str) -> List[Document]: input_files = get_input_files(root) # Use SimpleDirectoryReader to load the files into Documents return SimpleDirectoryReader(root, input_files=input_files, file_metadata=lambda filename: {"filename": filename}).load_data() - - -def index_dir_for(branch: str) -> str: - return f"/Users/natesesti/Desktop/continue/continuedev/src/continuedev/libs/data/{branch}" - - -def get_git_root_dir(): - return "/Users/natesesti/Desktop/continue/extension/examples/python" - result = subprocess.run(['git', 'rev-parse', '--show-toplevel'], - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - return result.stdout.decode().strip() - - -def get_current_branch() -> str: - return subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("utf-8").strip() - - -def get_current_commit() -> str: - return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip() - - -def create_codebase_index(): - """Create a new index for the current branch.""" - branch = get_current_branch() - if not os.path.exists(index_dir_for(branch)): - os.makedirs(index_dir_for(branch)) - - print("ROOT DIRECTORY: ", get_git_root_dir()) - documents = load_gpt_index_documents(get_git_root_dir()) - - chunks = {} - doc_chunks = [] - for doc in documents: - text_splitter = TokenTextSplitter() - text_chunks = text_splitter.split_text(doc.text) - filename = doc.extra_info["filename"] - chunks[filename] = len(text_chunks) - for i, text in enumerate(text_chunks): - doc_chunks.append(Document(text, doc_id=f"{filename}::{i}")) - - with open(f"{index_dir_for(branch)}/metadata.json", "w") as f: - json.dump({"commit": get_current_commit(), - "chunks": chunks}, f, indent=4) - - index = GPTVectorStoreIndex([]) - - for chunk in doc_chunks: - index.insert(chunk) - - # d = 1536 # Dimension of text-ada-embedding-002 - # faiss_index = faiss.IndexFlatL2(d) - # index = GPTFaissIndex(documents, faiss_index=faiss_index) - # index.save_to_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index") - - index.storage_context.persist( - persist_dir=index_dir_for(branch)) - - print("Codebase index created") - - -def get_modified_deleted_files() -> Tuple[List[str], List[str]]: - """Get a list of all files that have been modified since the last commit.""" - branch = get_current_branch() - current_commit = get_current_commit() - - metadata = f"{index_dir_for(branch)}/metadata.json" - with open(metadata, "r") as f: - previous_commit = json.load(f)["commit"] - - modified_deleted_files = subprocess.check_output( - ["git", "diff", "--name-only", previous_commit, current_commit]).decode("utf-8").strip() - modified_deleted_files = modified_deleted_files.split("\n") - modified_deleted_files = [f for f in modified_deleted_files if f] - - root = get_git_root_dir() - deleted_files = [ - f for f in modified_deleted_files if not os.path.exists(root + "/" + f)] - modified_files = [ - f for f in modified_deleted_files if os.path.exists(root + "/" + f)] - - return further_filter(modified_files, index_dir_for(branch)), further_filter(deleted_files, index_dir_for(branch)) - - -def update_codebase_index(): - """Update the index with a list of files.""" - branch = get_current_branch() - - if not os.path.exists(index_dir_for(branch)): - create_codebase_index() - else: - # index = GPTFaissIndex.load_from_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index") - index = GPTVectorStoreIndex.load_from_disk( - f"{index_dir_for(branch)}/index.json") - modified_files, deleted_files = get_modified_deleted_files() - - with open(f"{index_dir_for(branch)}/metadata.json", "r") as f: - metadata = json.load(f) - - for file in deleted_files: - - num_chunks = metadata["chunks"][file] - for i in range(num_chunks): - index.delete(f"{file}::{i}") - - del metadata["chunks"][file] - - print(f"Deleted {file}") - - for file in modified_files: - - if file in metadata["chunks"]: - - num_chunks = metadata["chunks"][file] - - for i in range(num_chunks): - index.delete(f"{file}::{i}") - - print(f"Deleted old version of {file}") - - with open(file, "r") as f: - text = f.read() - - text_splitter = TokenTextSplitter() - text_chunks = text_splitter.split_text(text) - - for i, text in enumerate(text_chunks): - index.insert(Document(text, doc_id=f"{file}::{i}")) - - metadata["chunks"][file] = len(text_chunks) - - print(f"Inserted new version of {file}") - - metadata["commit"] = get_current_commit() - - with open(f"{index_dir_for(branch)}/metadata.json", "w") as f: - json.dump(metadata, f, indent=4) - - print("Codebase index updated") - - -if __name__ == "__main__": - """python3 update.py""" - update_codebase_index() diff --git a/continuedev/src/continuedev/steps/chroma.py b/continuedev/src/continuedev/steps/chroma.py index 59a8b6e0..7bb9389e 100644 --- a/continuedev/src/continuedev/steps/chroma.py +++ b/continuedev/src/continuedev/steps/chroma.py @@ -1,12 +1,27 @@ from textwrap import dedent from typing import Coroutine, Union from ..core.observation import Observation, TextObservation -from ..core.main import Step, ContinueSDK +from ..core.main import Step +from ..core.sdk import ContinueSDK from .core.core import EditFileStep -from ..libs.chroma.query import query_codebase_index +from ..libs.chroma.query import ChromaIndexManager from .core.core import EditFileStep +class CreateCodebaseIndexChroma(Step): + name: str = "Create Codebase Index" + hide: bool = True + + async def describe(self, llm) -> Coroutine[str, None, None]: + return "Indexing the codebase..." + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + index = ChromaIndexManager(await sdk.ide.getWorkspaceDirectory()) + if not index.check_index_exists(): + self.hide = False + index.create_codebase_index() + + class AnswerQuestionChroma(Step): question: str _answer: Union[str, None] = None @@ -19,7 +34,8 @@ class AnswerQuestionChroma(Step): return self._answer async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - results = query_codebase_index(self.question) + index = ChromaIndexManager(await sdk.ide.getWorkspaceDirectory()) + results = index.query_codebase_index(self.question) code_snippets = "" @@ -41,7 +57,9 @@ class AnswerQuestionChroma(Step): Here is the answer:""") answer = (await sdk.models.gpt35()).complete(prompt) - print(answer) + # Make paths relative to the workspace directory + answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "") + self._answer = answer await sdk.ide.setFileOpen(files[0]) @@ -52,7 +70,8 @@ class EditFileChroma(Step): hide: bool = True async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - results = query_codebase_index(self.request) + index = ChromaIndexManager(await sdk.ide.getWorkspaceDirectory()) + results = index.query_codebase_index(self.request) resource_name = list( results.source_nodes[0].node.relationships.values())[0] |