import re import json from huggingface_hub import InferenceClient import gradio as gr # Initialize HuggingFace client client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") # Function to format the input into a strict JSON-based prompt def format_prompt(topic, description, difficulty): prompt = ( f"You are an expert educator. Generate a structured, highly engaging, and educational JSON object on the topic '{topic}'. " f"Use the following description as context: '{description}'. " f"The content must be suitable for a '{difficulty}' difficulty level and strictly adhere to the following JSON structure:\n\n" f"{{\n" f" \"title\": \"[A descriptive and concise title for the topic]\",\n" f" \"sections\": [\n" f" {{\n" f" \"subheading\": \"[A clear and concise subheading summarizing the section]\",\n" f" \"content\": \"[A detailed, engaging explanation of the section content written in clear, accessible language.]\"\n" f" }}\n" f" ]\n" f"}}\n\n" f"### Strict Output Rules:\n" f"1. The output **must be a valid JSON object** and nothing else.\n" f"2. All keys and string values must be enclosed in double quotes (\"\").\n" f"3. The `sections` field must be a non-empty list of objects, each containing `subheading` and `content`.\n" f"4. Avoid extra characters, trailing commas, or malformed syntax.\n" f"5. Close all brackets and braces properly.\n" f"6. If there is insufficient information, return a JSON object with empty placeholders, e.g.,\n" f"{{\n" f" \"title\": \"\",\n" f" \"sections\": []\n" f"}}\n" f"7. Validate the output to ensure it complies with the required JSON structure.\n" ) return prompt # Function to clean and format the AI output def clean_and_format_learning_content(output): """ Cleans, validates, and repairs JSON output for learning content. """ try: # Step 1: Clean raw output cleaned_output = re.sub(r'[^\x00-\x7F]+', '', output) # Remove non-ASCII characters cleaned_output = re.sub(r'`|<s>|</s>|◀|▶', '', cleaned_output) # Remove extraneous symbols cleaned_output = re.sub(r'^[^{]*', '', cleaned_output) # Remove text before the first '{' cleaned_output = re.sub(r'[^}]*$', '', cleaned_output) # Remove text after the last '}' cleaned_output = re.sub(r'\s+', ' ', cleaned_output).strip() # Normalize whitespace cleaned_output = cleaned_output.replace('\\"', '"') # Fix improperly escaped quotes cleaned_output = re.sub(r',\s*(\}|\])', r'\1', cleaned_output) # Remove trailing commas # Step 2: Fix invalid 'sections' fields # Replace invalid sections (e.g., sections:) with an empty array if re.search(r'"sections":\s*,', cleaned_output): cleaned_output = re.sub(r'"sections":\s*,', '"sections": []', cleaned_output) # Fix unbalanced brackets or braces open_braces = cleaned_output.count('{') close_braces = cleaned_output.count('}') open_brackets = cleaned_output.count('[') close_brackets = cleaned_output.count(']') if open_braces > close_braces: cleaned_output += '}' * (open_braces - close_braces) if open_brackets > close_brackets: cleaned_output += ']' * (open_brackets - close_brackets) # Fix commas between objects in arrays cleaned_output = re.sub(r'(\})(\s*{)', r'\1,\2', cleaned_output) # Step 3: Attempt to parse JSON json_output = json.loads(cleaned_output) # Step 4: Validate JSON structure required_keys = ["title", "sections"] if "title" not in json_output or "sections" not in json_output: raise ValueError("Missing required keys: 'title' or 'sections'.") if not isinstance(json_output["sections"], list): # If 'sections' is not a list, replace it with an empty list json_output["sections"] = [] else: for section in json_output["sections"]: if "subheading" not in section or "content" not in section: raise ValueError("Each section must contain 'subheading' and 'content'.") return json_output except (json.JSONDecodeError, ValueError) as e: # Provide detailed error information for debugging return { "error": "Failed to parse or validate output as JSON", "details": str(e), "output": cleaned_output } # Function to generate learning content def generate_learning_content(topic, description, difficulty, temperature=0.9, max_new_tokens=2000, top_p=0.95, repetition_penalty=1.2): """ Generates learning content and validates the output. """ temperature = max(float(temperature), 1e-2) # Ensure minimum temperature top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42, ) # Format the prompt formatted_prompt = format_prompt(topic, description, difficulty) # Stream the output from the model stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) raw_output = "" for response in stream: raw_output += response.token.text # Clean and validate the raw output return clean_and_format_learning_content(raw_output) # Define the Gradio interface with gr.Blocks(theme="ocean") as demo: gr.HTML("<h1><center>Learning Content Generator</center></h1>") # Input fields for topic, description, and difficulty topic_input = gr.Textbox(label="Topic", placeholder="Enter the topic for learning content.") description_input = gr.Textbox(label="Description", placeholder="Enter a brief description of the topic.") difficulty_input = gr.Dropdown( label="Difficulty Level", choices=["High", "Medium", "Low"], value="Medium", interactive=True ) # Sliders for model parameters temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.9, label="Temperature") tokens_slider = gr.Slider(minimum=128, maximum=1048, step=64, value=512, label="Max new tokens") top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.95, label="Top-p (nucleus sampling)") repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.2, label="Repetition penalty") # Output field for generated learning content output = gr.Textbox(label="Generated Learning Content", lines=15) # Button to generate content submit_button = gr.Button("Generate Learning Content") # Define the click event to call the generate function submit_button.click( fn=generate_learning_content, inputs=[topic_input, description_input, difficulty_input, temperature_slider, tokens_slider, top_p_slider, repetition_penalty_slider], outputs=output, ) # Launch the app if __name__ == "__main__": demo.launch()