awacke1 commited on
Commit
017755d
Β·
verified Β·
1 Parent(s): 0893521

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +370 -0
app.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import shutil
4
+ import glob
5
+ import base64
6
+ import streamlit as st
7
+ import pandas as pd
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import csv
12
+ import time
13
+ from dataclasses import dataclass
14
+ from typing import Optional, Tuple
15
+ import zipfile
16
+ import math
17
+
18
+ # Page Configuration
19
+ st.set_page_config(
20
+ page_title="SFT Model Builder πŸš€",
21
+ page_icon="πŸ€–",
22
+ layout="wide",
23
+ initial_sidebar_state="expanded",
24
+ )
25
+
26
+ # Model Configuration Class
27
+ @dataclass
28
+ class ModelConfig:
29
+ name: str
30
+ base_model: str
31
+ size: str
32
+ domain: Optional[str] = None
33
+
34
+ @property
35
+ def model_path(self):
36
+ return f"models/{self.name}"
37
+
38
+ # Custom Dataset for SFT
39
+ class SFTDataset(Dataset):
40
+ def __init__(self, data, tokenizer, max_length=128):
41
+ self.data = data
42
+ self.tokenizer = tokenizer
43
+ self.max_length = max_length
44
+
45
+ def __len__(self):
46
+ return len(self.data)
47
+
48
+ def __getitem__(self, idx):
49
+ prompt = self.data[idx]["prompt"]
50
+ response = self.data[idx]["response"]
51
+
52
+ prompt_encoding = self.tokenizer(
53
+ prompt,
54
+ max_length=self.max_length // 2,
55
+ padding="max_length",
56
+ truncation=True,
57
+ return_tensors="pt"
58
+ )
59
+
60
+ full_text = f"{prompt} {response}"
61
+ full_encoding = self.tokenizer(
62
+ full_text,
63
+ max_length=self.max_length,
64
+ padding="max_length",
65
+ truncation=True,
66
+ return_tensors="pt"
67
+ )
68
+
69
+ input_ids = prompt_encoding["input_ids"].squeeze()
70
+ attention_mask = prompt_encoding["attention_mask"].squeeze()
71
+ labels = full_encoding["input_ids"].squeeze()
72
+
73
+ prompt_len = prompt_encoding["input_ids"].ne(self.tokenizer.pad_token_id).sum().item()
74
+ labels[:prompt_len] = -100 # Mask prompt in loss
75
+
76
+ return {
77
+ "input_ids": input_ids,
78
+ "attention_mask": attention_mask,
79
+ "labels": labels
80
+ }
81
+
82
+ # Model Builder Class
83
+ class ModelBuilder:
84
+ def __init__(self):
85
+ self.config = None
86
+ self.model = None
87
+ self.tokenizer = None
88
+ self.sft_data = None
89
+
90
+ def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
91
+ with st.spinner("Loading model... ⏳"):
92
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
93
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
94
+ if self.tokenizer.pad_token is None:
95
+ self.tokenizer.pad_token = self.tokenizer.eos_token
96
+ if config:
97
+ self.config = config
98
+ st.success("Model loaded! βœ…")
99
+ return self
100
+
101
+ def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
102
+ self.sft_data = []
103
+ with open(csv_path, "r") as f:
104
+ reader = csv.DictReader(f)
105
+ for row in reader:
106
+ self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
107
+
108
+ dataset = SFTDataset(self.sft_data, self.tokenizer)
109
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
110
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
111
+
112
+ self.model.train()
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+ self.model.to(device)
115
+ for epoch in range(epochs):
116
+ with st.spinner(f"Training epoch {epoch + 1}/{epochs}... βš™οΈ"):
117
+ total_loss = 0
118
+ for batch in dataloader:
119
+ optimizer.zero_grad()
120
+ input_ids = batch["input_ids"].to(device)
121
+ attention_mask = batch["attention_mask"].to(device)
122
+ labels = batch["labels"].to(device)
123
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
124
+ loss = outputs.loss
125
+ loss.backward()
126
+ optimizer.step()
127
+ total_loss += loss.item()
128
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
129
+ st.success("SFT Fine-tuning completed! πŸŽ‰")
130
+ return self
131
+
132
+ def save_model(self, path: str):
133
+ with st.spinner("Saving model... πŸ’Ύ"):
134
+ os.makedirs(os.path.dirname(path), exist_ok=True)
135
+ self.model.save_pretrained(path)
136
+ self.tokenizer.save_pretrained(path)
137
+ st.success(f"Model saved at {path}! βœ…")
138
+
139
+ def evaluate(self, prompt: str):
140
+ self.model.eval()
141
+ with torch.no_grad():
142
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
143
+ outputs = self.model.generate(
144
+ **inputs,
145
+ max_new_tokens=50,
146
+ do_sample=True,
147
+ top_p=0.95,
148
+ temperature=0.7
149
+ )
150
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
151
+
152
+ # Utility Functions
153
+ def get_download_link(file_path, mime_type="text/plain", label="Download"):
154
+ with open(file_path, 'rb') as f:
155
+ data = f.read()
156
+ b64 = base64.b64encode(data).decode()
157
+ return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} πŸ“₯</a>'
158
+
159
+ def zip_directory(directory_path, zip_path):
160
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
161
+ for root, _, files in os.walk(directory_path):
162
+ for file in files:
163
+ file_path = os.path.join(root, file)
164
+ arcname = os.path.relpath(file_path, os.path.dirname(directory_path))
165
+ zipf.write(file_path, arcname)
166
+
167
+ def get_model_files():
168
+ return [d for d in glob.glob("models/*") if os.path.isdir(d)]
169
+
170
+ # Cargo Travel Time Tool
171
+ def calculate_cargo_travel_time(
172
+ origin_coords: Tuple[float, float],
173
+ destination_coords: Tuple[float, float],
174
+ cruising_speed_kmh: float = 750.0
175
+ ) -> float:
176
+ def to_radians(degrees: float) -> float:
177
+ return degrees * (math.pi / 180)
178
+ lat1, lon1 = map(to_radians, origin_coords)
179
+ lat2, lon2 = map(to_radians, destination_coords)
180
+ EARTH_RADIUS_KM = 6371.0
181
+ dlon = lon2 - lon1
182
+ dlat = lat2 - lat1
183
+ a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
184
+ c = 2 * math.asin(math.sqrt(a))
185
+ distance = EARTH_RADIUS_KM * c
186
+ actual_distance = distance * 1.1
187
+ flight_time = (actual_distance / cruising_speed_kmh) + 1.0
188
+ return round(flight_time, 2)
189
+
190
+ # Main App
191
+ st.title("SFT Model Builder πŸ€–πŸš€")
192
+
193
+ # Sidebar for Model Management
194
+ st.sidebar.header("Model Management πŸ—‚οΈ")
195
+ model_dirs = get_model_files()
196
+ selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
197
+
198
+ if selected_model != "None" and st.sidebar.button("Load Model πŸ“‚"):
199
+ if 'builder' not in st.session_state:
200
+ st.session_state['builder'] = ModelBuilder()
201
+ config = ModelConfig(name=os.path.basename(selected_model), base_model="unknown", size="small", domain="general")
202
+ st.session_state['builder'].load_model(selected_model, config)
203
+ st.session_state['model_loaded'] = True
204
+ st.rerun()
205
+
206
+ # Main UI with Tabs
207
+ tab1, tab2, tab3, tab4 = st.tabs(["Build New Model 🌱", "Fine-Tune Model πŸ”§", "Test Model πŸ§ͺ", "Agentic RAG Demo 🌐"])
208
+
209
+ with tab1:
210
+ st.header("Build New Model 🌱")
211
+ base_model = st.selectbox(
212
+ "Select Base Model",
213
+ [
214
+ "HuggingFaceTB/SmolLM-135M", # ~270 MB
215
+ "HuggingFaceTB/SmolLM-360M", # ~720 MB
216
+ "Qwen/Qwen1.5-0.5B-Chat", # ~1 GB
217
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # ~2 GB, slightly over but included
218
+ ],
219
+ help="Choose a tiny, open-source model (<1 GB except TinyLlama)"
220
+ )
221
+ model_name = st.text_input("Model Name", f"new-model-{int(time.time())}")
222
+ domain = st.text_input("Target Domain", "general")
223
+
224
+ if st.button("Download Model ⬇️"):
225
+ config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain)
226
+ builder = ModelBuilder()
227
+ builder.load_model(base_model, config)
228
+ builder.save_model(config.model_path)
229
+ st.session_state['builder'] = builder
230
+ st.session_state['model_loaded'] = True
231
+ st.success(f"Model downloaded and saved to {config.model_path}! πŸŽ‰")
232
+ st.rerun()
233
+
234
+ with tab2:
235
+ st.header("Fine-Tune Model πŸ”§")
236
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
237
+ st.warning("Please download or load a model first! ⚠️")
238
+ else:
239
+ if st.button("Generate Sample CSV πŸ“"):
240
+ sample_data = [
241
+ {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human intelligence in machines."},
242
+ {"prompt": "Explain machine learning", "response": "Machine learning is a subset of AI where models learn from data."},
243
+ {"prompt": "What is a neural network?", "response": "A neural network is a model inspired by the human brain."},
244
+ ]
245
+ csv_path = f"sft_data_{int(time.time())}.csv"
246
+ with open(csv_path, "w", newline="") as f:
247
+ writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
248
+ writer.writeheader()
249
+ writer.writerows(sample_data)
250
+ st.markdown(get_download_link(csv_path, "text/csv", "Download Sample CSV"), unsafe_allow_html=True)
251
+ st.success(f"Sample CSV generated as {csv_path}! βœ…")
252
+
253
+ uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
254
+ if uploaded_csv and st.button("Fine-Tune with Uploaded CSV πŸ”„"):
255
+ csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
256
+ with open(csv_path, "wb") as f:
257
+ f.write(uploaded_csv.read())
258
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
259
+ new_config = ModelConfig(
260
+ name=new_model_name,
261
+ base_model=st.session_state['builder'].config.base_model,
262
+ size="small",
263
+ domain=st.session_state['builder'].config.domain
264
+ )
265
+ st.session_state['builder'].config = new_config
266
+ with st.status("Fine-tuning model... ⏳", expanded=True) as status:
267
+ st.session_state['builder'].fine_tune_sft(csv_path)
268
+ st.session_state['builder'].save_model(new_config.model_path)
269
+ status.update(label="Fine-tuning completed! πŸŽ‰", state="complete")
270
+
271
+ zip_path = f"{new_config.model_path}.zip"
272
+ zip_directory(new_config.model_path, zip_path)
273
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Model"), unsafe_allow_html=True)
274
+ st.rerun()
275
+
276
+ with tab3:
277
+ st.header("Test Model πŸ§ͺ")
278
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
279
+ st.warning("Please download or load a model first! ⚠️")
280
+ else:
281
+ if st.session_state['builder'].sft_data:
282
+ st.write("Testing with SFT Data:")
283
+ for item in st.session_state['builder'].sft_data[:3]:
284
+ prompt = item["prompt"]
285
+ expected = item["response"]
286
+ generated = st.session_state['builder'].evaluate(prompt)
287
+ st.write(f"**Prompt**: {prompt}")
288
+ st.write(f"**Expected**: {expected}")
289
+ st.write(f"**Generated**: {generated}")
290
+ st.write("---")
291
+
292
+ test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
293
+ if st.button("Run Test ▢️"):
294
+ result = st.session_state['builder'].evaluate(test_prompt)
295
+ st.write(f"**Generated Response**: {result}")
296
+
297
+ if st.button("Export Model Files πŸ“¦"):
298
+ config = st.session_state['builder'].config
299
+ app_code = f"""
300
+ import streamlit as st
301
+ from transformers import AutoModelForCausalLM, AutoTokenizer
302
+
303
+ model = AutoModelForCausalLM.from_pretrained("{config.model_path}")
304
+ tokenizer = AutoTokenizer.from_pretrained("{config.model_path}")
305
+
306
+ st.title("SFT Model Demo")
307
+ input_text = st.text_area("Enter prompt")
308
+ if st.button("Generate"):
309
+ inputs = tokenizer(input_text, return_tensors="pt")
310
+ outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
311
+ st.write(tokenizer.decode(outputs[0], skip_special_tokens=True))
312
+ """
313
+ with open("sft_app.py", "w") as f:
314
+ f.write(app_code)
315
+ reqs = "streamlit\ntorch\ntransformers\n"
316
+ with open("sft_requirements.txt", "w") as f:
317
+ f.write(reqs)
318
+ readme = f"""
319
+ # SFT Model Demo
320
+
321
+ ## How to run
322
+ 1. Install requirements: `pip install -r sft_requirements.txt`
323
+ 2. Run the app: `streamlit run sft_app.py`
324
+ 3. Input a prompt and click "Generate".
325
+ """
326
+ with open("sft_README.md", "w") as f:
327
+ f.write(readme)
328
+
329
+ st.markdown(get_download_link("sft_app.py", "text/plain", "Download App"), unsafe_allow_html=True)
330
+ st.markdown(get_download_link("sft_requirements.txt", "text/plain", "Download Requirements"), unsafe_allow_html=True)
331
+ st.markdown(get_download_link("sft_README.md", "text/markdown", "Download README"), unsafe_allow_html=True)
332
+ st.success("Model files exported! βœ…")
333
+
334
+ with tab4:
335
+ st.header("Agentic RAG Demo 🌐")
336
+ st.write("This demo uses tiny models with Agentic RAG to plan a luxury superhero-themed party, enhancing retrieval with DuckDuckGo.")
337
+
338
+ if st.button("Run Agentic RAG Demo πŸŽ‰"):
339
+ try:
340
+ from smolagents import CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool
341
+
342
+ # Load selected tiny model
343
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
344
+ model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M")
345
+
346
+ # Define Agentic RAG agent
347
+ agent = CodeAgent(
348
+ model=model,
349
+ tokenizer=tokenizer,
350
+ tools=[DuckDuckGoSearchTool(), VisitWebpageTool(), calculate_cargo_travel_time],
351
+ additional_authorized_imports=["pandas"],
352
+ planning_interval=5,
353
+ verbosity_level=2,
354
+ max_steps=15,
355
+ )
356
+
357
+ task = """
358
+ Plan a luxury superhero-themed party at Wayne Manor (42.3601Β° N, 71.0589Β° W). Search for the latest superhero party trends using DuckDuckGo,
359
+ refine results to include luxury elements (decorations, entertainment, catering), and calculate cargo travel times from key locations
360
+ (e.g., New York, LA, London) to Wayne Manor. Synthesize a complete plan and return it as a pandas dataframe with at least 6 entries
361
+ including locations, travel times, and luxury party ideas.
362
+ """
363
+ with st.spinner("Running Agentic RAG system... ⏳"):
364
+ result = agent.run(task)
365
+ st.write("Agentic RAG Result:")
366
+ st.write(result)
367
+ except ImportError:
368
+ st.error("Please install required packages: `pip install smolagents pandas`")
369
+ except Exception as e:
370
+ st.error(f"Error running demo: {str(e)}")