Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from fastapi.responses import JSONResponse | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import os | |
import re | |
import logging | |
app = FastAPI() | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set the cache directory for Hugging Face | |
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache') | |
# Load model and tokenizer | |
model_name = "BIJOY087/Bangla_barta_shurkha_mobilebert" | |
try: | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
logger.info("Model and tokenizer loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load model or tokenizer: {e}") | |
raise RuntimeError(f"Failed to load model or tokenizer: {e}") | |
class TextRequest(BaseModel): | |
text: str | |
class BatchTextRequest(BaseModel): | |
texts: list[str] | |
# Regular expression to detect Bangla characters | |
bangla_regex = re.compile('[\u0980-\u09FF]') | |
def contains_bangla(text): | |
return bool(bangla_regex.search(text)) | |
async def batch_predict(request: BatchTextRequest): | |
try: | |
model.eval() | |
# Prepare the batch results | |
results = [] | |
for idx, text in enumerate(request.texts): | |
# Check if text contains Bangla characters | |
if not contains_bangla(text): | |
results.append({"id": idx + 1, "text": text, "prediction": "other"}) | |
continue | |
# Encode and predict for texts containing Bangla characters | |
inputs = tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=64, | |
truncation=True, | |
padding='max_length', | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
with torch.no_grad(): | |
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits | |
prediction = torch.argmax(logits, dim=1).item() | |
label = "Spam" if prediction == 1 else "Ham" | |
results.append({"id": idx + 1, "text": text, "prediction": label}) | |
logger.info(f"Batch prediction results: {results}") | |
return JSONResponse(content={"results": results}, media_type="application/json; charset=utf-8") | |
except Exception as e: | |
logger.error(f"Batch prediction failed: {e}") | |
raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.") | |
async def root(): | |
return {"message": "Welcome to the MobileBERT API"} | |