File size: 7,293 Bytes
a0aa11b
 
b7f8795
02664ae
 
d4befbe
bfd4b1b
d4befbe
 
d66e57e
ff15af7
b7bdedd
69c35ae
b7bdedd
4aa54d6
b7bdedd
50e2461
4aa54d6
8908f31
b7bdedd
4aa54d6
50e2461
4aa54d6
b7bdedd
 
 
 
 
 
 
 
 
 
 
 
ff15af7
d66e57e
02664ae
8908f31
64c0450
b7bdedd
d4befbe
b2c3e78
dfdfff2
f5c3e8d
dfdfff2
d4befbe
f5c3e8d
dc6c5d5
3668272
 
 
dc6c5d5
b2c3e78
dc6c5d5
 
c2cc992
 
 
 
dfdfff2
c2cc992
dfdfff2
 
 
 
 
 
 
 
4aa54d6
c2cc992
 
 
 
f5c3e8d
dfdfff2
c2cc992
dfdfff2
cc97dfc
 
d4befbe
c2cc992
 
 
 
 
 
dc6c5d5
d4befbe
dc6c5d5
b2c3e78
dfdfff2
d4befbe
 
 
 
 
4617aa5
c2cc992
d4befbe
9ea1f3e
d4befbe
 
 
d66e57e
b7f8795
02664ae
b7f8795
02664ae
b7f8795
 
 
 
 
520db8f
02664ae
d4befbe
d66e57e
b7f8795
d66e57e
b7f8795
d66e57e
d4befbe
b7f8795
d4befbe
4617aa5
d4befbe
9bdb1f8
d4befbe
 
d66e57e
 
 
d4befbe
d66e57e
 
 
 
 
 
 
b7f8795
d66e57e
 
 
 
 
 
02664ae
d66e57e
d4befbe
520db8f
d66e57e
 
520db8f
d66e57e
 
 
 
 
 
02664ae
d66e57e
 
a0aa11b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import re
import json
from huggingface_hub import InferenceClient
import gradio as gr

# Initialize HuggingFace client
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

# Function to format the input into a strict JSON-based prompt
def format_prompt(topic, description, difficulty):
    prompt = (
        f"You are an expert educator. Generate a structured, highly engaging, and educational JSON object on the topic '{topic}'. "
        f"Use the following description as context: '{description}'. "
        f"The content must be suitable for a '{difficulty}' difficulty level and strictly adhere to the following JSON structure:\n\n"
        f"{{\n"
        f"  \"title\": \"[A descriptive and concise title for the topic]\",\n"
        f"  \"sections\": [\n"
        f"    {{\n"
        f"      \"subheading\": \"[A clear and concise subheading summarizing the section]\",\n"
        f"      \"content\": \"[A detailed, engaging explanation of the section content written in clear, accessible language.]\"\n"
        f"    }}\n"
        f"  ]\n"
        f"}}\n\n"
        f"### Strict Output Rules:\n"
        f"1. The output **must be a valid JSON object** and nothing else.\n"
        f"2. All keys and string values must be enclosed in double quotes (\"\").\n"
        f"3. The `sections` field must be a non-empty list of objects, each containing `subheading` and `content`.\n"
        f"4. Avoid extra characters, trailing commas, or malformed syntax.\n"
        f"5. Close all brackets and braces properly.\n"
        f"6. If there is insufficient information, return a JSON object with empty placeholders, e.g.,\n"
        f"{{\n"
        f"  \"title\": \"\",\n"
        f"  \"sections\": []\n"
        f"}}\n"
        f"7. Validate the output to ensure it complies with the required JSON structure.\n"
    )
    return prompt




