awacke1 commited on
Commit
990b411
·
verified ·
1 Parent(s): 3ba8ca7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +304 -22
app.py CHANGED
@@ -35,24 +35,182 @@ st.set_page_config(
35
  }
36
  )
37
 
38
- # [Previous sections like ModelConfig, SFTDataset, ModelBuilder, Utility Functions remain unchanged...]
 
 
 
 
 
 
 
 
 
 
39
 
40
- # Cargo Travel Time Tool with Refined Docstring
41
- from smolagents import tool
 
 
 
 
42
 
43
- @tool
44
- def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
45
- """
46
- Calculate cargo plane travel time between two coordinates using the great-circle distance.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- Args:
49
- origin_coords (Tuple[float, float]): The latitude and longitude of the starting point in degrees, e.g., (42.3601, -71.0589).
50
- destination_coords (Tuple[float, float]): The latitude and longitude of the destination in degrees, e.g., (40.7128, -74.0060).
51
- cruising_speed_kmh (float, optional): The cruising speed of the cargo plane in kilometers per hour. Defaults to 750.0.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- Returns:
54
- float: The estimated travel time in hours, rounded to two decimal places, including takeoff and landing time.
55
- """
 
 
 
 
 
 
 
 
56
  def to_radians(degrees: float) -> float:
57
  return degrees * (math.pi / 180)
58
  lat1, lon1 = map(to_radians, origin_coords)
@@ -63,14 +221,14 @@ def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_
63
  a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
64
  c = 2 * math.asin(math.sqrt(a))
65
  distance = EARTH_RADIUS_KM * c
66
- actual_distance = distance * 1.1 # 10% buffer for real-world routes
67
- flight_time = (actual_distance / cruising_speed_kmh) + 1.0 # Add 1 hour for takeoff/landing
68
  return round(flight_time, 2)
69
 
70
  # Main App
71
  st.title("SFT Tiny Titans 🚀 (Small but Mighty!)")
72
 
73
- # Sidebar with Galleries (unchanged)
74
  st.sidebar.header("Galleries & Shenanigans 🎨")
75
  st.sidebar.subheader("Image Gallery 📸")
76
  img_files = get_gallery_files(["png", "jpg", "jpeg"])
@@ -98,10 +256,134 @@ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
98
  st.session_state['model_loaded'] = True
99
  st.rerun()
100
 
101
- # Main UI with Tabs (only Tab 4 shown here for brevity)
102
  tab1, tab2, tab3, tab4 = st.tabs(["Build Tiny Titan 🌱", "Fine-Tune Titan 🔧", "Test Titan 🧪", "Agentic RAG Party 🌐"])
103
 
104
- # [Tab 1, Tab 2, Tab 3 remain unchanged...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  with tab4:
107
  st.header("Agentic RAG Party 🌐 (Party Like It’s 2099!)")
@@ -112,12 +394,12 @@ with tab4:
112
  from smolagents import CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool
113
  from transformers import AutoModelForCausalLM
114
 
115
- # Load the model
116
  with st.spinner("Loading SmolLM-135M... ⏳ (Titan’s suiting up!)"):
117
  model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M")
118
  st.write("Model loaded! 🦸‍♂️ (Ready to party!)")
119
 
120
- # Initialize agent with proper tools
121
  agent = CodeAgent(
122
  model=model,
123
  tools=[DuckDuckGoSearchTool(), VisitWebpageTool(), calculate_cargo_travel_time],
@@ -144,4 +426,4 @@ Add a random superhero catchphrase to each entry for fun!
144
  except TypeError as e:
145
  st.error(f"Agent setup failed: {str(e)} (Looks like the Titans need a tune-up!)")
146
  except Exception as e:
147
- st.error(f"Error running demo: {str(e)} (Even Batman has off days!)")
 
35
  }
36
  )
37
 
38
+ # Model Configuration Class
39
+ @dataclass
40
+ class ModelConfig:
41
+ name: str
42
+ base_model: str
43
+ size: str
44
+ domain: Optional[str] = None
45
+
46
+ @property
47
+ def model_path(self):
48
+ return f"models/{self.name}"
49
 
50
+ # Custom Dataset for SFT
51
+ class SFTDataset(Dataset):
52
+ def __init__(self, data, tokenizer, max_length=128):
53
+ self.data = data
54
+ self.tokenizer = tokenizer
55
+ self.max_length = max_length
56
 
