PranavKeshav commited on
Commit
8cf5b0e
·
verified ·
1 Parent(s): f13119b

Update retrieval_qa_pipeline.py

Browse files
Files changed (1) hide show
  1. retrieval_qa_pipeline.py +118 -0
retrieval_qa_pipeline.py CHANGED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # retrieval_qa_pipeline.py
2
+
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.llms import HuggingFacePipeline
7
+ from langchain.chains import RetrievalQA
8
+ from datasets import load_dataset
9
+
10
+ def load_model_and_tokenizer(model_name: str):
11
+ """
12
+ Load the pre-trained model and tokenizer from the Hugging Face Hub.
13
+
14
+ Args:
15
+ model_name (str): The Hugging Face repository name of the model.
16
+
17
+ Returns:
18
+ model: The loaded model.
19
+ tokenizer: The loaded tokenizer.
20
+ """
21
+ print(f"Loading model and tokenizer from {model_name}...")
22
+ model = AutoModelForCausalLM.from_pretrained(model_name)
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ return model, tokenizer
25
+
26
+ def load_dataset_from_hf(dataset_name: str):
27
+ """
28
+ Load the dataset from the Hugging Face Hub.
29
+
30
+ Args:
31
+ dataset_name (str): The Hugging Face repository name of the dataset.
32
+
33
+ Returns:
34
+ texts (list): The text descriptions from the dataset.
35
+ metadatas (list): Metadata for each text (e.g., upf_code).
36
+ """
37
+ print(f"Loading dataset from {dataset_name}...")
38
+ dataset = load_dataset(dataset_name)
39
+ texts = dataset["train"]["power_intent_description"]
40
+ metadatas = [{"upf_code": code} for code in dataset["train"]["upf_code"]]
41
+ return texts, metadatas
42
+
43
+ def load_faiss_index(faiss_index_path: str):
44
+ """
45
+ Load the FAISS index and associated embeddings.
46
+
47
+ Args:
48
+ faiss_index_path (str): Path to the saved FAISS index.
49
+
50
+ Returns:
51
+ vectorstore (FAISS): The FAISS vector store.
52
+ """
53
+ print(f"Loading FAISS index from {faiss_index_path}...")
54
+ embeddings = HuggingFaceEmbeddings() # Default embeddings
55
+ vectorstore = FAISS.load_local(faiss_index_path, embeddings)
56
+ return vectorstore
57
+
58
+ def build_retrieval_qa_pipeline(model, tokenizer, vectorstore):
59
+ """
60
+ Build the retrieval-based QA pipeline.
61
+
62
+ Args:
63
+ model: The pre-trained model.
64
+ tokenizer: The tokenizer associated with the model.
65
+ vectorstore (FAISS): The FAISS vector store for retrieval.
66
+
67
+ Returns:
68
+ qa_chain (RetrievalQA): The retrieval-based QA pipeline.
69
+ """
70
+ print("Building the retrieval-based QA pipeline...")
71
+ hf_pipeline = pipeline(
72
+ "text-generation",
73
+ model=model,
74
+ tokenizer=tokenizer,
75
+ max_length=2048,
76
+ temperature=0.7,
77
+ top_p=0.95,
78
+ repetition_penalty=1.15
79
+ )
80
+
81
+ llm = HuggingFacePipeline(pipeline=hf_pipeline)
82
+ retriever = vectorstore.as_retriever()
83
+ qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
84
+
85
+ return qa_chain
86
+
87
+ def main():
88
+ # Replace these names with your model and dataset repo names
89
+ model_name = "username/my_fine_tuned_model"
90
+ dataset_name = "PranavKeshav/upf_code"
91
+ faiss_index_path = "faiss_index"
92
+
93
+ print("Starting pipeline setup...")
94
+
95
+ # Load model and tokenizer
96
+ model, tokenizer = load_model_and_tokenizer(model_name)
97
+
98
+ # Load dataset
99
+ texts, metadatas = load_dataset_from_hf(dataset_name)
100
+
101
+ # Load FAISS index
102
+ vectorstore = load_faiss_index(faiss_index_path)
103
+
104
+ # Build QA pipeline
105
+ qa_chain = build_retrieval_qa_pipeline(model, tokenizer, vectorstore)
106
+
107
+ # Test the pipeline
108
+ print("Pipeline is ready! You can now ask questions.")
109
+ while True:
110
+ query = input("Enter your query (or type 'exit' to quit): ")
111
+ if query.lower() == "exit":
112
+ print("Exiting...")
113
+ break
114
+ response = qa_chain.run(query)
115
+ print(f"Response: {response}")
116
+
117
+ if __name__ == "__main__":
118
+ main()