summaryrefslogtreecommitdiff
path: root/extension/scripts/index.py
blob: 3afc91319a2f391dced61b182a59e3facb00a227 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import sys
import os
from typing import TextIO
from chroma import update_collection, query_collection, create_collection, collection_exists, get_current_branch
from typer import Typer

app = Typer()

class SilenceStdoutContextManager:
    saved_stdout: TextIO

    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

silence = SilenceStdoutContextManager()

@app.command("exists")
def exists(cwd: str):
    with silence:
        exists = collection_exists(cwd)
    print({"exists": exists})

@app.command("create")
def create(cwd: str):
    with silence:
        branch = get_current_branch(cwd)
        create_collection(branch, cwd)
    print({"success": True})

@app.command("update")
def update(cwd: str):
    with silence:
        update_collection(cwd)
    print({"success": True})

@app.command("query")
def query(query: str, n_results: int, cwd: str):
    with silence:
        resp = query_collection(query, n_results, cwd)
    results = [{
        "id": resp["ids"][0][i],
        "document": resp["documents"][0][i]
    } for i in range(len(resp["ids"][0]))]
    print({"results": results})

if __name__ == "__main__":
    app()