import streamlit as st
import requests
import clipboard
import random
import clipboard
import json
import itertools
import pickle
from huggingface_hub import InferenceClient

# Import and download necessary NLTK data for tokenization.
import nltk
from nltk.translate.bleu_score import sentence_bleu
nltk.download('punkt')

# Import the ROUGE metric implementation.
from rouge import Rouge
rouge = Rouge()

from datasets import load_dataset
# Use name="sample-10BT" to use the 10BT sample.
fw = load_dataset("HuggingFaceFW/fineweb", name="CC-MAIN-2024-10", split="train", streaming=True)

# Define helper functions for character-level accuracy and precision.
def char_accuracy(true_output, model_output):
    # Compare matching characters in corresponding positions.
    matches = sum(1 for c1, c2 in zip(true_output, model_output) if c1 == c2)
    # Account for any extra characters in either string.
    total = max(len(true_output), len(model_output))
    return matches / total if total > 0 else 1.0

def char_precision(true_output, model_output):
    # Here, precision is defined as matching characters divided by the length of the model's output.
    matches = sum(1 for c1, c2 in zip(true_output, model_output) if c1 == c2)
    return matches / len(model_output) if len(model_output) > 0 else 0.0


# Initialize session state to keep track of the current index and user input
if 'current_index' not in st.session_state:
    st.session_state.current_index = 0
if 'user_input' not in st.session_state:
    st.session_state.user_input = ''

def process_element(true_output, model_output):
    global acc, pres, bleu, rouges
    tag1 = model_output.find("<back>")
    tag2 = model_output.find("</back>")
    model_output = model_output[tag1 + 6:tag2]
    print("Model output:", model_output)

    # Tokenize both outputs for BLEU calculation.
    reference_tokens = nltk.word_tokenize(true_output)
    candidate_tokens = nltk.word_tokenize(model_output)

    # Compute BLEU score (using the single reference).
    bleu_score = sentence_bleu([reference_tokens], candidate_tokens)
    print("BLEU score:", bleu_score)

    # Compute ROUGE scores.
    rouge_scores = rouge.get_scores(model_output, true_output)
    print("ROUGE scores:", rouge_scores)

    # Compute character-level accuracy and precision.
    accuracy_metric = char_accuracy(true_output, model_output)
    precision_metric = char_precision(true_output, model_output)
    print("Character Accuracy:", accuracy_metric)
    print("Character Precision:", precision_metric)
    print("-" * 80)
    acc.append(accuracy_metric)
    pres.append(precision_metric)
    bleu.append(bleu_score)
    rouges.append(rouge_scores)
# Get 1 sample (modify the range if you need more samples).
samples = list(itertools.islice(fw, 100))
word_threshold = 100
acc = []
pres=[]
bleu=[]
rouges=[]
if st.session_state.current_index < len(samples):
    # Display the current element
    current_element = samples[st.session_state.current_index]["text"].split(" ")
    
    current_element=" ".join(current_element).strip().replace("\n","")
    prompt = "You are a helpful assistant that echoes the user's input, but backwards, do not simply rearrange the words, reverse the user's input down to the character (e.g. reverse Hello World to dlroW olleH). Surround the backwards version of the user's input with <back> </back> tags: "+current_element
    st.code(prompt)
    true_output = current_element[::-1]
    print("True output:", true_output)
    # Input widget to get user input, tied to session state
    user_input = st.text_input(
        "Enter your input:",
        value=st.session_state.user_input,
        key='input_field'
    )

    # Check if the user has entered input
    if st.button("Submit"):
        if user_input:
            # Process the current element and user input
            process_element(true_output, user_input)
            # Clear the input field by resetting the session state
            st.session_state.user_input = ''
            # Move to the next element
            st.session_state.current_index += 1
            # Rerun the app to update the state
            st.rerun()
        else:
            st.warning("Please enter your input before submitting.")
else:
    st.success("All elements have been processed!")
    with open('accuracy.pkl', 'wb') as file:
        pickle.dump(acc, file)
    with open('precision.pkl', 'wb') as file:
        pickle.dump(pres, file)
    with open('bleu.pkl', 'wb') as file:
        pickle.dump(bleu, file)
    with open('rouge.pkl', 'wb') as file:
        pickle.dump(rouges, file)