Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
from typing import List | |
import google.generativeai as genai | |
import chromadb | |
from chromadb.utils import embedding_functions | |
model = genai.GenerativeModel("gemini-pro") | |
def build_prompt(query: str, context: List[str]) -> str: | |
""" | |
Builds a prompt for the LLM. # | |
This function builds a prompt for the LLM. It takes the original query, | |
and the returned context, and asks the model to answer the question based only | |
on what's in the context, not what's in its weights. | |
Args: | |
query (str): The original query. | |
context (List[str]): The context of the query, returned by embedding search. | |
Returns: | |
A prompt for the LLM (str). | |
""" | |
base_prompt = { | |
"content": "I am going to ask you a question, which I would like you to answer" | |
" based only on the provided context, and not any other information." | |
" If there is not enough information in the context to answer the question," | |
' say "I am not sure", then try to make a guess.' | |
" Break your answer up into nicely readable paragraphs.", | |
} | |
user_prompt = { | |
"content": f" The question is '{query}'. Here is all the context you have:" | |
f'{(" ").join(context)}', | |
} | |
# combine the prompts to output a single prompt string | |
system = f"{base_prompt['content']} {user_prompt['content']}" | |
return system | |
def get_gemini_response(query: str, context: List[str]) -> str: | |
""" | |
Queries the Gemini API to get a response to the question. | |
Args: | |
query (str): The original query. | |
context (List[str]): The context of the query, returned by embedding search. | |
Returns: | |
A response to the question. | |
""" | |
response = model.generate_content(build_prompt(query, context)) | |
return response.text | |
def main( | |
collection_name: str = "documents_collection", persist_directory: str = "." | |
) -> None: | |
# Check if the GOOGLE_API_KEY environment variable is set. Prompt the user to set it if not. | |
google_api_key = None | |
if "GOOGLE_API_KEY" not in os.environ: | |
gapikey = input("Please enter your Google API Key: ") | |
genai.configure(api_key=gapikey) | |
google_api_key = gapikey | |
else: | |
google_api_key = os.environ["GOOGLE_API_KEY"] | |
# Instantiate a persistent chroma client in the persist_directory. | |
# This will automatically load any previously saved collections. | |
# Learn more at docs.trychroma.com | |
client = chromadb.PersistentClient(path=persist_directory) | |
# create embedding function | |
embedding_function = embedding_functions.GoogleGenerativeAIEmbeddingFunction(api_key=google_api_key, task_type="RETRIEVAL_QUERY") | |
# Get the collection. | |
collection = client.get_collection( | |
name=collection_name, embedding_function=embedding_function | |
) | |
# We use a simple input loop. | |
while True: | |
# Get the user's query | |
query = input("Query: ") | |
if len(query) == 0: | |
print("Please enter a question. Ctrl+C to Quit.\n") | |
continue | |
print("\nThinking...\n") | |
# Query the collection to get the 5 most relevant results | |
results = collection.query( | |
query_texts=[query], n_results=5, include=["documents", "metadatas"] | |
) | |
sources = "\n".join( | |
[ | |
f"{result['filename']}: line {result['line_number']}" | |
for result in results["metadatas"][0] # type: ignore | |
] | |
) | |
# Get the response from Gemini | |
response = get_gemini_response(query, results["documents"][0]) # type: ignore | |
# Output, with sources | |
print(response) | |
print("\n") | |
print(f"Source documents:\n{sources}") | |
print("\n") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Load documents from a directory into a Chroma collection" | |
) | |
parser.add_argument( | |
"--persist_directory", | |
type=str, | |
default="chroma_storage", | |
help="The directory where you want to store the Chroma collection", | |
) | |
parser.add_argument( | |
"--collection_name", | |
type=str, | |
default="documents_collection", | |
help="The name of the Chroma collection", | |
) | |
# Parse arguments | |
args = parser.parse_args() | |
main( | |
collection_name=args.collection_name, | |
persist_directory=args.persist_directory, | |
) | |