Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -42,6 +42,14 @@ st.set_page_config(
|
|
| 42 |
}
|
| 43 |
)
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
# Model Configuration Classes
|
| 46 |
@dataclass
|
| 47 |
class ModelConfig:
|
|
@@ -110,6 +118,7 @@ class ModelBuilder:
|
|
| 110 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 111 |
if config:
|
| 112 |
self.config = config
|
|
|
|
| 113 |
st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
|
| 114 |
return self
|
| 115 |
def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
|
|
@@ -233,6 +242,15 @@ def get_model_files(model_type="causal_lm"):
|
|
| 233 |
def get_gallery_files(file_types):
|
| 234 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
# Mock Search Tool for RAG
|
| 237 |
def mock_search(query: str) -> str:
|
| 238 |
if "superhero" in query.lower():
|
|
@@ -299,13 +317,7 @@ st.title("SFT Tiny Titans 🚀 (Small but Mighty!)")
|
|
| 299 |
# Sidebar Galleries
|
| 300 |
st.sidebar.header("Media Gallery 🎨")
|
| 301 |
gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 4)
|
| 302 |
-
|
| 303 |
-
if media_files:
|
| 304 |
-
cols = st.sidebar.columns(2)
|
| 305 |
-
for idx, file in enumerate(media_files[:gallery_size * 2]):
|
| 306 |
-
with cols[idx % 2]:
|
| 307 |
-
st.image(Image.open(file), caption=file, use_column_width=True)
|
| 308 |
-
st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)
|
| 309 |
|
| 310 |
st.sidebar.subheader("Model Management 🗂️")
|
| 311 |
model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
|
|
@@ -350,10 +362,8 @@ with tab2:
|
|
| 350 |
filename = generate_filename(0)
|
| 351 |
with open(filename, "wb") as f:
|
| 352 |
f.write(cam0_img.getvalue())
|
| 353 |
-
st.image(Image.open(filename), caption=filename,
|
| 354 |
logger.info(f"Saved snapshot from Camera 0: {filename}")
|
| 355 |
-
if 'captured_images' not in st.session_state:
|
| 356 |
-
st.session_state['captured_images'] = []
|
| 357 |
st.session_state['captured_images'].append(filename)
|
| 358 |
update_gallery()
|
| 359 |
if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
|
|
@@ -370,7 +380,7 @@ with tab2:
|
|
| 370 |
st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
|
| 371 |
update_gallery()
|
| 372 |
for frame in st.session_state['cam0_frames']:
|
| 373 |
-
st.image(Image.open(frame), caption=frame,
|
| 374 |
with cols[1]:
|
| 375 |
st.subheader("Camera 1")
|
| 376 |
cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
|
|
@@ -378,10 +388,8 @@ with tab2:
|
|
| 378 |
filename = generate_filename(1)
|
| 379 |
with open(filename, "wb") as f:
|
| 380 |
f.write(cam1_img.getvalue())
|
| 381 |
-
st.image(Image.open(filename), caption=filename,
|
| 382 |
logger.info(f"Saved snapshot from Camera 1: {filename}")
|
| 383 |
-
if 'captured_images' not in st.session_state:
|
| 384 |
-
st.session_state['captured_images'] = []
|
| 385 |
st.session_state['captured_images'].append(filename)
|
| 386 |
update_gallery()
|
| 387 |
if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
|
|
@@ -398,7 +406,7 @@ with tab2:
|
|
| 398 |
st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
|
| 399 |
update_gallery()
|
| 400 |
for frame in st.session_state['cam1_frames']:
|
| 401 |
-
st.image(Image.open(frame), caption=frame,
|
| 402 |
|
| 403 |
with tab3:
|
| 404 |
st.header("Fine-Tune Titan 🔧")
|
|
@@ -485,4 +493,7 @@ st.sidebar.subheader("Action Logs 📜")
|
|
| 485 |
log_container = st.sidebar.empty()
|
| 486 |
with log_container:
|
| 487 |
for record in log_records:
|
| 488 |
-
st.write(f"{record.asctime} - {record.levelname} - {record.message}")
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
}
|
| 43 |
)
|
| 44 |
|
| 45 |
+
# Initialize st.session_state
|
| 46 |
+
if 'captured_images' not in st.session_state:
|
| 47 |
+
st.session_state['captured_images'] = []
|
| 48 |
+
if 'builder' not in st.session_state:
|
| 49 |
+
st.session_state['builder'] = None
|
| 50 |
+
if 'model_loaded' not in st.session_state:
|
| 51 |
+
st.session_state['model_loaded'] = False
|
| 52 |
+
|
| 53 |
# Model Configuration Classes
|
| 54 |
@dataclass
|
| 55 |
class ModelConfig:
|
|
|
|
| 118 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 119 |
if config:
|
| 120 |
self.config = config
|
| 121 |
+
self.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 122 |
st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
|
| 123 |
return self
|
| 124 |
def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
|
|
|
|
| 242 |
def get_gallery_files(file_types):
|
| 243 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
| 244 |
|
| 245 |
+
def update_gallery():
|
| 246 |
+
media_files = get_gallery_files(["png"])
|
| 247 |
+
if media_files:
|
| 248 |
+
cols = st.sidebar.columns(2)
|
| 249 |
+
for idx, file in enumerate(media_files[:gallery_size * 2]):
|
| 250 |
+
with cols[idx % 2]:
|
| 251 |
+
st.image(Image.open(file), caption=file, use_container_width=True)
|
| 252 |
+
st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)
|
| 253 |
+
|
| 254 |
# Mock Search Tool for RAG
|
| 255 |
def mock_search(query: str) -> str:
|
| 256 |
if "superhero" in query.lower():
|
|
|
|
| 317 |
# Sidebar Galleries
|
| 318 |
st.sidebar.header("Media Gallery 🎨")
|
| 319 |
gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 4)
|
| 320 |
+
update_gallery()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
st.sidebar.subheader("Model Management 🗂️")
|
| 323 |
model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
|
|
|
|
| 362 |
filename = generate_filename(0)
|
| 363 |
with open(filename, "wb") as f:
|
| 364 |
f.write(cam0_img.getvalue())
|
| 365 |
+
st.image(Image.open(filename), caption=filename, use_container_width=True)
|
| 366 |
logger.info(f"Saved snapshot from Camera 0: {filename}")
|
|
|
|
|
|
|
| 367 |
st.session_state['captured_images'].append(filename)
|
| 368 |
update_gallery()
|
| 369 |
if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
|
|
|
|
| 380 |
st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
|
| 381 |
update_gallery()
|
| 382 |
for frame in st.session_state['cam0_frames']:
|
| 383 |
+
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
| 384 |
with cols[1]:
|
| 385 |
st.subheader("Camera 1")
|
| 386 |
cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
|
|
|
|
| 388 |
filename = generate_filename(1)
|
| 389 |
with open(filename, "wb") as f:
|
| 390 |
f.write(cam1_img.getvalue())
|
| 391 |
+
st.image(Image.open(filename), caption=filename, use_container_width=True)
|
| 392 |
logger.info(f"Saved snapshot from Camera 1: {filename}")
|
|
|
|
|
|
|
| 393 |
st.session_state['captured_images'].append(filename)
|
| 394 |
update_gallery()
|
| 395 |
if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
|
|
|
|
| 406 |
st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
|
| 407 |
update_gallery()
|
| 408 |
for frame in st.session_state['cam1_frames']:
|
| 409 |
+
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
| 410 |
|
| 411 |
with tab3:
|
| 412 |
st.header("Fine-Tune Titan 🔧")
|
|
|
|
| 493 |
log_container = st.sidebar.empty()
|
| 494 |
with log_container:
|
| 495 |
for record in log_records:
|
| 496 |
+
st.write(f"{record.asctime} - {record.levelname} - {record.message}")
|
| 497 |
+
|
| 498 |
+
# Initial Gallery Update
|
| 499 |
+
update_gallery()
|