57
+ def __len__(self):
58
+ return len(self.data)
59
+
60
+ def __getitem__(self, idx):
61
+ prompt = self.data[idx]["prompt"]
62
+ response = self.data[idx]["response"]
63
+
64
+ full_text = f"{prompt} {response}"
65
+ full_encoding = self.tokenizer(
66
+ full_text,
67
+ max_length=self.max_length,
68
+ padding="max_length",
69
+ truncation=True,
70
+ return_tensors="pt"
71
+ )
72
+
73
+ prompt_encoding = self.tokenizer(
74
+ prompt,
75
+ max_length=self.max_length,
76
+ padding=False,
77
+ truncation=True,
78
+ return_tensors="pt"
79
+ )
80
+
81
+ input_ids = full_encoding["input_ids"].squeeze()
82
+ attention_mask = full_encoding["attention_mask"].squeeze()
83
+ labels = input_ids.clone()
84
+
85
+ prompt_len = prompt_encoding["input_ids"].shape[1]
86
+ if prompt_len < self.max_length:
87
+ labels[:prompt_len] = -100
88
+
89
+ return {
90
+ "input_ids": input_ids,
91
+ "attention_mask": attention_mask,
92
+ "labels": labels
93
+ }
94
+
95
+ # Model Builder Class with Easter Egg Jokes
96
+ class ModelBuilder:
97
+ def __init__(self):
98
+ self.config = None
99
+ self.model = None
100
+ self.tokenizer = None
101
+ self.sft_data = None
102
+ self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! 😂", "Training complete! Time for a binary coffee break. ☕"]
103
+
104
+ def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
105
+ with st.spinner(f"Loading {model_path}... ⏳ (Patience, young padawan!)"):
106
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
107
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
108
+ if self.tokenizer.pad_token is None:
109
+ self.tokenizer.pad_token = self.tokenizer.eos_token
110
+ if config:
111
+ self.config = config
112
+ st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
113
+ return self
114
+
115
+ def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
116
+ self.sft_data = []
117
+ with open(csv_path, "r") as f:
118
+ reader = csv.DictReader(f)
119
+ for row in reader:
120
+ self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
121
+
122
+ dataset = SFTDataset(self.sft_data, self.tokenizer)
123
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
124
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
125
+
126
+ self.model.train()
127
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
+ self.model.to(device)
129
+ for epoch in range(epochs):
130
+ with st.spinner(f"Training epoch {epoch + 1}/{epochs}... ⚙️ (The AI is lifting weights!)"):
131
+ total_loss = 0
132
+ for batch in dataloader:
133
+ optimizer.zero_grad()
134
+ input_ids = batch["input_ids"].to(device)
135
+ attention_mask = batch["attention_mask"].to(device)
136
+ labels = batch["labels"].to(device)
137
+
138
+ assert input_ids.shape[0] == labels.shape[0], f"Batch size mismatch: input_ids {input_ids.shape}, labels {labels.shape}"
139
+
140
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
141
+ loss = outputs.loss
142
+ loss.backward()
143
+ optimizer.step()
144
+ total_loss += loss.item()
145
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
146
+ st.success(f"SFT Fine-tuning completed! 🎉 {random.choice(self.jokes)}")
147
+ return self
148
+
149
+ def save_model(self, path: str):
150
+ with st.spinner("Saving model... 💾 (Packing the AI’s suitcase!)"):
151
+ os.makedirs(os.path.dirname(path), exist_ok=True)
152
+ self.model.save_pretrained(path)
153
+ self.tokenizer.save_pretrained(path)
154
+ st.success(f"Model saved at {path}! ✅ May the force be with it.")
155
 
156
+ def evaluate(self, prompt: str, status_container=None):
157
+ """Evaluate with feedback"""
158
+ self.model.eval()
159
+ if status_container:
160
+ status_container.write("Preparing to evaluate... 🧠 (Titan’s warming up its circuits!)")
161
+ logger.info(f"Evaluating prompt: {prompt}")
162
+
163
+ try:
164
+ with torch.no_grad():
165
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
166
+ if status_container:
167
+ status_container.write(f"Tokenized input shape: {inputs['input_ids'].shape} 📏")
168
+
169
+ outputs = self.model.generate(
170
+ **inputs,
171
+ max_new_tokens=50,
172
+ do_sample=True,
173
+ top_p=0.95,
174
+ temperature=0.7
175
+ )
176
+ if status_container:
177
+ status_container.write("Generation complete! Decoding response... 🗣")
178
+
179
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
180
+ logger.info(f"Generated response: {result}")
181
+ return result
182
+ except Exception as e:
183
+ logger.error(f"Evaluation error: {str(e)}")
184
+ if status_container:
185
+ status_container.error(f"Oops! Something broke: {str(e)} 💥 (Titan tripped over a wire!)")
186
+ return f"Error: {str(e)}"
187
+
188
+ # Utility Functions with Wit
189
+ def get_download_link(file_path, mime_type="text/plain", label="Download"):
190
+ with open(file_path, 'rb') as f:
191
+ data = f.read()
192
+ b64 = base64.b64encode(data).decode()
193
+ return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥 (Grab it before it runs away!)</a>'
194
+
195
+ def zip_directory(directory_path, zip_path):
196
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
197
+ for root, _, files in os.walk(directory_path):
198
+ for file in files:
199
+ file_path = os.path.join(root, file)
200
+ arcname = os.path.relpath(file_path, os.path.dirname(directory_path))
201
+ zipf.write(file_path, arcname)
202
 
