import logging
import re
from pathlib import Path
import time
import gradio as gr
import nltk
from cleantext import clean
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
from utils import load_examples, truncate_word_count
_here = Path(__file__).parent
nltk.download("stopwords")  # TODO=find where this requirement originates from
import transformers
transformers.logging.set_verbosity_error()
logging.basicConfig()
def proc_submission(
    input_text: str,
    model_size: str,
    num_beams,
    token_batch_length,
    length_penalty,
    repetition_penalty,
    no_repeat_ngram_size,
    max_input_length: int = 512,
):
    """
    proc_submission - a helper function for the gradio module
    Parameters
    ----------
    input_text : str, required, the text to be processed
    max_input_length : int, optional, the maximum length of the input text, default=512
    Returns
    -------
    str of HTML, the interactive HTML form for the model
    """
    settings = {
        "length_penalty": length_penalty,
        "repetition_penalty": repetition_penalty,
        "no_repeat_ngram_size": no_repeat_ngram_size,
        "encoder_no_repeat_ngram_size": 4,
        "num_beams": num_beams,
        "min_length": 4,
        "max_length": int(token_batch_length // 4),
        "early_stopping": True,
        "do_sample": False,
    }
    st = time.perf_counter()
    history = {}
    clean_text = clean(input_text, lower=False)
    max_input_length = 1024 if model_size == "base" else max_input_length
    processed = truncate_word_count(clean_text, max_input_length)
    if processed["was_truncated"]:
        tr_in = processed["truncated_text"]
        msg = f"Input text was truncated to {max_input_length} words (based on whitespace)"
        logging.warning(msg)
        history["WARNING"] = msg
    else:
        tr_in = input_text
    _summaries = summarize_via_tokenbatches(
        tr_in,
        model_sm if model_size == "base" else model,
        tokenizer_sm if model_size == "base" else tokenizer,
        batch_length=token_batch_length,
        **settings,
    )
    sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)]
    sum_scores = [
        f"\n - Section {i}: {round(s['summary_score'],4)}"
        for i, s in enumerate(_summaries)
    ]
    history["Summary Text"] = "
".join(sum_text)
    history["Summary Scores"] = "\n".join(sum_scores)
    history["Input"] = tr_in
    html = ""
    rt = round((time.perf_counter() - st) / 60, 2)
    print(f"Runtime: {rt} minutes")
    html += f"
Runtime: {rt} minutes on CPU
" for name, item in history.items(): html += ( f"