File size: 1,374 Bytes
50fee82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from transformers import T5Tokenizer, MT5ForConditionalGeneration

class T5_B(object):
    def __init__(self, model: str = "google/t5-large-ssm", device = 'cuda:0'):
        self.device = device
        self.tokenizer = T5Tokenizer.from_pretrained(model)
        if device == 'multigpu':
            self.model = MT5ForConditionalGeneration.from_pretrained(model).eval()
            self.model.parallelize()
        else:
            self.model = MT5ForConditionalGeneration.from_pretrained(model).to(device).eval()

    def predict(self, question: str):
        device = 'cuda:0' if self.device == 'multigpu' else self.device
        encode = self.tokenizer(question, return_tensors='pt').to(device)
        answer = self.model.generate(encode.input_ids)[0]
        decoded = self.tokenizer.decode(answer, skip_special_tokens=True)
        return decoded

    def predict_batch(self, question_list):
        assert type(question_list) == type([])
        device = 'cuda:0' if self.device == 'multigpu' else self.device
        encode = self.tokenizer(question_list, return_tensors='pt', padding = True).to(device)
        answer = self.model.generate(**encode)
        #return answer
        decoded = [self.tokenizer.decode(ans, skip_special_tokens=True) for ans in answer]
        #decoded = self.tokenizer.decode(answer, skip_special_tokens=True)
        return decoded