import re
import json
from huggingface_hub import InferenceClient
import gradio as gr

# Initialize HuggingFace client
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

# Function to format the input into a strict JSON-based prompt
def format_prompt(topic, description, difficulty):
    prompt = (
        f"You are an expert educator. Generate a structured, highly engaging, and educational JSON object on the topic '{topic}'. "
        f"Use the following description as context: '{description}'. "
        f"The content must be suitable for a '{difficulty}' difficulty level and strictly adhere to the following JSON structure:\n\n"
        f"{{\n"
        f"  \"title\": \"[A descriptive and concise title for the topic]\",\n"
        f"  \"sections\": [\n"
        f"    {{\n"
        f"      \"subheading\": \"[A clear and concise subheading summarizing the section]\",\n"
        f"      \"content\": \"[A detailed, engaging explanation of the section content written in clear, accessible language.]\"\n"
        f"    }}\n"
        f"  ]\n"
        f"}}\n\n"
        f"### Strict Output Rules:\n"
        f"1. The output **must be a valid JSON object** and nothing else.\n"
        f"2. All keys and string values must be enclosed in double quotes (\"\").\n"
        f"3. The `sections` field must be a non-empty list of objects, each containing `subheading` and `content`.\n"
        f"4. Avoid extra characters, trailing commas, or malformed syntax.\n"
        f"5. Close all brackets and braces properly.\n"
        f"6. If there is insufficient information, return a JSON object with empty placeholders, e.g.,\n"
        f"{{\n"
        f"  \"title\": \"\",\n"
        f"  \"sections\": []\n"
        f"}}\n"
        f"7. Validate the output to ensure it complies with the required JSON structure.\n"
    )
    return prompt




# Function to clean and format the AI output
def clean_and_format_learning_content(output):
    """
    Cleans, validates, and repairs JSON output for learning content.
    """
    try:
        # Step 1: Clean raw output
        cleaned_output = re.sub(r'[^\x00-\x7F]+', '', output)  # Remove non-ASCII characters
        cleaned_output = re.sub(r'`|<s>|</s>|◀|▶', '', cleaned_output)  # Remove extraneous symbols
        cleaned_output = re.sub(r'^[^{]*', '', cleaned_output)  # Remove text before the first '{'
        cleaned_output = re.sub(r'[^}]*$', '', cleaned_output)  # Remove text after the last '}'
        cleaned_output = re.sub(r'\s+', ' ', cleaned_output).strip()  # Normalize whitespace
        cleaned_output = cleaned_output.replace('\\"', '"')  # Fix improperly escaped quotes
        cleaned_output = re.sub(r',\s*(\}|\])', r'\1', cleaned_output)  # Remove trailing commas

        # Step 2: Fix invalid 'sections' fields
        # Replace invalid sections (e.g., sections:) with an empty array
        if re.search(r'"sections":\s*,', cleaned_output):
            cleaned_output = re.sub(r'"sections":\s*,', '"sections": []', cleaned_output)

        # Fix unbalanced brackets or braces
        open_braces = cleaned_output.count('{')
        close_braces = cleaned_output.count('}')
        open_brackets = cleaned_output.count('[')
        close_brackets = cleaned_output.count(']')
        if open_braces > close_braces:
            cleaned_output += '}' * (open_braces - close_braces)
        if open_brackets > close_brackets:
            cleaned_output += ']' * (open_brackets - close_brackets)

        # Fix commas between objects in arrays
        cleaned_output = re.sub(r'(\})(\s*{)', r'\1,\2', cleaned_output)

        # Step 3: Attempt to parse JSON
        json_output = json.loads(cleaned_output)

        # Step 4: Validate JSON structure
        required_keys = ["title", "sections"]
        if "title" not in json_output or "sections" not in json_output:
            raise ValueError("Missing required keys: 'title' or 'sections'.")
        if not isinstance(json_output["sections"], list):
            # If 'sections' is not a list, replace it with an empty list
            json_output["sections"] = []
        else:
            for section in json_output["sections"]:
                if "subheading" not in section or "content" not in section:
                    raise ValueError("Each section must contain 'subheading' and 'content'.")

        return json_output

    except (json.JSONDecodeError, ValueError) as e:
        # Provide detailed error information for debugging
        return {
            "error": "Failed to parse or validate output as JSON",
            "details": str(e),
            "output": cleaned_output
        }


# Function to generate learning content
def generate_learning_content(topic, description, difficulty, temperature=0.9, max_new_tokens=2000, top_p=0.95, repetition_penalty=1.2):
    """
    Generates learning content and validates the output.
    """
    temperature = max(float(temperature), 1e-2)  # Ensure minimum temperature
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    # Format the prompt
    formatted_prompt = format_prompt(topic, description, difficulty)

    # Stream the output from the model
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    
    raw_output = ""
    for response in stream:
        raw_output += response.token.text

    # Clean and validate the raw output
    return clean_and_format_learning_content(raw_output)

# Define the Gradio interface
with gr.Blocks(theme="ocean") as demo:
    gr.HTML("<h1><center>Learning Content Generator</center></h1>")
    
    # Input fields for topic, description, and difficulty
    topic_input = gr.Textbox(label="Topic", placeholder="Enter the topic for learning content.")
    description_input = gr.Textbox(label="Description", placeholder="Enter a brief description of the topic.")
    difficulty_input = gr.Dropdown(
        label="Difficulty Level", 
        choices=["High", "Medium", "Low"], 
        value="Medium", 
        interactive=True
    )
    
    # Sliders for model parameters
    temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.9, label="Temperature")
    tokens_slider = gr.Slider(minimum=128, maximum=1048, step=64, value=512, label="Max new tokens")
    top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.95, label="Top-p (nucleus sampling)")
    repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.2, label="Repetition penalty")

    # Output field for generated learning content
    output = gr.Textbox(label="Generated Learning Content", lines=15)

    # Button to generate content
    submit_button = gr.Button("Generate Learning Content")

    # Define the click event to call the generate function
    submit_button.click(
        fn=generate_learning_content,
        inputs=[topic_input, description_input, difficulty_input, temperature_slider, tokens_slider, top_p_slider, repetition_penalty_slider],
        outputs=output,
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()