LearningContent / app.py
abhiimanyu's picture
Update app.py
b7bdedd verified
import re
import json
from huggingface_hub import InferenceClient
import gradio as gr
# Initialize HuggingFace client
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
# Function to format the input into a strict JSON-based prompt
def format_prompt(topic, description, difficulty):
prompt = (
f"You are an expert educator. Generate a structured, highly engaging, and educational JSON object on the topic '{topic}'. "
f"Use the following description as context: '{description}'. "
f"The content must be suitable for a '{difficulty}' difficulty level and strictly adhere to the following JSON structure:\n\n"
f"{{\n"
f" \"title\": \"[A descriptive and concise title for the topic]\",\n"
f" \"sections\": [\n"
f" {{\n"
f" \"subheading\": \"[A clear and concise subheading summarizing the section]\",\n"
f" \"content\": \"[A detailed, engaging explanation of the section content written in clear, accessible language.]\"\n"
f" }}\n"
f" ]\n"
f"}}\n\n"
f"### Strict Output Rules:\n"
f"1. The output **must be a valid JSON object** and nothing else.\n"
f"2. All keys and string values must be enclosed in double quotes (\"\").\n"
f"3. The `sections` field must be a non-empty list of objects, each containing `subheading` and `content`.\n"
f"4. Avoid extra characters, trailing commas, or malformed syntax.\n"
f"5. Close all brackets and braces properly.\n"
f"6. If there is insufficient information, return a JSON object with empty placeholders, e.g.,\n"
f"{{\n"
f" \"title\": \"\",\n"
f" \"sections\": []\n"
f"}}\n"
f"7. Validate the output to ensure it complies with the required JSON structure.\n"
)
return prompt
# Function to clean and format the AI output
def clean_and_format_learning_content(output):
"""
Cleans, validates, and repairs JSON output for learning content.
"""
try:
# Step 1: Clean raw output
cleaned_output = re.sub(r'[^\x00-\x7F]+', '', output) # Remove non-ASCII characters
cleaned_output = re.sub(r'`|<s>|</s>|◀|▶', '', cleaned_output) # Remove extraneous symbols
cleaned_output = re.sub(r'^[^{]*', '', cleaned_output) # Remove text before the first '{'
cleaned_output = re.sub(r'[^}]*$', '', cleaned_output) # Remove text after the last '}'
cleaned_output = re.sub(r'\s+', ' ', cleaned_output).strip() # Normalize whitespace
cleaned_output = cleaned_output.replace('\\"', '"') # Fix improperly escaped quotes
cleaned_output = re.sub(r',\s*(\}|\])', r'\1', cleaned_output) # Remove trailing commas
# Step 2: Fix invalid 'sections' fields
# Replace invalid sections (e.g., sections:) with an empty array
if re.search(r'"sections":\s*,', cleaned_output):
cleaned_output = re.sub(r'"sections":\s*,', '"sections": []', cleaned_output)
# Fix unbalanced brackets or braces
open_braces = cleaned_output.count('{')
close_braces = cleaned_output.count('}')
open_brackets = cleaned_output.count('[')
close_brackets = cleaned_output.count(']')
if open_braces > close_braces:
cleaned_output += '}' * (open_braces - close_braces)
if open_brackets > close_brackets:
cleaned_output += ']' * (open_brackets - close_brackets)
# Fix commas between objects in arrays
cleaned_output = re.sub(r'(\})(\s*{)', r'\1,\2', cleaned_output)
# Step 3: Attempt to parse JSON
json_output = json.loads(cleaned_output)
# Step 4: Validate JSON structure
required_keys = ["title", "sections"]
if "title" not in json_output or "sections" not in json_output:
raise ValueError("Missing required keys: 'title' or 'sections'.")
if not isinstance(json_output["sections"], list):
# If 'sections' is not a list, replace it with an empty list
json_output["sections"] = []
else:
for section in json_output["sections"]:
if "subheading" not in section or "content" not in section:
raise ValueError("Each section must contain 'subheading' and 'content'.")
return json_output
except (json.JSONDecodeError, ValueError) as e:
# Provide detailed error information for debugging
return {
"error": "Failed to parse or validate output as JSON",
"details": str(e),
"output": cleaned_output
}
# Function to generate learning content
def generate_learning_content(topic, description, difficulty, temperature=0.9, max_new_tokens=2000, top_p=0.95, repetition_penalty=1.2):
"""
Generates learning content and validates the output.
"""
temperature = max(float(temperature), 1e-2) # Ensure minimum temperature
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
# Format the prompt
formatted_prompt = format_prompt(topic, description, difficulty)
# Stream the output from the model
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
raw_output = ""
for response in stream:
raw_output += response.token.text
# Clean and validate the raw output
return clean_and_format_learning_content(raw_output)
# Define the Gradio interface
with gr.Blocks(theme="ocean") as demo:
gr.HTML("<h1><center>Learning Content Generator</center></h1>")
# Input fields for topic, description, and difficulty
topic_input = gr.Textbox(label="Topic", placeholder="Enter the topic for learning content.")
description_input = gr.Textbox(label="Description", placeholder="Enter a brief description of the topic.")
difficulty_input = gr.Dropdown(
label="Difficulty Level",
choices=["High", "Medium", "Low"],
value="Medium",
interactive=True
)
# Sliders for model parameters
temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.9, label="Temperature")
tokens_slider = gr.Slider(minimum=128, maximum=1048, step=64, value=512, label="Max new tokens")
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.95, label="Top-p (nucleus sampling)")
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.2, label="Repetition penalty")
# Output field for generated learning content
output = gr.Textbox(label="Generated Learning Content", lines=15)
# Button to generate content
submit_button = gr.Button("Generate Learning Content")
# Define the click event to call the generate function
submit_button.click(
fn=generate_learning_content,
inputs=[topic_input, description_input, difficulty_input, temperature_slider, tokens_slider, top_p_slider, repetition_penalty_slider],
outputs=output,
)
# Launch the app
if __name__ == "__main__":
demo.launch()