Spaces:
Sleeping
Sleeping
import os | |
import argparse | |
from tqdm import tqdm | |
import chromadb | |
from chromadb.utils import embedding_functions | |
import google.generativeai as genai | |
def main( | |
documents_directory: str = "documents", | |
collection_name: str = "documents_collection", | |
persist_directory: str = ".", | |
) -> None: | |
# Read all files in the data directory | |
documents = [] | |
metadatas = [] | |
files = os.listdir(documents_directory) | |
for filename in files: | |
with open(f"{documents_directory}/{filename}", "r") as file: | |
for line_number, line in enumerate( | |
tqdm((file.readlines()), desc=f"Reading {filename}"), 1 | |
): | |
# Strip whitespace and append the line to the documents list | |
line = line.strip() | |
# Skip empty lines | |
if len(line) == 0: | |
continue | |
documents.append(line) | |
metadatas.append({"filename": filename, "line_number": line_number}) | |
# Instantiate a persistent chroma client in the persist_directory. | |
# Learn more at docs.trychroma.com | |
client = chromadb.PersistentClient(path=persist_directory) | |
google_api_key = None | |
if "GOOGLE_API_KEY" not in os.environ: | |
gapikey = input("Please enter your Google API Key: ") | |
genai.configure(api_key=gapikey) | |
google_api_key = gapikey | |
else: | |
google_api_key = os.environ["GOOGLE_API_KEY"] | |
# create embedding function | |
embedding_function = embedding_functions.GoogleGenerativeAIEmbeddingFunction(api_key=google_api_key) | |
# If the collection already exists, we just return it. This allows us to add more | |
# data to an existing collection. | |
collection = client.get_or_create_collection( | |
name=collection_name, embedding_function=embedding_function | |
) | |
# Create ids from the current count | |
count = collection.count() | |
print(f"Collection already contains {count} documents") | |
ids = [str(i) for i in range(count, count + len(documents))] | |
# Load the documents in batches of 100 | |
for i in tqdm( | |
range(0, len(documents), 100), desc="Adding documents", unit_scale=100 | |
): | |
collection.add( | |
ids=ids[i : i + 100], | |
documents=documents[i : i + 100], | |
metadatas=metadatas[i : i + 100], # type: ignore | |
) | |
new_count = collection.count() | |
print(f"Added {new_count - count} documents") | |
if __name__ == "__main__": | |
# Read the data directory, collection name, and persist directory | |
parser = argparse.ArgumentParser( | |
description="Load documents from a directory into a Chroma collection" | |
) | |
# Add arguments | |
parser.add_argument( | |
"--data_directory", | |
type=str, | |
default="documents", | |
help="The directory where your text files are stored", | |
) | |
parser.add_argument( | |
"--collection_name", | |
type=str, | |
default="documents_collection", | |
help="The name of the Chroma collection", | |
) | |
parser.add_argument( | |
"--persist_directory", | |
type=str, | |
default="chroma_storage", | |
help="The directory where you want to store the Chroma collection", | |
) | |
# Parse arguments | |
args = parser.parse_args() | |
main( | |
documents_directory=args.data_directory, | |
collection_name=args.collection_name, | |
persist_directory=args.persist_directory, | |
) | |