padmanabhbosamia commited on
Commit
ac8a885
·
verified ·
1 Parent(s): 68060e3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer
5
+ from config import SmolLM2Config
6
+ from model import SmolLM2Lightning
7
+
8
+ def load_model(checkpoint_path):
9
+ """Load the trained model from checkpoint"""
10
+ try:
11
+ config = SmolLM2Config("config.yaml")
12
+ model = SmolLM2Lightning.load_from_checkpoint(checkpoint_path, config=config)
13
+ model.eval()
14
+
15
+ if torch.cuda.is_available():
16
+ model = model.cuda()
17
+ print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
18
+ else:
19
+ print("Model loaded on CPU")
20
+
21
+ return model
22
+ except Exception as e:
23
+ print(f"Error loading model: {str(e)}")
24
+ return None
25
+
26
+ def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
27
+ """Generate text from prompt"""
28
+ try:
29
+ if model is None:
30
+ return "Model not loaded. Please check if checkpoint exists."
31
+
32
+ inputs = model.tokenizer(
33
+ prompt,
34
+ return_tensors="pt",
35
+ truncation=True,
36
+ max_length=model.config.model.max_position_embeddings,
37
+ padding=True
38
+ )
39
+
40
+ if torch.cuda.is_available():
41
+ inputs = {k: v.cuda() for k, v in inputs.items()}
42
+
43
+ with torch.no_grad():
44
+ outputs = model.generate(
45
+ input_ids=inputs.input_ids,
46
+ attention_mask=inputs.attention_mask,
47
+ max_length=max_length,
48
+ temperature=temperature,
49
+ top_p=top_p,
50
+ top_k=top_k,
51
+ do_sample=True,
52
+ pad_token_id=model.tokenizer.pad_token_id,
53
+ bos_token_id=model.tokenizer.bos_token_id,
54
+ eos_token_id=model.tokenizer.eos_token_id
55
+ )
56
+
57
+ return model.tokenizer.decode(outputs[0], skip_special_tokens=True)
58
+
59
+ except Exception as e:
60
+ return f"Error generating text: {str(e)}"
61
+
62
+ # Load the model
63
+ print("Loading model...")
64
+ checkpoint_path = "checkpoints/smol-lm2-final.ckpt"
65
+ if not os.path.exists(checkpoint_path):
66
+ print(f"Warning: Checkpoint not found at {checkpoint_path}")
67
+ print("Please train the model first or specify correct checkpoint path")
68
+ model = None
69
+ else:
70
+ model = load_model(checkpoint_path)
71
+
72
+ # Create Gradio interface
73
+ demo = gr.Interface(
74
+ fn=generate_text,
75
+ inputs=[
76
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
77
+ gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
78
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
79
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p"),
80
+ gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
81
+ ],
82
+ outputs=gr.Textbox(label="Generated Text"),
83
+ title="SmolLM2 Text Generation",
84
+ description="Enter a prompt and adjust generation parameters to create text with SmolLM2",
85
+ examples=[
86
+ ["Explain what machine learning is:", 100, 0.7, 0.9, 50],
87
+ ["Once upon a time", 150, 0.8, 0.9, 40],
88
+ ["The best way to learn programming is", 120, 0.7, 0.9, 50]
89
+ ]
90
+ )
91
+
92
+ if __name__ == "__main__":
93
+ print("Starting Gradio interface...")
94
+ # Simple launch configuration
95
+ demo.launch(
96
+ server_port=7860,
97
+ share=True
98
+ )