anirudh248's picture
Update handler.py
108cdf7 verified
raw
history blame
2.2 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
import torch
class Handler:
def __init__(self):
# Load the fine-tuned model and tokenizer
print("Loading model and tokenizer...")
self.model = AutoModelForCausalLM.from_pretrained("anirudh248/upf_code_generator_final", device_map="auto")
self.tokenizer = AutoTokenizer.from_pretrained("anirudh248/upf_code_generator_final")
# Load the FAISS index and embeddings
print("Loading FAISS index and embeddings...")
self.embeddings = HuggingFaceEmbeddings()
self.vectorstore = FAISS.load_local("faiss_index", self.embeddings, allow_dangerous_deserialization=True)
# Create the Hugging Face pipeline for text generation
print("Creating Hugging Face pipeline...")
self.hf_pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1,
temperature=0.7,
max_new_tokens=2048,
top_p=0.95,
repetition_penalty=1.15
)
# Wrap the pipeline in LangChain
self.llm = HuggingFacePipeline(pipeline=self.hf_pipeline)
# Create the retriever and RetrievalQA chain
self.retriever = self.vectorstore.as_retriever()
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
retriever=self.retriever,
return_source_documents=False
)
def __call__(self, request):
try:
# Get the prompt from the request
prompt = request.json.get("prompt")
if not prompt:
return {"error": "Prompt is required"}, 400
# Generate UPF code using the QA chain
response = self.qa_chain.run(prompt)
# Return the response
return {"response": response}
except Exception as e:
return {"error": str(e)}, 500