roneliav commited on
Commit
30aebe0
·
1 Parent(s): 0439c26

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +63 -0
pipeline.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer
3
+
4
+ def get_markers_for_model():
5
+ special_tokens_constants = Namespace()
6
+ special_tokens_constants.separator_different_qa = "&&"
7
+ special_tokens_constants.separator_output_question_answer = "? "
8
+ return special_tokens_constants
9
+
10
+ def load_trained_model(name_or_path):
11
+ tokenizer = AutoTokenizer.from_pretrained(name_or_path)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path)
13
+ return model, tokenizer
14
+
15
+
16
+ class QADiscourse_Pipeline(Text2TextGenerationPipeline):
17
+ def __init__(self, model_repo: str, **kwargs):
18
+ model, tokenizer = load_trained_model(model_repo)
19
+ super().__init__(model, tokenizer, framework="pt")
20
+ self.special_tokens = get_markers_for_model()
21
+
22
+
23
+ def preprocess(self, inputs):
24
+ # Here, inputs is string or list of strings; apply string postprocessing
25
+ return super().preprocess(inputs)
26
+
27
+
28
+
29
+ def _forward(self, *args, **kwargs):
30
+ outputs = super()._forward(*args, **kwargs)
31
+ return outputs
32
+
33
+
34
+ def postprocess(self, model_outputs):
35
+ predictions = self.tokenizer.decode(model_outputs["output_ids"].squeeze(), skip_special_tokens=True, clean_up_tokenization_spaces=False)
36
+ seperated_qas = self._split_to_list(predictions)
37
+ qas = []
38
+ for qa_pair in seperated_qas:
39
+ qas.append(self._postrocess_qa(qa_pair))
40
+ return qas
41
+
42
+ def _split_to_list(self, output_seq: str) -> list:
43
+ return output_seq.split(self.special_tokens.separator_different_qa)
44
+
45
+
46
+ def _postrocess_qa(self, seq: str) -> str:
47
+ # split question and answers
48
+ if self.special_tokens.separator_output_question_answer in seq:
49
+ question, answer = seq.split(self.special_tokens.separator_output_question_answer)
50
+ else:
51
+ print("invalid format: no separator between question and answer found...")
52
+ return None
53
+ return {"question": question, "answer": answer}
54
+
55
+
56
+ if __name__ == "__main__":
57
+ pipe = QADiscourse_Pipeline("RonEliav/QA_discourse")
58
+ res1 = pipe("I don't like chocolate, but I like cookies.")
59
+ res2 = pipe(["I don't like chocolate, but I like cookies.",
60
+ "I dived in the sea easily"], num_beams=10)
61
+ print(res1)
62
+ print(res2)
63
+