summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-06-04 00:49:43 -0400
committerNate Sesti <sestinj@gmail.com>2023-06-04 00:49:43 -0400
commit73e1f6c9e4634dbb4bbdde391a6730f5f90653d1 (patch)
tree8e6c00b3664c57fb3ff3ba296dc99382abb37460
parente4a40479cc1aacb1d91481e047ba2790c41ec16c (diff)
parent939986e90772d0bbbf6c63ead00093da4c57c553 (diff)
downloadsncontinue-73e1f6c9e4634dbb4bbdde391a6730f5f90653d1.tar.gz
sncontinue-73e1f6c9e4634dbb4bbdde391a6730f5f90653d1.tar.bz2
sncontinue-73e1f6c9e4634dbb4bbdde391a6730f5f90653d1.zip
Merge branch 'chroma'
-rw-r--r--continuedev/src/continuedev/core/policy.py12
-rw-r--r--continuedev/src/continuedev/libs/chroma/query.py217
-rw-r--r--continuedev/src/continuedev/libs/chroma/replace.py19
-rw-r--r--continuedev/src/continuedev/libs/chroma/update.py161
-rw-r--r--continuedev/src/continuedev/steps/chroma.py29
5 files changed, 195 insertions, 243 deletions
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index a946200e..4287bb6e 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -1,12 +1,12 @@
from typing import List, Tuple, Type
+from ..steps.chroma import AnswerQuestionChroma, EditFileChroma, CreateCodebaseIndexChroma
from ..steps.steps_on_startup import StepsOnStartupStep
from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe
from .main import Step, Validator, History, Policy
from .observation import Observation, TracebackObservation, UserInputObservation
from ..steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep, MessageStep, EmptyStep, SetupContinueWorkspaceStep
from ..recipes.WritePytestsRecipe.main import WritePytestsRecipe
-# from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma
from ..recipes.ContinueRecipeRecipe.main import ContinueStepStep
from ..steps.comment_code import CommentCodeStep
@@ -17,7 +17,7 @@ class DemoPolicy(Policy):
def next(self, history: History) -> Step:
# At the very start, run initial Steps spcecified in the config
if history.get_current() is None:
- return MessageStep(message="Welcome to Continue!") >> SetupContinueWorkspaceStep() >> StepsOnStartupStep()
+ return MessageStep(message="Welcome to Continue!") >> SetupContinueWorkspaceStep() >> CreateCodebaseIndexChroma() >> StepsOnStartupStep()
observation = history.get_current().observation
if observation is not None and isinstance(observation, UserInputObservation):
@@ -28,10 +28,10 @@ class DemoPolicy(Policy):
return CreatePipelineRecipe()
elif "/comment" in observation.user_input.lower():
return CommentCodeStep()
- # elif "/ask" in observation.user_input:
- # return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:]))
- # elif "/edit" in observation.user_input:
- # return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:]))
+ elif "/ask" in observation.user_input:
+ return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:]))
+ elif "/edit" in observation.user_input:
+ return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:]))
elif "/step" in observation.user_input:
return ContinueStepStep(prompt=" ".join(observation.user_input.split(" ")[1:]))
return StarCoderEditHighlightedCodeStep(user_input=observation.user_input)
diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py
index 5a1c89b3..c27329f0 100644
--- a/continuedev/src/continuedev/libs/chroma/query.py
+++ b/continuedev/src/continuedev/libs/chroma/query.py
@@ -1,78 +1,183 @@
+import json
import subprocess
-import sys
-from llama_index import GPTVectorStoreIndex, StorageContext, load_index_from_storage
+from typing import List, Tuple
+from llama_index import GPTVectorStoreIndex, StorageContext, load_index_from_storage, Document
+from llama_index.langchain_helpers.text_splitter import TokenTextSplitter
import os
-from typer import Typer
-from enum import Enum
-from .update import update_codebase_index, create_codebase_index, index_dir_for, get_current_branch
-from .replace import replace_additional_index
+from .update import filter_ignored_files, load_gpt_index_documents
+from functools import cached_property
+
+
+class ChromaIndexManager:
+ workspace_dir: str
+
+ def __init__(self, workspace_dir: str):
+ self.workspace_dir = workspace_dir
+
+ @cached_property
+ def current_commit(self) -> str:
+ """Get the current commit."""
+ return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=self.workspace_dir).decode("utf-8").strip()
+
+ @cached_property
+ def current_branch(self) -> str:
+ """Get the current branch."""
+ return subprocess.check_output(
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=self.workspace_dir).decode("utf-8").strip()
-app = Typer()
+ @cached_property
+ def index_dir(self) -> str:
+ return os.path.join(self.workspace_dir, ".continue", "chroma", self.current_branch)
+ @cached_property
+ def git_root_dir(self):
+ """Get the root directory of a Git repository."""
+ try:
+ return subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], cwd=self.workspace_dir).strip().decode()
+ except subprocess.CalledProcessError:
+ return None
-def query_codebase_index(query: str) -> str:
- """Query the codebase index."""
- branch = subprocess.check_output(
- ["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("utf-8").strip()
- path = index_dir_for(branch)
- if not os.path.exists(path):
- print("No index found for the codebase at ", path)
- return ""
+ def check_index_exists(self):
+ return os.path.exists(os.path.join(self.index_dir, "metadata.json"))
- storage_context = StorageContext.from_defaults(
- persist_dir=index_dir_for(branch))
- index = load_index_from_storage(storage_context)
- # index = GPTVectorStoreIndex.load_from_disk(path)
- engine = index.as_query_engine()
- return engine.query(query)
+ def create_codebase_index(self):
+ """Create a new index for the current branch."""
+ if not self.check_index_exists():
+ os.makedirs(self.index_dir)
+ else:
+ return
+ documents = load_gpt_index_documents(self.workspace_dir)
-def query_additional_index(query: str) -> str:
- """Query the additional index."""
- index = GPTVectorStoreIndex.load_from_disk('data/additional_index.json')
- return index.query(query)
+ chunks = {}
+ doc_chunks = []
+ for doc in documents:
+ text_splitter = TokenTextSplitter()
+ try:
+ text_chunks = text_splitter.split_text(doc.text)
+ except:
+ print("ERROR (probably found special token): ", doc.text)
+ continue
+ filename = doc.extra_info["filename"]
+ chunks[filename] = len(text_chunks)
+ for i, text in enumerate(text_chunks):
+ doc_chunks.append(Document(text, doc_id=f"{filename}::{i}"))
+ with open(f"{self.index_dir}/metadata.json", "w") as f:
+ json.dump({"commit": self.current_commit,
+ "chunks": chunks}, f, indent=4)
-class IndexTypeOption(str, Enum):
- codebase = "codebase"
- additional = "additional"
+ index = GPTVectorStoreIndex([])
+ for chunk in doc_chunks:
+ index.insert(chunk)
-@app.command()
-def query(context: IndexTypeOption, query: str):
- if context == IndexTypeOption.additional:
- response = query_additional_index(query)
- elif context == IndexTypeOption.codebase:
- response = query_codebase_index(query)
- else:
- print("Error: unknown context")
- print({"response": response})
+ # d = 1536 # Dimension of text-ada-embedding-002
+ # faiss_index = faiss.IndexFlatL2(d)
+ # index = GPTFaissIndex(documents, faiss_index=faiss_index)
+ # index.save_to_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index")
+ index.storage_context.persist(persist_dir=self.index_dir)
-@app.command()
-def check_index_exists(root_path: str):
- branch = get_current_branch()
- exists = os.path.exists(index_dir_for(branch))
- print({"exists": exists})
+ print("Codebase index created")
+ def get_modified_deleted_files(self) -> Tuple[List[str], List[str]]:
+ """Get a list of all files that have been modified since the last commit."""
+ metadata = f"{self.index_dir}/metadata.json"
+ with open(metadata, "r") as f:
+ previous_commit = json.load(f)["commit"]
-@app.command()
-def update():
- update_codebase_index()
- print("Updated codebase index")
+ modified_deleted_files = subprocess.check_output(
+ ["git", "diff", "--name-only", previous_commit, self.current_commit]).decode("utf-8").strip()
+ modified_deleted_files = modified_deleted_files.split("\n")
+ modified_deleted_files = [f for f in modified_deleted_files if f]
+ deleted_files = [
+ f for f in modified_deleted_files if not os.path.exists(os.path.join(self.workspace_dir, f))]
+ modified_files = [
+ f for f in modified_deleted_files if os.path.exists(os.path.join(self.workspace_dir, f))]
-@app.command("create")
-def create_index():
- create_codebase_index()
- print("Created file index")
+ return filter_ignored_files(modified_files, self.index_dir), filter_ignored_files(deleted_files, self.index_dir)
+ def update_codebase_index(self):
+ """Update the index with a list of files."""
-@app.command()
-def replace_additional_index(info: str):
- replace_additional_index()
- print("Replaced additional index")
+ if not self.check_index_exists():
+ self.create_codebase_index()
+ else:
+ # index = GPTFaissIndex.load_from_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index")
+ index = GPTVectorStoreIndex.load_from_disk(
+ f"{self.index_dir}/index.json")
+ modified_files, deleted_files = self.get_modified_deleted_files()
+ with open(f"{self.index_dir}/metadata.json", "r") as f:
+ metadata = json.load(f)
-if __name__ == '__main__':
- app()
+ for file in deleted_files:
+
+ num_chunks = metadata["chunks"][file]
+ for i in range(num_chunks):
+ index.delete(f"{file}::{i}")
+
+ del metadata["chunks"][file]
+
+ print(f"Deleted {file}")
+
+ for file in modified_files:
+
+ if file in metadata["chunks"]:
+
+ num_chunks = metadata["chunks"][file]
+
+ for i in range(num_chunks):
+ index.delete(f"{file}::{i}")
+
+ print(f"Deleted old version of {file}")
+
+ with open(file, "r") as f:
+ text = f.read()
+
+ text_splitter = TokenTextSplitter()
+ text_chunks = text_splitter.split_text(text)
+
+ for i, text in enumerate(text_chunks):
+ index.insert(Document(text, doc_id=f"{file}::{i}"))
+
+ metadata["chunks"][file] = len(text_chunks)
+
+ print(f"Inserted new version of {file}")
+
+ metadata["commit"] = self.current_commit
+
+ with open(f"{self.index_dir}/metadata.json", "w") as f:
+ json.dump(metadata, f, indent=4)
+
+ print("Codebase index updated")
+
+ def query_codebase_index(self, query: str) -> str:
+ """Query the codebase index."""
+ if not self.check_index_exists():
+ print("No index found for the codebase at ", self.index_dir)
+ return ""
+
+ storage_context = StorageContext.from_defaults(
+ persist_dir=self.index_dir)
+ index = load_index_from_storage(storage_context)
+ # index = GPTVectorStoreIndex.load_from_disk(path)
+ engine = index.as_query_engine()
+ return engine.query(query)
+
+ def query_additional_index(self, query: str) -> str:
+ """Query the additional index."""
+ index = GPTVectorStoreIndex.load_from_disk(
+ os.path.join(self.index_dir, 'additional_index.json'))
+ return index.query(query)
+
+ def replace_additional_index(self, info: str):
+ """Replace the additional index with the given info."""
+ with open(f'{self.index_dir}/additional_context.txt', 'w') as f:
+ f.write(info)
+ documents = [Document(info)]
+ index = GPTVectorStoreIndex(documents)
+ index.save_to_disk(f'{self.index_dir}/additional_index.json')
+ print("Additional index replaced")
diff --git a/continuedev/src/continuedev/libs/chroma/replace.py b/continuedev/src/continuedev/libs/chroma/replace.py
deleted file mode 100644
index 1868b152..00000000
--- a/continuedev/src/continuedev/libs/chroma/replace.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import sys
-from llama_index import GPTVectorStoreIndex, Document
-
-
-def replace_additional_index(info: str):
- """Replace the additional index."""
- with open('data/additional_context.txt', 'w') as f:
- f.write(info)
- documents = [Document(info)]
- index = GPTVectorStoreIndex(documents)
- index.save_to_disk('data/additional_index.json')
- print("Additional index replaced")
-
-
-if __name__ == "__main__":
- """python3 replace.py <info>"""
- info = sys.argv[1] if len(sys.argv) > 1 else None
- if info:
- replace_additional_index(info)
diff --git a/continuedev/src/continuedev/libs/chroma/update.py b/continuedev/src/continuedev/libs/chroma/update.py
index 3b9eb743..23ed950f 100644
--- a/continuedev/src/continuedev/libs/chroma/update.py
+++ b/continuedev/src/continuedev/libs/chroma/update.py
@@ -1,11 +1,9 @@
# import faiss
-import json
import os
import subprocess
-from llama_index.langchain_helpers.text_splitter import TokenTextSplitter
-from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Document
-from typing import List, Generator, Tuple
+from llama_index import SimpleDirectoryReader, Document
+from typing import List
from dotenv import load_dotenv
load_dotenv()
@@ -21,7 +19,7 @@ FILE_TYPES_TO_IGNORE = [
]
-def further_filter(files: List[str], root_dir: str):
+def filter_ignored_files(files: List[str], root_dir: str):
"""Further filter files before indexing."""
for file in files:
if file.endswith(tuple(FILE_TYPES_TO_IGNORE)) or file.startswith('.git') or file.startswith('archive'):
@@ -29,14 +27,6 @@ def further_filter(files: List[str], root_dir: str):
yield root_dir + "/" + file
-def get_git_root_dir(path: str):
- """Get the root directory of a Git repository."""
- try:
- return subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], cwd=path).strip().decode()
- except subprocess.CalledProcessError:
- return None
-
-
def get_git_ignored_files(root_dir: str):
"""Get the list of ignored files in a Git repository."""
try:
@@ -59,7 +49,7 @@ def get_input_files(root_dir: str):
ignored_files = set(get_git_ignored_files(root_dir))
all_files = set(get_all_files(root_dir))
nonignored_files = all_files - ignored_files
- return further_filter(nonignored_files, root_dir)
+ return filter_ignored_files(nonignored_files, root_dir)
def load_gpt_index_documents(root: str) -> List[Document]:
@@ -68,146 +58,3 @@ def load_gpt_index_documents(root: str) -> List[Document]:
input_files = get_input_files(root)
# Use SimpleDirectoryReader to load the files into Documents
return SimpleDirectoryReader(root, input_files=input_files, file_metadata=lambda filename: {"filename": filename}).load_data()
-
-
-def index_dir_for(branch: str) -> str:
- return f"/Users/natesesti/Desktop/continue/continuedev/src/continuedev/libs/data/{branch}"
-
-
-def get_git_root_dir():
- return "/Users/natesesti/Desktop/continue/extension/examples/python"
- result = subprocess.run(['git', 'rev-parse', '--show-toplevel'],
- stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- return result.stdout.decode().strip()
-
-
-def get_current_branch() -> str:
- return subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("utf-8").strip()
-
-
-def get_current_commit() -> str:
- return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
-
-
-def create_codebase_index():
- """Create a new index for the current branch."""
- branch = get_current_branch()
- if not os.path.exists(index_dir_for(branch)):
- os.makedirs(index_dir_for(branch))
-
- print("ROOT DIRECTORY: ", get_git_root_dir())
- documents = load_gpt_index_documents(get_git_root_dir())
-
- chunks = {}
- doc_chunks = []
- for doc in documents:
- text_splitter = TokenTextSplitter()
- text_chunks = text_splitter.split_text(doc.text)
- filename = doc.extra_info["filename"]
- chunks[filename] = len(text_chunks)
- for i, text in enumerate(text_chunks):
- doc_chunks.append(Document(text, doc_id=f"{filename}::{i}"))
-
- with open(f"{index_dir_for(branch)}/metadata.json", "w") as f:
- json.dump({"commit": get_current_commit(),
- "chunks": chunks}, f, indent=4)
-
- index = GPTVectorStoreIndex([])
-
- for chunk in doc_chunks:
- index.insert(chunk)
-
- # d = 1536 # Dimension of text-ada-embedding-002
- # faiss_index = faiss.IndexFlatL2(d)
- # index = GPTFaissIndex(documents, faiss_index=faiss_index)
- # index.save_to_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index")
-
- index.storage_context.persist(
- persist_dir=index_dir_for(branch))
-
- print("Codebase index created")
-
-
-def get_modified_deleted_files() -> Tuple[List[str], List[str]]:
- """Get a list of all files that have been modified since the last commit."""
- branch = get_current_branch()
- current_commit = get_current_commit()
-
- metadata = f"{index_dir_for(branch)}/metadata.json"
- with open(metadata, "r") as f:
- previous_commit = json.load(f)["commit"]
-
- modified_deleted_files = subprocess.check_output(
- ["git", "diff", "--name-only", previous_commit, current_commit]).decode("utf-8").strip()
- modified_deleted_files = modified_deleted_files.split("\n")
- modified_deleted_files = [f for f in modified_deleted_files if f]
-
- root = get_git_root_dir()
- deleted_files = [
- f for f in modified_deleted_files if not os.path.exists(root + "/" + f)]
- modified_files = [
- f for f in modified_deleted_files if os.path.exists(root + "/" + f)]
-
- return further_filter(modified_files, index_dir_for(branch)), further_filter(deleted_files, index_dir_for(branch))
-
-
-def update_codebase_index():
- """Update the index with a list of files."""
- branch = get_current_branch()
-
- if not os.path.exists(index_dir_for(branch)):
- create_codebase_index()
- else:
- # index = GPTFaissIndex.load_from_disk(f"{index_dir_for(branch)}/index.json", faiss_index_save_path=f"{index_dir_for(branch)}/index_faiss_core.index")
- index = GPTVectorStoreIndex.load_from_disk(
- f"{index_dir_for(branch)}/index.json")
- modified_files, deleted_files = get_modified_deleted_files()
-
- with open(f"{index_dir_for(branch)}/metadata.json", "r") as f:
- metadata = json.load(f)
-
- for file in deleted_files:
-
- num_chunks = metadata["chunks"][file]
- for i in range(num_chunks):
- index.delete(f"{file}::{i}")
-
- del metadata["chunks"][file]
-
- print(f"Deleted {file}")
-
- for file in modified_files:
-
- if file in metadata["chunks"]:
-
- num_chunks = metadata["chunks"][file]
-
- for i in range(num_chunks):
- index.delete(f"{file}::{i}")
-
- print(f"Deleted old version of {file}")
-
- with open(file, "r") as f:
- text = f.read()
-
- text_splitter = TokenTextSplitter()
- text_chunks = text_splitter.split_text(text)
-
- for i, text in enumerate(text_chunks):
- index.insert(Document(text, doc_id=f"{file}::{i}"))
-
- metadata["chunks"][file] = len(text_chunks)
-
- print(f"Inserted new version of {file}")
-
- metadata["commit"] = get_current_commit()
-
- with open(f"{index_dir_for(branch)}/metadata.json", "w") as f:
- json.dump(metadata, f, indent=4)
-
- print("Codebase index updated")
-
-
-if __name__ == "__main__":
- """python3 update.py"""
- update_codebase_index()
diff --git a/continuedev/src/continuedev/steps/chroma.py b/continuedev/src/continuedev/steps/chroma.py
index 59a8b6e0..7bb9389e 100644
--- a/continuedev/src/continuedev/steps/chroma.py
+++ b/continuedev/src/continuedev/steps/chroma.py
@@ -1,12 +1,27 @@
from textwrap import dedent
from typing import Coroutine, Union
from ..core.observation import Observation, TextObservation
-from ..core.main import Step, ContinueSDK
+from ..core.main import Step
+from ..core.sdk import ContinueSDK
from .core.core import EditFileStep
-from ..libs.chroma.query import query_codebase_index
+from ..libs.chroma.query import ChromaIndexManager
from .core.core import EditFileStep
+class CreateCodebaseIndexChroma(Step):
+ name: str = "Create Codebase Index"
+ hide: bool = True
+
+ async def describe(self, llm) -> Coroutine[str, None, None]:
+ return "Indexing the codebase..."
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ index = ChromaIndexManager(await sdk.ide.getWorkspaceDirectory())
+ if not index.check_index_exists():
+ self.hide = False
+ index.create_codebase_index()
+
+
class AnswerQuestionChroma(Step):
question: str
_answer: Union[str, None] = None
@@ -19,7 +34,8 @@ class AnswerQuestionChroma(Step):
return self._answer
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
- results = query_codebase_index(self.question)
+ index = ChromaIndexManager(await sdk.ide.getWorkspaceDirectory())
+ results = index.query_codebase_index(self.question)
code_snippets = ""
@@ -41,7 +57,9 @@ class AnswerQuestionChroma(Step):
Here is the answer:""")
answer = (await sdk.models.gpt35()).complete(prompt)
- print(answer)
+ # Make paths relative to the workspace directory
+ answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "")
+
self._answer = answer
await sdk.ide.setFileOpen(files[0])
@@ -52,7 +70,8 @@ class EditFileChroma(Step):
hide: bool = True
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
- results = query_codebase_index(self.request)
+ index = ChromaIndexManager(await sdk.ide.getWorkspaceDirectory())
+ results = index.query_codebase_index(self.request)
resource_name = list(
results.source_nodes[0].node.relationships.values())[0]