# Function to clean and format the AI output
def clean_and_format_learning_content(output):
    """
    Cleans, validates, and repairs JSON output for learning content.
    """
    try:
        # Step 1: Clean raw output
        cleaned_output = re.sub(r'[^\x00-\x7F]+', '', output)  # Remove non-ASCII characters
        cleaned_output = re.sub(r'`|<s>|</s>|◀|▶', '', cleaned_output)  # Remove extraneous symbols
        cleaned_output = re.sub(r'^[^{]*', '', cleaned_output)  # Remove text before the first '{'
        cleaned_output = re.sub(r'[^}]*$', '', cleaned_output)  # Remove text after the last '}'
        cleaned_output = re.sub(r'\s+', ' ', cleaned_output).strip()  # Normalize whitespace
        cleaned_output = cleaned_output.replace('\\"', '"')  # Fix improperly escaped quotes
        cleaned_output = re.sub(r',\s*(\}|\])', r'\1', cleaned_output)  # Remove trailing commas

        # Step 2: Fix invalid 'sections' fields
        # Replace invalid sections (e.g., sections:) with an empty array
        if re.search(r'"sections":\s*,', cleaned_output):
            cleaned_output = re.sub(r'"sections":\s*,', '"sections": []', cleaned_output)

        # Fix unbalanced brackets or braces
        open_braces = cleaned_output.count('{')
        close_braces = cleaned_output.count('}')
        open_brackets = cleaned_output.count('[')
        close_brackets = cleaned_output.count(']')
        if open_braces > close_braces:
            cleaned_output += '}' * (open_braces - close_braces)
        if open_brackets > close_brackets:
            cleaned_output += ']' * (open_brackets - close_brackets)

        # Fix commas between objects in arrays
        cleaned_output = re.sub(r'(\})(\s*{)', r'\1,\2', cleaned_output)

        # Step 3: Attempt to parse JSON
        json_output = json.loads(cleaned_output)

        # Step 4: Validate JSON structure
        required_keys = ["title", "sections"]
        if "title" not in json_output or "sections" not in json_output:
            raise ValueError("Missing required keys: 'title' or 'sections'.")
        if not isinstance(json_output["sections"], list):
            # If 'sections' is not a list, replace it with an empty list
            json_output["sections"] = []
        else:
            for section in json_output["sections"]:
                if "subheading" not in section or "content" not in section:
                    raise ValueError("Each section must contain 'subheading' and 'content'.")

        return json_output

    except (json.JSONDecodeError, ValueError) as e:
        # Provide detailed error information for debugging
        return {
            "error": "Failed to parse or validate output as JSON",
            "details": str(e),
            "output": cleaned_output
        }


# Function to generate learning content
def generate_learning_content(topic, description, difficulty, temperature=0.9, max_new_tokens=2000, top_p=0.95, repetition_penalty=1.2):
    """
    Generates learning content and validates the output.
    """
    temperature = max(float(temperature), 1e-2)  # Ensure minimum temperature
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    # Format the prompt
    formatted_prompt = format_prompt(topic, description, difficulty)

    # Stream the output from the model
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    
    raw_output = ""
    for response in stream:
        raw_output += response.token.text

    # Clean and validate the raw output
    return clean_and_format_learning_content(raw_output)

# Define the Gradio interface
with gr.Blocks(theme="ocean") as demo:
    gr.HTML("<h1><center>Learning Content Generator</center></h1>")
    
    # Input fields for topic, description, and difficulty
    topic_input = gr.Textbox(label="Topic", placeholder="Enter the topic for learning content.")
    description_input = gr.Textbox(label="Description", placeholder="Enter a brief description of the topic.")
    difficulty_input = gr.Dropdown(
        label="Difficulty Level", 
        choices=["High", "Medium", "Low"], 
        value="Medium", 
        interactive=True
    )
    
    # Sliders for model parameters
    temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.9, label="Temperature")
    tokens_slider = gr.Slider(minimum=128, maximum=1048, step=64, value=512, label="Max new tokens")
    top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.95, label="Top-p (nucleus sampling)")
    repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.2, label="Repetition penalty")

    # Output field for generated learning content
    output = gr.Textbox(label="Generated Learning Content", lines=15)

    # Button to generate content
    submit_button = gr.Button("Generate Learning Content")

    # Define the click event to call the generate function
    submit_button.click(
        fn=generate_learning_content,
        inputs=[topic_input, description_input, difficulty_input, temperature_slider, tokens_slider, top_p_slider, repetition_penalty_slider],
        outputs=output,
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()