from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch from tqdm.auto import tqdm def handle_long_text( input_text: str, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer, max_length: int = 128, stride: int = 128, batch_length: int = 2048, min_batch_length: int = 512, **generate_kwargs, ) -> str: """ Maneja textos largos dividiéndolos en segmentos y generando resúmenes para cada uno. Args: input_text (str): Texto completo a resumir. model: Modelo de resumen abstractivo. tokenizer: Tokenizador asociado al modelo. max_length (int): Longitud máxima del resumen generado por segmento. stride (int): Cantidad de tokens que se superponen entre segmentos. batch_length (int): Longitud máxima de tokens por segmento. min_batch_length (int): Longitud mínima permitida por segmento. generate_kwargs: Parámetros adicionales para el modelo de generación. Returns: str: Resumen final concatenado de todos los segmentos. """ # Validar parámetros de longitud if batch_length < min_batch_length: batch_length = min_batch_length # Tokenizar texto completo en segmentos encoded_input = tokenizer( input_text, return_tensors="pt", max_length=batch_length, truncation=True, stride=stride, return_overflowing_tokens=True, add_special_tokens=True, ) # Obtener IDs y máscaras de atención input_ids = encoded_input["input_ids"] attention_masks = encoded_input["attention_mask"] # Progresión para múltiples segmentos summaries = [] pbar = tqdm(total=len(input_ids), desc="Procesando segmentos") for ids, mask in zip(input_ids, attention_masks): # Enviar al dispositivo correcto (CPU/GPU) ids = ids.unsqueeze(0).to(model.device) mask = mask.unsqueeze(0).to(model.device) # Generar resumen para el segmento actual outputs = model.generate( input_ids=ids, attention_mask=mask, max_length=max_length, no_repeat_ngram_size=3, num_beams=4, early_stopping=True, **generate_kwargs, ) # Decodificar resumen generado summary = tokenizer.decode( outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) summaries.append(summary) pbar.update() pbar.close() # Concatenar resúmenes y devolver el texto final final_summary = " ".join(summaries) return final_summary