File size: 5,004 Bytes
803aa7a
6258aee
 
 
 
 
 
 
 
803aa7a
ac2f0d0
 
 
541649c
 
 
 
 
 
 
 
ac2f0d0
 
 
4fbfacc
 
44018ec
 
6258aee
 
 
 
 
 
803aa7a
32d8530
 
 
 
 
 
 
 
 
 
4e8afdd
44018ec
6258aee
 
44018ec
 
6258aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44018ec
 
6258aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e9d187
6258aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a30f27
9ffc8e2
3a30f27
 
6258aee
3a30f27
 
6258aee
3a30f27
6258aee
 
 
6e3fb2b
6258aee
803aa7a
3a30f27
426c844
3a30f27
 
 
 
 
6258aee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModel
from litellm import completion
import os
import torch
from sentence_transformers import CrossEncoder

os.environ['GROQ_API_KEY'] = "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"

PROMPT = """/
You are a virtual representative of a retail company and a consultant for customers.
To generate answers, use only information from the context!
Do not ask additional questions, but simply offer the product available in the context!
Your goal is to answer customers' questions, thus helping them.
You should advise the customer in choosing products using the context.
If you could not find a specific answer:
- Answer "I do not know. For more information, please contact: +380954673526" and nothing more.
You always maintain a polite, professional tone. The format of the answer should be simple, understandable and clear. Avoid long explanations if they are not necessary.

"""

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
reranker_model = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-6")

def load_documents(file_paths):
    documents = []
    for path in file_paths:
        with open(path, 'r', encoding='utf-8') as file:
            documents.append(file.read().strip())
    return documents

def load_documents_with_chunking(file_paths, chunk_size=500):
    documents = []
    for path in file_paths:
        with open(path, 'r', encoding='utf-8') as file:
            text = file.read().strip()
            for i in range(0, len(text), chunk_size):
                chunk = text[i:i + chunk_size]
                documents.append(chunk)
    return documents

class Retriver:
    def __init__(self, documents, tokenizer, model):
        self.documents = documents
        self.bm25 = BM25Okapi([doc.split() for doc in documents])
        self.tokenizer = tokenizer
        self.model = model
        self.index = self.create_faiss_index()

    def create_faiss_index(self):
        embeddings = self.embed_documents(self.documents)
        dimension = embeddings.shape[1]
        index = faiss.IndexFlatL2(dimension)
        index.add(embeddings)
        return index

    def embed_documents(self, docs):
        tokens = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt")
        with torch.no_grad():
            embeddings = self.model(**tokens).last_hidden_state.mean(dim=1).numpy()
        return embeddings

    def search_bm25(self, query, top_k=5):
        query_terms = query.split()
        scores = self.bm25.get_scores(query_terms)
        top_indices = np.argsort(scores)[::-1][:top_k]
        return [self.documents[i] for i in top_indices]

    def search_semantic(self, query, top_k=5):
        query_embedding = self.embed_documents([query])
        distances, indices = self.index.search(query_embedding, top_k)
        return [self.documents[i] for i in indices[0]]

class Reranker:
    def __init__(self, reranker):
        self.model = reranker

    def rank(self, query, documents):
        pairs = [(query, doc) for doc in documents]
        scores = self.model.predict(pairs)
        ranked_docs = [documents[i] for i in np.argsort(scores)[::-1]]
        return ranked_docs

class QAChatbot:
    def __init__(self, indexer, reranker):
        self.indexer = indexer
        self.reranker = reranker

    def generate_answer(self, query):
        bm25_results = self.indexer.search_bm25(query)
        semantic_results = self.indexer.search_semantic(query)
        combined_results = list(set(bm25_results + semantic_results))

        ranked_docs = self.reranker.rank(query, combined_results)

        context = "\n".join(ranked_docs[:3])  
        response = completion(
            model="groq/llama3-8b-8192",
            messages=[
                {
                    "role": "system",
                    "content": PROMPT
                },
                {
                    "role": "user",
                    "content": f"Context: {context}\n\nQuestion: {query}\nAnswer:",
                }
            ],
        )
        return response


def chatbot_interface(query, history):
    # file_paths = ["Company_eng.txt", "base_eng.txt"]  
    # documents = load_documents(file_paths)
    
    # indexer = Retriver(documents, tokenizer, model)
    # reranker = Reranker(reranker_model)

    #chatbot = QAChatbot(indexer, reranker)
    answer = chatbot.generate_answer(query)
    return answer["choices"][0]["message"]["content"]

iface = gr.ChatInterface(fn=chatbot_interface, type="messages")

if __name__ == "__main__":
    file_paths = ["Company_eng.txt", "base_eng.txt"]  
    documents = load_documents(file_paths)
    
    indexer = Retriver(documents, tokenizer, model)
    reranker = Reranker(reranker_model)

    chatbot = QAChatbot(indexer, reranker)
    iface.launch()