File size: 2,906 Bytes
30aebe0 e4d6f89 30aebe0 e4d6f89 30aebe0 e4d6f89 30aebe0 e4d6f89 30aebe0 e4d6f89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from argparse import Namespace
from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer
def get_markers_for_model():
special_tokens_constants = Namespace()
special_tokens_constants.separator_different_qa = "&&&"
special_tokens_constants.separator_output_question_answer = "SSEEPP"
special_tokens_constants.source_prefix = "qa: "
return special_tokens_constants
def load_trained_model(name_or_path):
tokenizer = AutoTokenizer.from_pretrained(name_or_path)
model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path)
return model, tokenizer
class QADiscourse_Pipeline(Text2TextGenerationPipeline):
def __init__(self, model_repo: str, **kwargs):
model, tokenizer = load_trained_model(model_repo)
super().__init__(model, tokenizer, framework="pt")
self.special_tokens = get_markers_for_model()
def preprocess(self, inputs):
if isinstance(inputs, str):
processed_inputs = self._preprocess_string(inputs)
elif hasattr(inputs, "__iter__"):
processed_inputs = [self._preprocess_string(s) for s in inputs]
else:
raise ValueError("inputs must be str or Iterable[str]")
# Now pass to super.preprocess for tokenization
return super().preprocess(processed_inputs)
def _preprocess_string(self, seq: str) -> str:
seq = self.special_tokens.source_prefix + seq
print(seq)
return seq
def _forward(self, *args, **kwargs):
outputs = super()._forward(*args, **kwargs)
return outputs
def postprocess(self, model_outputs):
predictions = self.tokenizer.decode(model_outputs["output_ids"].squeeze(), skip_special_tokens=True, clean_up_tokenization_spaces=False)
seperated_qas = self._split_to_list(predictions)
qas = []
for qa_pair in seperated_qas:
post_process = self._postrocess_qa(qa_pair) # if the prediction isn't a valid QA
if post_process is not None:
qas.append(post_process)
return qas
def _split_to_list(self, output_seq: str) -> list:
return output_seq.split(self.special_tokens.separator_different_qa)
def _postrocess_qa(self, seq: str) -> str:
# split question and answers
if self.special_tokens.separator_output_question_answer in seq:
question, answer = seq.split(self.special_tokens.separator_output_question_answer)
else:
return None
return {"question": question, "answer": answer}
if __name__ == "__main__":
pipe = QADiscourse_Pipeline("RonEliav/QA_discourse")
res1 = pipe("I don't like chocolate, but I like cookies.")
res2 = pipe(["I don't like chocolate, but I like cookies.",
"I dived in the sea easily"], num_beams=10)
print(res1)
print(res2) |