File size: 2,482 Bytes
314bc09
 
 
 
5e41d4e
 
 
 
 
d3a1fe2
5e41d4e
 
 
4ff8f6d
 
 
 
5e41d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ff8f6d
5e41d4e
 
4ff8f6d
5e41d4e
 
d3a1fe2
5e41d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ff8f6d
5e41d4e
 
 
 
 
 
 
 
 
 
d3a1fe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2566a62
4ff8f6d
5e41d4e
2566a62
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# chat
from QWEN import ChatQWEN
from langchain_core.prompts import ChatPromptTemplate

# db related
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma


def load_db(CHROMA_PATH="chromadb/", MODEL_NAME="Alibaba-NLP/gte-multilingual-base"):
    # setup embeddings
    embeddings = HuggingFaceEmbeddings(
        model_name=MODEL_NAME,
        model_kwargs={
            "device": "cuda",
            "trust_remote_code": True,
        },
        encode_kwargs={"normalize_embeddings": True},
    )

    db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
    return db


def query_db(db, query_text):
    # Search the DB.
    results = db.similarity_search_with_relevance_scores(query_text, k=3)

    # gather in a context
    context_text = "\n\n---\n\n".join(
        [f"{doc.page_content}" for doc, _score in results]
    )
    sources = "\n".join([doc.metadata["source"] for doc, _score in results])

    # return
    return context_text, sources


def load_chain():
    # prompt chat
    prompt = ChatPromptTemplate(
        [
            (
                "system",
                "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
            ),
            (
                "human",
                """Answer the question based only on the following context:

{context}

---

Answer the question based on the above context in question's original language: {question}""",
            ),
        ]
    )

    # model creation
    llm = ChatQWEN()

    # pipeline
    chain = prompt | llm

    return chain


def query(question, db, chain):
    context, sources = query_db(db, question)

    print(f"Context:\n{context}\n*************************")

    # ask
    answer = chain.invoke(
        {
            "context": context,
            "question": question,
        }
    ).content
    print(f"Answer:\n{answer}\n*************************")

    print(f"Sources:\n{sources}")

    return answer, sources


if __name__ == "__main__":
    db = load_db()

    question = "Cor do cabelo de Van Helsing"

    context, sources = query_db(db, question)

    # model creation
    chain = load_chain()

    print(f"Context:\n{context}\n*************************")

    # ask
    answer = chain.invoke(
        {
            "context": context,
            "question": question,
        }
    ).content
    print(f"Answer:\n{answer}\n*************************")

    print(f"Sources:\n{sources}")