Spaces:
Sleeping
Sleeping
import re | |
import json | |
from huggingface_hub import InferenceClient | |
import gradio as gr | |
# Initialize HuggingFace client | |
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
# Function to format the input into a strict JSON-based prompt | |
def format_prompt(topic, description, difficulty): | |
prompt = ( | |
f"You are an expert educator. Generate a structured, highly engaging, and educational JSON object on the topic '{topic}'. " | |
f"Use the following description as context: '{description}'. " | |
f"The content must be suitable for a '{difficulty}' difficulty level and strictly adhere to the following JSON structure:\n\n" | |
f"{{\n" | |
f" \"title\": \"[A descriptive and concise title for the topic]\",\n" | |
f" \"sections\": [\n" | |
f" {{\n" | |
f" \"subheading\": \"[A clear and concise subheading summarizing the section]\",\n" | |
f" \"content\": \"[A detailed, engaging explanation of the section content written in clear, accessible language.]\"\n" | |
f" }}\n" | |
f" ]\n" | |
f"}}\n\n" | |
f"### Strict Output Rules:\n" | |
f"1. The output **must be a valid JSON object** and nothing else.\n" | |
f"2. All keys and string values must be enclosed in double quotes (\"\").\n" | |
f"3. The `sections` field must be a non-empty list of objects, each containing `subheading` and `content`.\n" | |
f"4. Avoid extra characters, trailing commas, or malformed syntax.\n" | |
f"5. Close all brackets and braces properly.\n" | |
f"6. If there is insufficient information, return a JSON object with empty placeholders, e.g.,\n" | |
f"{{\n" | |
f" \"title\": \"\",\n" | |
f" \"sections\": []\n" | |
f"}}\n" | |
f"7. Validate the output to ensure it complies with the required JSON structure.\n" | |
) | |
return prompt | |
# Function to clean and format the AI output | |
def clean_and_format_learning_content(output): | |
""" | |
Cleans, validates, and repairs JSON output for learning content. | |
""" | |
try: | |
# Step 1: Clean raw output | |
cleaned_output = re.sub(r'[^\x00-\x7F]+', '', output) # Remove non-ASCII characters | |
cleaned_output = re.sub(r'`|<s>|</s>|◀|▶', '', cleaned_output) # Remove extraneous symbols | |
cleaned_output = re.sub(r'^[^{]*', '', cleaned_output) # Remove text before the first '{' | |
cleaned_output = re.sub(r'[^}]*$', '', cleaned_output) # Remove text after the last '}' | |
cleaned_output = re.sub(r'\s+', ' ', cleaned_output).strip() # Normalize whitespace | |
cleaned_output = cleaned_output.replace('\\"', '"') # Fix improperly escaped quotes | |
cleaned_output = re.sub(r',\s*(\}|\])', r'\1', cleaned_output) # Remove trailing commas | |
# Step 2: Fix invalid 'sections' fields | |
# Replace invalid sections (e.g., sections:) with an empty array | |
if re.search(r'"sections":\s*,', cleaned_output): | |
cleaned_output = re.sub(r'"sections":\s*,', '"sections": []', cleaned_output) | |
# Fix unbalanced brackets or braces | |
open_braces = cleaned_output.count('{') | |
close_braces = cleaned_output.count('}') | |
open_brackets = cleaned_output.count('[') | |
close_brackets = cleaned_output.count(']') | |
if open_braces > close_braces: | |
cleaned_output += '}' * (open_braces - close_braces) | |
if open_brackets > close_brackets: | |
cleaned_output += ']' * (open_brackets - close_brackets) | |
# Fix commas between objects in arrays | |
cleaned_output = re.sub(r'(\})(\s*{)', r'\1,\2', cleaned_output) | |
# Step 3: Attempt to parse JSON | |
json_output = json.loads(cleaned_output) | |
# Step 4: Validate JSON structure | |
required_keys = ["title", "sections"] | |
if "title" not in json_output or "sections" not in json_output: | |
raise ValueError("Missing required keys: 'title' or 'sections'.") | |
if not isinstance(json_output["sections"], list): | |
# If 'sections' is not a list, replace it with an empty list | |
json_output["sections"] = [] | |
else: | |
for section in json_output["sections"]: | |
if "subheading" not in section or "content" not in section: | |
raise ValueError("Each section must contain 'subheading' and 'content'.") | |
return json_output | |
except (json.JSONDecodeError, ValueError) as e: | |
# Provide detailed error information for debugging | |
return { | |
"error": "Failed to parse or validate output as JSON", | |
"details": str(e), | |
"output": cleaned_output | |
} | |
# Function to generate learning content | |
def generate_learning_content(topic, description, difficulty, temperature=0.9, max_new_tokens=2000, top_p=0.95, repetition_penalty=1.2): | |
""" | |
Generates learning content and validates the output. | |
""" | |
temperature = max(float(temperature), 1e-2) # Ensure minimum temperature | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
# Format the prompt | |
formatted_prompt = format_prompt(topic, description, difficulty) | |
# Stream the output from the model | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
raw_output = "" | |
for response in stream: | |
raw_output += response.token.text | |
# Clean and validate the raw output | |
return clean_and_format_learning_content(raw_output) | |
# Define the Gradio interface | |
with gr.Blocks(theme="ocean") as demo: | |
gr.HTML("<h1><center>Learning Content Generator</center></h1>") | |
# Input fields for topic, description, and difficulty | |
topic_input = gr.Textbox(label="Topic", placeholder="Enter the topic for learning content.") | |
description_input = gr.Textbox(label="Description", placeholder="Enter a brief description of the topic.") | |
difficulty_input = gr.Dropdown( | |
label="Difficulty Level", | |
choices=["High", "Medium", "Low"], | |
value="Medium", | |
interactive=True | |
) | |
# Sliders for model parameters | |
temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.9, label="Temperature") | |
tokens_slider = gr.Slider(minimum=128, maximum=1048, step=64, value=512, label="Max new tokens") | |
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.95, label="Top-p (nucleus sampling)") | |
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.2, label="Repetition penalty") | |
# Output field for generated learning content | |
output = gr.Textbox(label="Generated Learning Content", lines=15) | |
# Button to generate content | |
submit_button = gr.Button("Generate Learning Content") | |
# Define the click event to call the generate function | |
submit_button.click( | |
fn=generate_learning_content, | |
inputs=[topic_input, description_input, difficulty_input, temperature_slider, tokens_slider, top_p_slider, repetition_penalty_slider], | |
outputs=output, | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |