File size: 2,482 Bytes
314bc09 5e41d4e d3a1fe2 5e41d4e 4ff8f6d 5e41d4e 4ff8f6d 5e41d4e 4ff8f6d 5e41d4e d3a1fe2 5e41d4e 4ff8f6d 5e41d4e d3a1fe2 2566a62 4ff8f6d 5e41d4e 2566a62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
# chat
from QWEN import ChatQWEN
from langchain_core.prompts import ChatPromptTemplate
# db related
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
def load_db(CHROMA_PATH="chromadb/", MODEL_NAME="Alibaba-NLP/gte-multilingual-base"):
# setup embeddings
embeddings = HuggingFaceEmbeddings(
model_name=MODEL_NAME,
model_kwargs={
"device": "cuda",
"trust_remote_code": True,
},
encode_kwargs={"normalize_embeddings": True},
)
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
return db
def query_db(db, query_text):
# Search the DB.
results = db.similarity_search_with_relevance_scores(query_text, k=3)
# gather in a context
context_text = "\n\n---\n\n".join(
[f"{doc.page_content}" for doc, _score in results]
)
sources = "\n".join([doc.metadata["source"] for doc, _score in results])
# return
return context_text, sources
def load_chain():
# prompt chat
prompt = ChatPromptTemplate(
[
(
"system",
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
),
(
"human",
"""Answer the question based only on the following context:
{context}
---
Answer the question based on the above context in question's original language: {question}""",
),
]
)
# model creation
llm = ChatQWEN()
# pipeline
chain = prompt | llm
return chain
def query(question, db, chain):
context, sources = query_db(db, question)
print(f"Context:\n{context}\n*************************")
# ask
answer = chain.invoke(
{
"context": context,
"question": question,
}
).content
print(f"Answer:\n{answer}\n*************************")
print(f"Sources:\n{sources}")
return answer, sources
if __name__ == "__main__":
db = load_db()
question = "Cor do cabelo de Van Helsing"
context, sources = query_db(db, question)
# model creation
chain = load_chain()
print(f"Context:\n{context}\n*************************")
# ask
answer = chain.invoke(
{
"context": context,
"question": question,
}
).content
print(f"Answer:\n{answer}\n*************************")
print(f"Sources:\n{sources}")
|