roneliav commited on
Commit
0030578
·
1 Parent(s): 269c1e7

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +73 -0
pipeline.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer
3
+
4
+ def get_markers_for_model():
5
+ special_tokens_constants = Namespace()
6
+ special_tokens_constants.separator_different_qa = "&&&"
7
+ special_tokens_constants.separator_output_question_answer = "SSEEPP"
8
+ special_tokens_constants.source_prefix = "qa: "
9
+ return special_tokens_constants
10
+
11
+ def load_trained_model(name_or_path):
12
+ tokenizer = AutoTokenizer.from_pretrained(name_or_path)
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path)
14
+ return model, tokenizer
15
+
16
+
17
+ class QADiscourse_Pipeline(Text2TextGenerationPipeline):
18
+ def __init__(self, model_repo: str, **kwargs):
19
+ model, tokenizer = load_trained_model(model_repo)
20
+ super().__init__(model, tokenizer, framework="pt")
21
+ self.special_tokens = get_markers_for_model()
22
+
23
+
24
+ def preprocess(self, inputs):
25
+ if isinstance(inputs, str):
26
+ processed_inputs = self._preprocess_string(inputs)
27
+ elif hasattr(inputs, "__iter__"):
28
+ processed_inputs = [self._preprocess_string(s) for s in inputs]
29
+ else:
30
+ raise ValueError("inputs must be str or Iterable[str]")
31
+ # Now pass to super.preprocess for tokenization
32
+ return super().preprocess(processed_inputs)
33
+
34
+ def _preprocess_string(self, seq: str) -> str:
35
+ seq = self.special_tokens.source_prefix + seq
36
+ print(seq)
37
+ return seq
38
+
39
+ def _forward(self, *args, **kwargs):
40
+ outputs = super()._forward(*args, **kwargs)
41
+ return outputs
42
+
43
+
44
+ def postprocess(self, model_outputs):
45
+ predictions = self.tokenizer.decode(model_outputs["output_ids"].squeeze(), skip_special_tokens=True, clean_up_tokenization_spaces=False)
46
+ seperated_qas = self._split_to_list(predictions)
47
+ qas = []
48
+ for qa_pair in seperated_qas:
49
+ post_process = self._postrocess_qa(qa_pair) # if the prediction isn't a valid QA
50
+ if post_process is not None:
51
+ qas.append(post_process)
52
+ return qas
53
+
54
+ def _split_to_list(self, output_seq: str) -> list:
55
+ return output_seq.split(self.special_tokens.separator_different_qa)
56
+
57
+
58
+ def _postrocess_qa(self, seq: str) -> str:
59
+ # split question and answers
60
+ if self.special_tokens.separator_output_question_answer in seq:
61
+ question, answer = seq.split(self.special_tokens.separator_output_question_answer)
62
+ else:
63
+ return None
64
+ return {"question": question, "answer": answer}
65
+
66
+
67
+ if __name__ == "__main__":
68
+ pipe = QADiscourse_Pipeline("RonEliav/QA_discourse")
69
+ res1 = pipe("I don't like chocolate, but I like cookies.")
70
+ res2 = pipe(["I don't like chocolate, but I like cookies.",
71
+ "I dived in the sea easily"], num_beams=10)
72
+ print(res1)
73
+ print(res2)