Bijoy09 commited on
Commit
66a2bf7
·
verified ·
1 Parent(s): 41ec2d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from fastapi.responses import JSONResponse
5
+ import torch
6
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
7
+ import os
8
+ import re
9
+ import logging
10
+
11
+ app = FastAPI()
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Set the cache directory for Hugging Face
18
+ os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
19
+
20
+ # Load model and tokenizer
21
+ model_name = "BIJOY087/Bangla_barta_shurkha_mobilebert"
22
+ try:
23
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ logger.info("Model and tokenizer loaded successfully")
26
+ except Exception as e:
27
+ logger.error(f"Failed to load model or tokenizer: {e}")
28
+ raise RuntimeError(f"Failed to load model or tokenizer: {e}")
29
+
30
+ class TextRequest(BaseModel):
31
+ text: str
32
+
33
+ class BatchTextRequest(BaseModel):
34
+ texts: list[str]
35
+
36
+ # Regular expression to detect Bangla characters
37
+ bangla_regex = re.compile('[\u0980-\u09FF]')
38
+
39
+ def contains_bangla(text):
40
+ return bool(bangla_regex.search(text))
41
+
42
+ @app.post("/batch_predict/")
43
+ async def batch_predict(request: BatchTextRequest):
44
+ try:
45
+ model.eval()
46
+
47
+ # Prepare the batch results
48
+ results = []
49
+
50
+ for idx, text in enumerate(request.texts):
51
+
52
+ # Check if text contains Bangla characters
53
+ if not contains_bangla(text):
54
+ results.append({"id": idx + 1, "text": text, "prediction": "other"})
55
+ continue
56
+
57
+ # Encode and predict for texts containing Bangla characters
58
+ inputs = tokenizer.encode_plus(
59
+ text,
60
+ add_special_tokens=True,
61
+ max_length=64,
62
+ truncation=True,
63
+ padding='max_length',
64
+ return_attention_mask=True,
65
+ return_tensors='pt'
66
+ )
67
+
68
+ with torch.no_grad():
69
+ logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
70
+ prediction = torch.argmax(logits, dim=1).item()
71
+ label = "Spam" if prediction == 1 else "Ham"
72
+ results.append({"id": idx + 1, "text": text, "prediction": label})
73
+
74
+ logger.info(f"Batch prediction results: {results}")
75
+ return JSONResponse(content={"results": results}, media_type="application/json; charset=utf-8")
76
+
77
+ except Exception as e:
78
+ logger.error(f"Batch prediction failed: {e}")
79
+ raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
80
+
81
+ @app.get("/")
82
+ async def root():
83
+ return {"message": "Welcome to the MobileBERT API"}