203
+ def get_model_files():
204
+ return [d for d in glob.glob("models/*") if os.path.isdir(d)]
205
+
206
+ def get_gallery_files(file_types):
207
+ files = []
208
+ for ext in file_types:
209
+ files.extend(glob.glob(f"*.{ext}"))
210
+ return sorted(files)
211
+
212
+ # Cargo Travel Time Tool
213
+ def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
214
  def to_radians(degrees: float) -> float:
215
  return degrees * (math.pi / 180)
216
  lat1, lon1 = map(to_radians, origin_coords)
 
221
  a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
222
  c = 2 * math.asin(math.sqrt(a))
223
  distance = EARTH_RADIUS_KM * c
224
+ actual_distance = distance * 1.1
225
+ flight_time = (actual_distance / cruising_speed_kmh) + 1.0
226
  return round(flight_time, 2)
227
 
228
  # Main App
229
  st.title("SFT Tiny Titans 🚀 (Small but Mighty!)")
230
 
231
+ # Sidebar with Galleries
232
  st.sidebar.header("Galleries & Shenanigans 🎨")
233
  st.sidebar.subheader("Image Gallery 📸")
234
  img_files = get_gallery_files(["png", "jpg", "jpeg"])
 
256
  st.session_state['model_loaded'] = True
257
  st.rerun()
258
 
259
+ # Main UI with Tabs
260
  tab1, tab2, tab3, tab4 = st.tabs(["Build Tiny Titan 🌱", "Fine-Tune Titan 🔧", "Test Titan 🧪", "Agentic RAG Party 🌐"])
261
 
