|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
import torch |
|
|
|
class ModelHandler: |
|
def __init__(self): |
|
self.model_path = "asritha22bce/bart-positive-tone" |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path) |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) |
|
|
|
def preprocess(self, text): |
|
return self.tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
def inference(self, inputs): |
|
with torch.no_grad(): |
|
output_ids = self.model.generate(**inputs, max_length=50) |
|
return self.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
def postprocess(self, output): |
|
return {"positive_headline": output} |
|
|
|
handler = ModelHandler() |
|
|
|
def handle_request(text): |
|
inputs = handler.preprocess(text) |
|
output = handler.inference(inputs) |
|
return handler.postprocess(output) |
|
|