llama_QA / handler.py
khaledsayed1's picture
Upload handler.py
c81829d verified
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
class ModelHandler:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = None
self.tokenizer = None
self.initialized = False
def initialize(self):
"""Initialize the model and tokenizer"""
if self.initialized:
return
try:
# Load model and tokenizer from the local path
model_path = os.path.dirname(os.path.abspath(__file__))
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.float16 # Use float16 for T4 GPU optimization
)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.initialized = True
except Exception as e:
raise RuntimeError(f"Error initializing model: {str(e)}")
def predict(self, input_data):
"""
Process the input data and generate an answer from the model.
Args:
input_data (dict): The input question.
Returns:
dict: The model's generated answer.
"""
if not self.initialized:
self.initialize()
try:
# Extract the question from input_data
question = input_data.get('question', '')
if not question:
return {"error": "No question provided."}
# Define the prompt with the user's question
alpaca_prompt = f"""
السؤال: {question}
الإجابة:
"""
formatted_prompt = alpaca_prompt.strip()
# Tokenize the input
inputs = self.tokenizer([formatted_prompt], return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate with proper error handling and memory management
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=128,
temperature=0.7,
top_k=50,
top_p=0.95,
use_cache=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode the output
decoded_output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
# Clean up the output
clean_output = decoded_output[0].replace("السؤال:", "").replace("الإجابة:", "").strip()
# Clear CUDA cache if using GPU
if self.device == "cuda":
torch.cuda.empty_cache()
return {"answer": clean_output}
except Exception as e:
return {"error": f"Prediction error: {str(e)}"}
# Create a global handler instance
handler = ModelHandler()
def predict(input_data):
"""
Wrapper function for the handler's predict method
"""
return handler.predict(input_data)