roneliav commited on
Commit
e4d6f89
·
1 Parent(s): 9e944a8

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +19 -9
pipeline.py CHANGED
@@ -3,8 +3,9 @@ from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, Aut
3
 
4
  def get_markers_for_model():
5
  special_tokens_constants = Namespace()
6
- special_tokens_constants.separator_different_qa = "&&"
7
- special_tokens_constants.separator_output_question_answer = "? "
 
8
  return special_tokens_constants
9
 
10
  def load_trained_model(name_or_path):
@@ -21,10 +22,19 @@ class QADiscourse_Pipeline(Text2TextGenerationPipeline):
21
 
22
 
23
  def preprocess(self, inputs):
24
- # Here, inputs is string or list of strings; apply string postprocessing
25
- return super().preprocess(inputs)
 
 
 
 
 
 
26
 
27
-
 
 
 
28
 
29
  def _forward(self, *args, **kwargs):
30
  outputs = super()._forward(*args, **kwargs)
@@ -36,7 +46,9 @@ class QADiscourse_Pipeline(Text2TextGenerationPipeline):
36
  seperated_qas = self._split_to_list(predictions)
37
  qas = []
38
  for qa_pair in seperated_qas:
39
- qas.append(self._postrocess_qa(qa_pair))
 
 
40
  return qas
41
 
42
  def _split_to_list(self, output_seq: str) -> list:
@@ -48,7 +60,6 @@ class QADiscourse_Pipeline(Text2TextGenerationPipeline):
48
  if self.special_tokens.separator_output_question_answer in seq:
49
  question, answer = seq.split(self.special_tokens.separator_output_question_answer)
50
  else:
51
- print("invalid format: no separator between question and answer found...")
52
  return None
53
  return {"question": question, "answer": answer}
54
 
@@ -59,5 +70,4 @@ if __name__ == "__main__":
59
  res2 = pipe(["I don't like chocolate, but I like cookies.",
60
  "I dived in the sea easily"], num_beams=10)
61
  print(res1)
62
- print(res2)
63
-
 
3
 
4
  def get_markers_for_model():
5
  special_tokens_constants = Namespace()
6
+ special_tokens_constants.separator_different_qa = "&&&"
7
+ special_tokens_constants.separator_output_question_answer = "SSEEPP"
8
+ special_tokens_constants.source_prefix = "qa: "
9
  return special_tokens_constants
10
 
11
  def load_trained_model(name_or_path):
 
22
 
23
 
24
  def preprocess(self, inputs):
25
+ if isinstance(inputs, str):
26
+ processed_inputs = self._preprocess_string(inputs)
27
+ elif hasattr(inputs, "__iter__"):
28
+ processed_inputs = [self._preprocess_string(s) for s in inputs]
29
+ else:
30
+ raise ValueError("inputs must be str or Iterable[str]")
31
+ # Now pass to super.preprocess for tokenization
32
+ return super().preprocess(processed_inputs)
33
 
34
+ def _preprocess_string(self, seq: str) -> str:
35
+ seq = self.special_tokens.source_prefix + seq
36
+ print(seq)
37
+ return seq
38
 
39
  def _forward(self, *args, **kwargs):
40
  outputs = super()._forward(*args, **kwargs)
 
46
  seperated_qas = self._split_to_list(predictions)
47
  qas = []
48
  for qa_pair in seperated_qas:
49
+ post_process = self._postrocess_qa(qa_pair) # if the prediction isn't a valid QA
50
+ if post_process is not None:
51
+ qas.append(post_process)
52
  return qas
53
 
54
  def _split_to_list(self, output_seq: str) -> list:
 
60
  if self.special_tokens.separator_output_question_answer in seq:
61
  question, answer = seq.split(self.special_tokens.separator_output_question_answer)
62
  else:
 
63
  return None
64
  return {"question": question, "answer": answer}
65
 
 
70
  res2 = pipe(["I don't like chocolate, but I like cookies.",
71
  "I dived in the sea easily"], num_beams=10)
72
  print(res1)
73
+ print(res2)