|
import torch |
|
import os |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
class ModelHandler: |
|
def __init__(self): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model = None |
|
self.tokenizer = None |
|
self.initialized = False |
|
|
|
def initialize(self): |
|
"""Initialize the model and tokenizer""" |
|
if self.initialized: |
|
return |
|
|
|
try: |
|
|
|
model_path = os.path.dirname(os.path.abspath(__file__)) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
device_map="auto", |
|
torch_dtype=torch.float16 |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
self.initialized = True |
|
except Exception as e: |
|
raise RuntimeError(f"Error initializing model: {str(e)}") |
|
|
|
def predict(self, input_data): |
|
""" |
|
Process the input data and generate an answer from the model. |
|
Args: |
|
input_data (dict): The input question. |
|
Returns: |
|
dict: The model's generated answer. |
|
""" |
|
if not self.initialized: |
|
self.initialize() |
|
|
|
try: |
|
|
|
question = input_data.get('question', '') |
|
if not question: |
|
return {"error": "No question provided."} |
|
|
|
|
|
alpaca_prompt = f""" |
|
السؤال: {question} |
|
الإجابة: |
|
""" |
|
formatted_prompt = alpaca_prompt.strip() |
|
|
|
|
|
inputs = self.tokenizer([formatted_prompt], return_tensors="pt") |
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=128, |
|
temperature=0.7, |
|
top_k=50, |
|
top_p=0.95, |
|
use_cache=True, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
|
|
decoded_output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
clean_output = decoded_output[0].replace("السؤال:", "").replace("الإجابة:", "").strip() |
|
|
|
|
|
if self.device == "cuda": |
|
torch.cuda.empty_cache() |
|
|
|
return {"answer": clean_output} |
|
|
|
except Exception as e: |
|
return {"error": f"Prediction error: {str(e)}"} |
|
|
|
|
|
handler = ModelHandler() |
|
|
|
def predict(input_data): |
|
""" |
|
Wrapper function for the handler's predict method |
|
""" |
|
return handler.predict(input_data) |
|
|