File size: 931 Bytes
5964e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

class ModelHandler:
    def __init__(self):
        self.model_path = "asritha22bce/bart-positive-tone"  # Change if needed
        self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)

    def preprocess(self, text):
        return self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)

    def inference(self, inputs):
        with torch.no_grad():
            output_ids = self.model.generate(**inputs, max_length=50)
        return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

    def postprocess(self, output):
        return {"positive_headline": output}

handler = ModelHandler()

def handle_request(text):
    inputs = handler.preprocess(text)
    output = handler.inference(inputs)
    return handler.postprocess(output)