import os
os.system("pip install -q gradio torch transformers")

import gradio as gr
import torch
import random
from transformers import GPT2LMHeadModel, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('RandomNameAnd6/DharGPT-Small')

# Read real titles from file
with open('dhar_mann_titles.txt', 'r') as file:
    dhar_mann_titles = file.readlines()

def levenshtein_distance(s1, s2):
    """
    Compute the Levenshtein distance between two strings.

    Parameters:
    - s1 (str): The first string.
    - s2 (str): The second string.

    Returns:
    - int: The Levenshtein distance between the two strings.
    """
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)

    if len(s2) == 0:
        return len(s1)

    previous_row = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row

    return previous_row[-1]

def string_similarity_index(original_text, comparison_text, threshold=0.75):
    """
    Calculate the similarity index between two strings based on Levenshtein distance
    and compare it to a threshold.

    Parameters:
    - original_text (str): The original text.
    - comparison_text (str): The text to compare for similarity.
    - threshold (float): The non-original threshold score (0 to 1).

    Returns:
    - bool: True if the similarity score is above the threshold, False otherwise.
    """
    # Calculate the Levenshtein distance
    distance = levenshtein_distance(original_text, comparison_text)

    # Calculate the maximum possible distance
    max_distance = max(len(original_text), len(comparison_text))

    # Calculate the similarity score
    similarity_score = 1 - distance / max_distance

    # Compare the similarity score to the threshold
    return similarity_score >= threshold

def clean_title(input_string):
    if input_string.endswith(" | Dhar Mann"):
        input_string = input_string[:-12]
    elif input_string.endswith(" | Dhar Mann Studios"):
        input_string = input_string[:-20]
    
    # Attempt to remove all text after the first comma
    comma_index = input_string.find(',')
    if comma_index != -1:
        input_string = input_string[:comma_index]
    
    return input_string

# Function to generate an AI title
def generate_ai_title():
    while True:
        inputs = tokenizer(["<|startoftext|>"]*1, return_tensors="pt")
        outputs = model.generate(**inputs, max_new_tokens=48, use_cache=True, temperature=0.85, do_sample=True)
        generated_title = (tokenizer.batch_decode(outputs)[0])[15:-13].strip()

        # Check for similarity with existing titles
        is_unique = True
        for title in dhar_mann_titles:
            title = title.strip()  # Remove any extra whitespace characters like newlines
            if string_similarity_index(clean_title(generated_title), clean_title(title)):
                is_unique = False
                print(f"Regenerating! Generated title was: \"{generated_title}\", and the real title was \"{title}\"")
                break

        if is_unique:
            return generated_title

# Function to check user's answer and update score
def check_answer(user_choice, real_index, score):
    if (user_choice == "Option 1" and real_index == 0) or (user_choice == "Option 2" and real_index == 1):
        score += 1
        return f"Correct! Your current score is: {score}", score, gr.update(visible=True), gr.update(visible=False)
    else:
        score = 0
        return f"Incorrect. Your score has been reset to: {score}", score, gr.update(visible=False), gr.update(visible=True)

# Function to update options
def update_options():
    real_index = random.choice([0, 1])
    real_title = random.choice(dhar_mann_titles).strip()
    ai_title = generate_ai_title()

    if real_index == 0:
        return real_title, ai_title, real_index
    else:
        return ai_title, real_title, real_index

def create_interface():
    with gr.Blocks() as demo:
        score = gr.State(0)
        real_index_state = gr.State(0)

        score_display = gr.Markdown("## Real or AI - Dhar Mann\n**Current Score: 0**")

        with gr.Row():
            with gr.Column():
                gr.Markdown("### Option 1")
                option1_box = gr.Markdown("")
            with gr.Column():
                gr.Markdown("### Option 2")
                option2_box = gr.Markdown("")

        with gr.Row():
            choice = gr.Radio(["Option 1", "Option 2"], label="Which one do you think is real?")

        submit_button = gr.Button("Submit")
        result_text = gr.Markdown("")
        continue_button = gr.Button("Continue", visible=False)
        restart_button = gr.Button("Restart", visible=False)

        def on_submit(user_choice, option1, option2, real_index, score):
            result, new_score, continue_visibility, restart_visibility = check_answer(user_choice, real_index, score)
            return result, new_score, continue_visibility, restart_visibility

        def on_continue(score):
            option1, option2, real_index = update_options()
            new_score_display = f"## Real or AI - Dhar Mann\n**Current Score: {score}**"
            return option1, option2, real_index, new_score_display, gr.update(value=None), "", gr.update(visible=False), gr.update(visible=False)

        def on_restart():
            return on_continue(0)

        # Initialize options
        option1, option2, real_index = update_options()

        submit_button.click(on_submit, inputs=[choice, option1_box, option2_box, real_index_state, score], outputs=[result_text, score, continue_button, restart_button])
        continue_button.click(on_continue, inputs=score, outputs=[option1_box, option2_box, real_index_state, score_display, choice, result_text, continue_button, restart_button])
        restart_button.click(on_restart, outputs=[option1_box, option2_box, real_index_state, score_display, choice, result_text, continue_button, restart_button])

        # Set initial content for option boxes
        option1_box.value = option1
        option2_box.value = option2
        real_index_state.value = real_index

    return demo

demo = create_interface()
demo.launch()