from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from tqdm.auto import tqdm


def handle_long_text(
    input_text: str,
    model: AutoModelForSeq2SeqLM,
    tokenizer: AutoTokenizer,
    max_length: int = 128,
    stride: int = 128,
    batch_length: int = 2048,
    min_batch_length: int = 512,
    **generate_kwargs,
) -> str:
    """
    Maneja textos largos dividiéndolos en segmentos y generando resúmenes para cada uno.

    Args:
        input_text (str): Texto completo a resumir.
        model: Modelo de resumen abstractivo.
        tokenizer: Tokenizador asociado al modelo.
        max_length (int): Longitud máxima del resumen generado por segmento.
        stride (int): Cantidad de tokens que se superponen entre segmentos.
        batch_length (int): Longitud máxima de tokens por segmento.
        min_batch_length (int): Longitud mínima permitida por segmento.
        generate_kwargs: Parámetros adicionales para el modelo de generación.

    Returns:
        str: Resumen final concatenado de todos los segmentos.
    """
    # Validar parámetros de longitud
    if batch_length < min_batch_length:
        batch_length = min_batch_length

    # Tokenizar texto completo en segmentos
    encoded_input = tokenizer(
        input_text,
        return_tensors="pt",
        max_length=batch_length,
        truncation=True,
        stride=stride,
        return_overflowing_tokens=True,
        add_special_tokens=True,
    )
    
    # Obtener IDs y máscaras de atención
    input_ids = encoded_input["input_ids"]
    attention_masks = encoded_input["attention_mask"]

    # Progresión para múltiples segmentos
    summaries = []
    pbar = tqdm(total=len(input_ids), desc="Procesando segmentos")

    for ids, mask in zip(input_ids, attention_masks):
        # Enviar al dispositivo correcto (CPU/GPU)
        ids = ids.unsqueeze(0).to(model.device)
        mask = mask.unsqueeze(0).to(model.device)

        # Generar resumen para el segmento actual
        outputs = model.generate(
            input_ids=ids,
            attention_mask=mask,
            max_length=max_length,
            no_repeat_ngram_size=3,
            num_beams=4,
            early_stopping=True,
            **generate_kwargs,
        )
        # Decodificar resumen generado
        summary = tokenizer.decode(
            outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        summaries.append(summary)
        pbar.update()

    pbar.close()

    # Concatenar resúmenes y devolver el texto final
    final_summary = " ".join(summaries)
    return final_summary