262
+ with tab1:
263
+ st.header("Build Tiny Titan 🌱 (Assemble Your Mini-Mecha!)")
264
+ base_model = st.selectbox(
265
+ "Select Tiny Model",
266
+ ["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"],
267
+ help="Pick a pint-sized powerhouse (<1 GB)! SmolLM-135M (~270 MB), SmolLM-360M (~720 MB), Qwen1.5-0.5B (~1 GB)"
268
+ )
269
+ model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
270
+ domain = st.text_input("Target Domain", "general")
271
+
272
+ if st.button("Download Model ⬇️"):
273
+ config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain)
274
+ builder = ModelBuilder()
275
+ builder.load_model(base_model, config)
276
+ builder.save_model(config.model_path)
277
+ st.session_state['builder'] = builder
278
+ st.session_state['model_loaded'] = True
279
+ st.success(f"Model downloaded and saved to {config.model_path}! 🎉 (Tiny but feisty!)")
280
+ st.rerun()
281
+
282
+ with tab2:
283
+ st.header("Fine-Tune Titan 🔧 (Teach Your Titan Some Tricks!)")
284
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
285
+ st.warning("Please build or load a Titan first! ⚠️ (No Titan, no party!)")
286
+ else:
287
+ if st.button("Generate Sample CSV 📝"):
288
+ sample_data = [
289
+ {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human smarts in machines."},
290
+ {"prompt": "Explain machine learning", "response": "Machine learning is AI’s gym where models bulk up on data."},
291
+ {"prompt": "What is a neural network?", "response": "A neural network is a brainy AI mimicking human noggins."},
292
+ ]
293
+ csv_path = f"sft_data_{int(time.time())}.csv"
294
+ with open(csv_path, "w", newline="") as f:
295
+ writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
296
+ writer.writeheader()
297
+ writer.writerows(sample_data)
298
+ st.markdown(get_download_link(csv_path, "text/csv", "Download Sample CSV"), unsafe_allow_html=True)
299
+ st.success(f"Sample CSV generated as {csv_path}! ✅ (Fresh from the data oven!)")
300
+
301
+ uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
302
+ if uploaded_csv and st.button("Fine-Tune with Uploaded CSV 🔄"):
303
+ csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
304
+ with open(csv_path, "wb") as f:
305
+ f.write(uploaded_csv.read())
306
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
307
+ new_config = ModelConfig(
308
+ name=new_model_name,
309
+ base_model=st.session_state['builder'].config.base_model,
310
+ size="small",
311
+ domain=st.session_state['builder'].config.domain
312
+ )
313
+ st.session_state['builder'].config = new_config
314
+ with st.status("Fine-tuning Titan... ⏳ (Whipping it into shape!)", expanded=True) as status:
315
+ st.session_state['builder'].fine_tune_sft(csv_path)
316
+ st.session_state['builder'].save_model(new_config.model_path)
317
+ status.update(label="Fine-tuning completed! 🎉 (Titan’s ready to rumble!)", state="complete")
318
+
319
+ zip_path = f"{new_config.model_path}.zip"
320
+ zip_directory(new_config.model_path, zip_path)
321
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)
322
+ st.rerun()
323
+
324
+ with tab3:
325
+ st.header("Test Titan 🧪 (Put Your Titan to the Test!)")
326
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
327
+ st.warning("Please build or load a Titan first! ⚠️ (No Titan, no test drive!)")
328
+ else:
329
+ if st.session_state['builder'].sft_data:
330
+ st.write("Testing with SFT Data:")
331
+ with st.spinner("Running SFT data tests... ⏳ (Titan’s flexing its brain muscles!)"):
332
+ for item in st.session_state['builder'].sft_data[:3]:
333
+ prompt = item["prompt"]
334
+ expected = item["response"]
335
+ status_container = st.empty()
336
+ generated = st.session_state['builder'].evaluate(prompt, status_container)
337
+ st.write(f"**Prompt**: {prompt}")
338
+ st.write(f"**Expected**: {expected}")
339
+ st.write(f"**Generated**: {generated} (Titan says: '{random.choice(['Bleep bloop!', 'I am groot!', '42!'])}')")
340
+ st.write("---")
341
+ status_container.empty() # Clear status after each test
342
+
343
+ test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
344
+ if st.button("Run Test ▶️"):
345
+ with st.spinner("Testing your prompt... ⏳ (Titan’s pondering deeply!)"):
346
+ status_container = st.empty()
347
+ result = st.session_state['builder'].evaluate(test_prompt, status_container)
348
+ st.write(f"**Generated Response**: {result} (Titan’s wisdom unleashed!)")
349
+ status_container.empty()
350
+
351
+ if st.button("Export Titan Files 📦"):
352
+ config = st.session_state['builder'].config
353
+ app_code = f"""
354
+ import streamlit as st
355
+ from transformers import AutoModelForCausalLM, AutoTokenizer
356
+
357
+ model = AutoModelForCausalLM.from_pretrained("{config.model_path}")
358
+ tokenizer = AutoTokenizer.from_pretrained("{config.model_path}")
359
+
360
+ st.title("Tiny Titan Demo")
361
+ input_text = st.text_area("Enter prompt")
362
+ if st.button("Generate"):
363
+ inputs = tokenizer(input_text, return_tensors="pt")
364
+ outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
365
+ st.write(tokenizer.decode(outputs[0], skip_special_tokens=True))
366
+ """
367
+ with open("titan_app.py", "w") as f:
368
+ f.write(app_code)
369
+ reqs = "streamlit\ntorch\ntransformers\n"
370
+ with open("titan_requirements.txt", "w") as f:
371
+ f.write(reqs)
372
+ readme = f"""
373
+ # Tiny Titan Demo
374
+
375
+ ## How to run
376
+ 1. Install requirements: `pip install -r titan_requirements.txt`
377
+ 2. Run the app: `streamlit run titan_app.py`
378
+ 3. Input a prompt and click "Generate". Watch the magic unfold! 🪄
379
+ """
380
+ with open("titan_README.md", "w") as f:
381
+ f.write(readme)
382
+
383
+ st.markdown(get_download_link("titan_app.py", "text/plain", "Download App"), unsafe_allow_html=True)
384
+ st.markdown(get_download_link("titan_requirements.txt", "text/plain", "Download Requirements"), unsafe_allow_html=True)
385
+ st.markdown(get_download_link("titan_README.md", "text/markdown", "Download README"), unsafe_allow_html=True)
386
+ st.success("Titan files exported! ✅ (Ready to conquer the galaxy!)")
387
 
388
  with tab4:
389
  st.header("Agentic RAG Party 🌐 (Party Like It’s 2099!)")
 
394
  from smolagents import CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool
395
  from transformers import AutoModelForCausalLM
396
 
397
+ # Load the model without separate tokenizer for agent
398
  with st.spinner("Loading SmolLM-135M... ⏳ (Titan’s suiting up!)"):
399
  model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M")
400
  st.write("Model loaded! 🦸‍♂️ (Ready to party!)")
401
 
402
+ # Initialize agent without tokenizer argument
403
  agent = CodeAgent(
404
  model=model,
405
  tools=[DuckDuckGoSearchTool(), VisitWebpageTool(), calculate_cargo_travel_time],
 
426
  except TypeError as e:
427
  st.error(f"Agent setup failed: {str(e)} (Looks like the Titans need a tune-up!)")
428
  except Exception as e:
429
+ st.error(f"Error running demo: {str(e)} (Even Batman has off days!)")