Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from googletrans import Translator | |
| from transformers import T5Tokenizer | |
| from transformers import T5ForConditionalGeneration | |
| from transformers import BartForConditionalGeneration | |
| from transformers import BartTokenizer | |
| from transformers import PreTrainedModel | |
| from transformers import PreTrainedTokenizer | |
| from transformers import AutoTokenizer | |
| # Question launcher | |
| class E2EQGPipeline: | |
| def __init__( | |
| self, | |
| model: PreTrainedModel, | |
| tokenizer: PreTrainedTokenizer | |
| ): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.model_type = "t5" | |
| self.kwargs = { | |
| "max_length": 256, | |
| "num_beams": 4, | |
| "length_penalty": 1.5, | |
| "no_repeat_ngram_size": 3, | |
| "early_stopping": True, | |
| } | |
| def generate_questions(self, context: str): | |
| inputs = self._prepare_inputs_for_e2e_qg(context) | |
| outs = self.model.generate( | |
| input_ids=inputs['input_ids'].to(self.device), | |
| attention_mask=inputs['attention_mask'].to(self.device), | |
| **self.kwargs | |
| ) | |
| prediction = self.tokenizer.decode(outs[0], skip_special_tokens=True) | |
| questions = prediction.split("<sep>") | |
| questions = [question.strip() for question in questions[:-1]] | |
| return questions | |
| def _prepare_inputs_for_e2e_qg(self, context): | |
| source_text = f"generate questions: {context}" | |
| inputs = self._tokenize([source_text], padding=False) | |
| return inputs | |
| def _tokenize( | |
| self, | |
| inputs, | |
| padding=True, | |
| truncation=True, | |
| add_special_tokens=True, | |
| max_length=512 | |
| ): | |
| inputs = self.tokenizer.batch_encode_plus( | |
| inputs, | |
| max_length=max_length, | |
| add_special_tokens=add_special_tokens, | |
| truncation=truncation, | |
| padding="max_length" if padding else False, | |
| pad_to_max_length=padding, | |
| return_tensors="pt" | |
| ) | |
| return inputs | |
| qg_model = T5ForConditionalGeneration.from_pretrained('valhalla/t5-base-e2e-qg') | |
| qg_tokenizer = T5Tokenizer.from_pretrained('valhalla/t5-base-e2e-qg') | |
| def generate_questions(text): | |
| qg_final_model = E2EQGPipeline(qg_model, qg_tokenizer) | |
| questions = qg_final_model.generate_questions(text) | |
| translator = Translator() | |
| translated_questions = [translator.translate(question, dest='es').text for question in questions] | |
| return translated_questions | |
| tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') | |
| model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') | |
| def generate_summary(text): | |
| inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True) | |
| summary_ids = model.generate(inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True) | |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| return summary | |
| # QA | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| ckpt = 'mrm8488/spanish-t5-small-sqac-for-qa' | |
| qa_tokenizer = AutoTokenizer.from_pretrained(ckpt) | |
| qa_model = T5ForConditionalGeneration.from_pretrained(ckpt).to(device) | |
| def generate_question_response(question, context): | |
| input_text = 'question: %s context: %s' % (question, context) | |
| print(input_text) | |
| features = qa_tokenizer([input_text], padding='max_length', truncation=True, max_length=512, return_tensors='pt') | |
| output = qa_model.generate( | |
| input_ids=features['input_ids'].to(device), | |
| attention_mask=features['attention_mask'].to(device), | |
| temperature=1.0 | |
| ) | |
| return qa_tokenizer.decode(output[0], skip_special_tokens=True) | |
| class SummarizerAndQA: | |
| def __init__(self): | |
| self.input_text = '' | |
| self.question = '' | |
| self.summary = '' | |
| self.study_generated_questions = '' | |
| self.question_response = '' | |
| def is_text_loaded(self): | |
| return self.input_text != '' | |
| def process_summarizer(self, text): | |
| self.input_text = text | |
| return generate_summary(text) | |
| def process_questions(self): | |
| return generate_questions(self.input_text) | |
| def process_question_response(self, question, history): | |
| return generate_question_response(question, self.input_text) | |
| summarizer_and_qa = SummarizerAndQA() | |
| textbox_input = gr.Textbox(label="Pega el text aca:", placeholder="Texto...", lines=15) | |
| summary_output = gr.Textbox(label="Resumen", lines=15) | |
| questions_output = gr.Textbox(label="Preguntas de guia generadas", lines=5) | |
| questions_generate_button = gr.Button("Generate", variant="primary", interactive=summarizer_and_qa.is_text_loaded()) | |
| summarizer_interface = gr.Interface(fn=summarizer_and_qa.process_summarizer, inputs=[textbox_input], outputs=[summary_output], allow_flagging="never") | |
| questions_interface = gr.Interface(fn=summarizer_and_qa.process_questions, inputs=[], outputs=[questions_output], allow_flagging="never", submit_btn=questions_generate_button, live=True) | |
| chatbot_interface = gr.ChatInterface(fn=summarizer_and_qa.process_question_response, type="messages", examples=[], title="Preguntas sobre el texto") | |
| gr.TabbedInterface([summarizer_interface, questions_interface, chatbot_interface], ["Resumidor", "Preguntas de guia", "Chatbot"]).launch() | |