File size: 3,387 Bytes
77a8694
 
360349c
77a8694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44245fd
77a8694
44245fd
77a8694
44245fd
77a8694
 
44245fd
77a8694
 
44245fd
77a8694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44245fd
 
 
 
77a8694
44245fd
 
77a8694
 
44245fd
77a8694
 
44245fd
77a8694
44245fd
77a8694
44245fd
77a8694
 
44245fd
 
 
77a8694
 
 
44245fd
 
 
77a8694
44245fd
77a8694
 
44245fd
77a8694
 
 
 
44245fd
77a8694
 
44245fd
 
 
 
 
 
 
 
77a8694
44245fd
 
 
77a8694
 
 
 
 
44245fd
 
 
77a8694
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Model configurations
BASE_MODEL = "HuggingFaceTB/SmolLM2-1.7B-Instruct"  # Base model
ADAPTER_MODEL = "Joash2024/Math-SmolLM2-1.7B"       # Our LoRA adapter

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token

print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    torch_dtype=torch.float16
)

print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(model, ADAPTER_MODEL)
model.eval()

def format_prompt(function: str) -> str:
    """Format input prompt for the model"""
    return f"""Given a mathematical function, find its derivative.

Function: {function}
The derivative of this function is:"""

def generate_derivative(function: str, max_length: int = 200) -> str:
    """Generate derivative for a given function"""
    # Format the prompt
    prompt = format_prompt(function)
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=1,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode and extract derivative
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    derivative = generated[len(prompt):].strip()
    
    return derivative

def solve_derivative(function: str) -> str:
    """Solve derivative and format output"""
    if not function:
        return "Please enter a function"
    
    print(f"\nGenerating derivative for: {function}")
    derivative = generate_derivative(function)
    
    # Format output with step-by-step explanation
    output = f"""Generated derivative: {derivative}

Let's verify this step by step:
1. Starting with f(x) = {function}
2. Applying differentiation rules
3. We get f'(x) = {derivative}"""
    
    return output

# Create Gradio interface
with gr.Blocks(title="Mathematics Derivative Solver") as demo:
    gr.Markdown("# Mathematics Derivative Solver")
    gr.Markdown("Using our fine-tuned model to solve derivatives")
    
    with gr.Row():
        with gr.Column():
            function_input = gr.Textbox(
                label="Enter a function",
                placeholder="Example: x^2, sin(x), e^x"
            )
            solve_btn = gr.Button("Find Derivative", variant="primary")
    
    with gr.Row():
        output = gr.Textbox(
            label="Solution with Steps",
            lines=6
        )
    
    # Example functions
    gr.Examples(
        examples=[
            ["x^2"],
            ["\\sin{\\left(x\\right)}"],
            ["e^x"],
            ["\\frac{1}{x}"],
            ["x^3 + 2x"],
            ["\\cos{\\left(x^2\\right)}"],
            ["\\log{\\left(x\\right)}"],
            ["x e^{-x}"]
        ],
        inputs=function_input,
        outputs=output,
        fn=solve_derivative,
        cache_examples=True,
    )
    
    # Connect the interface
    solve_btn.click(
        fn=solve_derivative,
        inputs=[function_input],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch()