|
import torch |
|
|
|
|
|
from QWEN import ChatQWEN |
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_chroma import Chroma |
|
|
|
|
|
def load_db(CHROMA_PATH="chromadb/", MODEL_NAME="Alibaba-NLP/gte-multilingual-base"): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
model_name=MODEL_NAME, |
|
model_kwargs={ |
|
"device": device, |
|
"trust_remote_code": True, |
|
}, |
|
encode_kwargs={"normalize_embeddings": True}, |
|
) |
|
|
|
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings) |
|
return db |
|
|
|
|
|
def query_db(db, query_text): |
|
|
|
results = db.similarity_search_with_relevance_scores(query_text, k=3) |
|
|
|
|
|
context_text = "\n\n---\n\n".join( |
|
[f"{doc.page_content}" for doc, _score in results] |
|
) |
|
sources = "\n".join([doc.metadata["source"] for doc, _score in results]) |
|
|
|
|
|
return context_text, sources |
|
|
|
|
|
def load_chain(): |
|
|
|
prompt = ChatPromptTemplate( |
|
[ |
|
( |
|
"system", |
|
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", |
|
), |
|
( |
|
"human", |
|
"""Answer the question based only on the following context: |
|
|
|
{context} |
|
|
|
--- |
|
|
|
Answer the question based on the above context in question's original language: {question}""", |
|
), |
|
] |
|
) |
|
|
|
|
|
llm = ChatQWEN() |
|
|
|
|
|
chain = prompt | llm |
|
|
|
return chain |
|
|
|
|
|
def query(question, db, chain): |
|
context, sources = query_db(db, question) |
|
|
|
print(f"Context:\n{context}\n*************************") |
|
|
|
|
|
answer = chain.invoke( |
|
{ |
|
"context": context, |
|
"question": question, |
|
} |
|
).content |
|
print(f"Answer:\n{answer}\n*************************") |
|
|
|
print(f"Sources:\n{sources}") |
|
|
|
return answer, sources |
|
|
|
|
|
if __name__ == "__main__": |
|
db = load_db() |
|
|
|
question = "Cor do cabelo de Van Helsing" |
|
|
|
context, sources = query_db(db, question) |
|
|
|
|
|
chain = load_chain() |
|
|
|
print(f"Context:\n{context}\n*************************") |
|
|
|
|
|
answer = chain.invoke( |
|
{ |
|
"context": context, |
|
"question": question, |
|
} |
|
).content |
|
print(f"Answer:\n{answer}\n*************************") |
|
|
|
print(f"Sources:\n{sources}") |
|
|