upf_code_generator_final / retrieval_qa_pipeline.py
anirudh248's picture
Update retrieval_qa_pipeline.py
d2334dd verified
raw
history blame
3.74 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from datasets import load_dataset
def load_model_and_tokenizer(model_name: str):
"""
Load the pre-trained model and tokenizer from the Hugging Face Hub.
Args:
model_name (str): The Hugging Face repository name of the model.
Returns:
model: The loaded model.
tokenizer: The loaded tokenizer.
"""
print(f"Loading model and tokenizer from {model_name}...")
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
def load_dataset_from_hf(dataset_name: str):
"""
Load the dataset from the Hugging Face Hub.
Args:
dataset_name (str): The Hugging Face repository name of the dataset.
Returns:
texts (list): The text descriptions from the dataset.
metadatas (list): Metadata for each text (e.g., upf_code).
"""
print(f"Loading dataset from {dataset_name}...")
dataset = load_dataset(dataset_name)
texts = dataset["train"]["power_intent_description"]
metadatas = [{"upf_code": code} for code in dataset["train"]["upf_code"]]
return texts, metadatas
def load_faiss_index(faiss_index_path: str):
"""
Load the FAISS index and associated embeddings.
Args:
faiss_index_path (str): Path to the saved FAISS index.
Returns:
vectorstore (FAISS): The FAISS vector store.
"""
print(f"Loading FAISS index from {faiss_index_path}...")
embeddings = HuggingFaceEmbeddings() # Default embeddings
vectorstore = FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)
return vectorstore
def build_retrieval_qa_pipeline(model, tokenizer, vectorstore):
"""
Build the retrieval-based QA pipeline.
Args:
model: The pre-trained model.
tokenizer: The tokenizer associated with the model.
vectorstore (FAISS): The FAISS vector store for retrieval.
Returns:
qa_chain (RetrievalQA): The retrieval-based QA pipeline.
"""
print("Building the retrieval-based QA pipeline...")
hf_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=2048,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.15
)
llm = HuggingFacePipeline(pipeline=hf_pipeline)
retriever = vectorstore.as_retriever()
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
return qa_chain
def main():
# Replace these names with your model and dataset repo names
model_name = "anirudh248/upf_code_generator_final"
dataset_name = "PranavKeshav/upf_dataset"
faiss_index_path = "faiss_index"
print("Starting pipeline setup...")
# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(model_name)
# Load dataset
texts, metadatas = load_dataset_from_hf(dataset_name)
# Load FAISS index
vectorstore = load_faiss_index(faiss_index_path)
# Build QA pipeline
qa_chain = build_retrieval_qa_pipeline(model, tokenizer, vectorstore)
# Test the pipeline
print("Pipeline is ready! You can now ask questions.")
while True:
query = input("Enter your query (or type 'exit' to quit): ")
if query.lower() == "exit":
print("Exiting...")
break
response = qa_chain.run(query)
print(f"Response: {response}")
if __name__ == "__main__":
main()