Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -53,10 +53,8 @@ st.set_page_config(
|
|
| 53 |
)
|
| 54 |
|
| 55 |
# Initialize st.session_state
|
| 56 |
-
if 'captured_files' not in st.session_state:
|
| 57 |
-
st.session_state['captured_files'] = {'cam0': None, 'cam1': None} # One file per camera
|
| 58 |
if 'history' not in st.session_state:
|
| 59 |
-
st.session_state['history'] =
|
| 60 |
if 'builder' not in st.session_state:
|
| 61 |
st.session_state['builder'] = None
|
| 62 |
if 'model_loaded' not in st.session_state:
|
|
@@ -329,21 +327,9 @@ def get_model_files(model_type="causal_lm"):
|
|
| 329 |
path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
|
| 330 |
return [d for d in glob.glob(path) if os.path.isdir(d)]
|
| 331 |
|
| 332 |
-
def get_gallery_files(file_types):
|
| 333 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
| 334 |
|
| 335 |
-
def download_pdf(url, output_path):
|
| 336 |
-
try:
|
| 337 |
-
response = requests.get(url, stream=True, timeout=10)
|
| 338 |
-
if response.status_code == 200:
|
| 339 |
-
with open(output_path, "wb") as f:
|
| 340 |
-
for chunk in response.iter_content(chunk_size=8192):
|
| 341 |
-
f.write(chunk)
|
| 342 |
-
return True
|
| 343 |
-
except requests.RequestException as e:
|
| 344 |
-
logger.error(f"Failed to download {url}: {e}")
|
| 345 |
-
return False
|
| 346 |
-
|
| 347 |
# Mock Search Tool for RAG
|
| 348 |
def mock_search(query: str) -> str:
|
| 349 |
if "superhero" in query.lower():
|
|
@@ -445,9 +431,6 @@ async def process_pdf_snapshot(pdf_path, mode="thumbnail"):
|
|
| 445 |
output_files.append(output_file)
|
| 446 |
elapsed = int(time.time() - start_time)
|
| 447 |
status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!")
|
| 448 |
-
for file in output_files:
|
| 449 |
-
if file not in st.session_state['captured_files'].values():
|
| 450 |
-
st.session_state['captured_files'][f"pdf_{len(output_files)}"] = file
|
| 451 |
update_gallery()
|
| 452 |
return output_files
|
| 453 |
except Exception as e:
|
|
@@ -465,8 +448,6 @@ async def process_ocr(image, output_file):
|
|
| 465 |
status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
|
| 466 |
async with aiofiles.open(output_file, "w") as f:
|
| 467 |
await f.write(result)
|
| 468 |
-
if output_file not in st.session_state['captured_files'].values():
|
| 469 |
-
st.session_state['captured_files']['ocr'] = output_file
|
| 470 |
update_gallery()
|
| 471 |
return result
|
| 472 |
|
|
@@ -479,8 +460,6 @@ async def process_image_gen(prompt, output_file):
|
|
| 479 |
elapsed = int(time.time() - start_time)
|
| 480 |
status.text(f"Image Gen completed in {elapsed}s!")
|
| 481 |
gen_image.save(output_file)
|
| 482 |
-
if output_file not in st.session_state['captured_files'].values():
|
| 483 |
-
st.session_state['captured_files']['gen'] = output_file
|
| 484 |
update_gallery()
|
| 485 |
return gen_image
|
| 486 |
|
|
@@ -496,8 +475,6 @@ async def process_custom_diffusion(images, output_file, model_name):
|
|
| 496 |
elapsed = int(time.time() - start_time)
|
| 497 |
status.text(f"{model_name} completed in {elapsed}s!")
|
| 498 |
upscaled_image.save(output_file)
|
| 499 |
-
if output_file not in st.session_state['captured_files'].values():
|
| 500 |
-
st.session_state['captured_files']['diffusion'] = output_file
|
| 501 |
update_gallery()
|
| 502 |
return upscaled_image
|
| 503 |
|
|
@@ -506,18 +483,14 @@ st.title("AI Vision & SFT Titans 🚀")
|
|
| 506 |
|
| 507 |
# Sidebar
|
| 508 |
st.sidebar.header("Captured Files 📜")
|
| 509 |
-
gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2) # Default to 2
|
| 510 |
def update_gallery():
|
| 511 |
-
media_files = [
|
| 512 |
-
|
| 513 |
-
if valid_files:
|
| 514 |
cols = st.sidebar.columns(2)
|
| 515 |
-
|
| 516 |
-
with cols[
|
| 517 |
-
st.image(Image.open(
|
| 518 |
-
if st.session_state['captured_files']['cam1'] in valid_files:
|
| 519 |
-
with cols[1]:
|
| 520 |
-
st.image(Image.open(st.session_state['captured_files']['cam1']), caption="Camera 1", use_container_width=True)
|
| 521 |
update_gallery()
|
| 522 |
|
| 523 |
st.sidebar.subheader("Model Management 🗂️")
|
|
@@ -541,8 +514,7 @@ with log_container:
|
|
| 541 |
st.sidebar.subheader("History 📜")
|
| 542 |
history_container = st.sidebar.empty()
|
| 543 |
with history_container:
|
| 544 |
-
|
| 545 |
-
for entry in [e for e in valid_history if e]: # Show only non-None entries
|
| 546 |
st.write(entry)
|
| 547 |
|
| 548 |
# Tabs
|
|
@@ -561,8 +533,9 @@ with tab1:
|
|
| 561 |
filename = generate_filename("cam0")
|
| 562 |
with open(filename, "wb") as f:
|
| 563 |
f.write(cam0_img.getvalue())
|
| 564 |
-
|
| 565 |
-
st.session_state['history']
|
|
|
|
| 566 |
st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
|
| 567 |
logger.info(f"Saved snapshot from Camera 0: {filename}")
|
| 568 |
update_gallery()
|
|
@@ -572,8 +545,9 @@ with tab1:
|
|
| 572 |
filename = generate_filename("cam1")
|
| 573 |
with open(filename, "wb") as f:
|
| 574 |
f.write(cam1_img.getvalue())
|
| 575 |
-
|
| 576 |
-
st.session_state['history']
|
|
|
|
| 577 |
st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
|
| 578 |
logger.info(f"Saved snapshot from Camera 1: {filename}")
|
| 579 |
update_gallery()
|
|
@@ -589,7 +563,9 @@ with tab2:
|
|
| 589 |
pdf_path = generate_filename("downloaded", "pdf")
|
| 590 |
if download_pdf(url, pdf_path):
|
| 591 |
logger.info(f"Downloaded PDF from {url} to {pdf_path}")
|
| 592 |
-
|
|
|
|
|
|
|
| 593 |
snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode.lower().replace(" ", "")))
|
| 594 |
for snapshot in snapshots:
|
| 595 |
st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
|
|
@@ -611,7 +587,9 @@ with tab3:
|
|
| 611 |
builder.save_model(config.model_path)
|
| 612 |
st.session_state['builder'] = builder
|
| 613 |
st.session_state['model_loaded'] = True
|
| 614 |
-
|
|
|
|
|
|
|
| 615 |
st.success(f"Model downloaded and saved to {config.model_path}! 🎉")
|
| 616 |
st.rerun()
|
| 617 |
|
|
@@ -646,13 +624,15 @@ with tab4:
|
|
| 646 |
st.session_state['builder'].save_model(new_config.model_path)
|
| 647 |
zip_path = f"{new_config.model_path}.zip"
|
| 648 |
zip_directory(new_config.model_path, zip_path)
|
| 649 |
-
|
|
|
|
|
|
|
| 650 |
st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)
|
| 651 |
st.rerun()
|
| 652 |
elif isinstance(st.session_state['builder'], DiffusionBuilder):
|
| 653 |
-
captured_files =
|
| 654 |
if len(captured_files) >= 2:
|
| 655 |
-
demo_data = [{"image": img, "text": f"Superhero {os.path.basename(img).split('.')[0]}"} for img in captured_files
|
| 656 |
edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic")
|
| 657 |
if st.button("Fine-Tune with Dataset 🔄"):
|
| 658 |
images = [Image.open(row["image"]) for _, row in edited_data.iterrows()]
|
|
@@ -664,12 +644,14 @@ with tab4:
|
|
| 664 |
st.session_state['builder'].save_model(new_config.model_path)
|
| 665 |
zip_path = f"{new_config.model_path}.zip"
|
| 666 |
zip_directory(new_config.model_path, zip_path)
|
| 667 |
-
|
|
|
|
|
|
|
| 668 |
st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
|
| 669 |
csv_path = f"sft_dataset_{int(time.time())}.csv"
|
| 670 |
with open(csv_path, "w", newline="") as f:
|
| 671 |
writer = csv.writer(f)
|
| 672 |
-
writer.writerow(["image", "text
|
| 673 |
for _, row in edited_data.iterrows():
|
| 674 |
writer.writerow([row["image"], row["text"]])
|
| 675 |
st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
|
|
@@ -696,7 +678,9 @@ with tab5:
|
|
| 696 |
if st.button("Run Test ▶️"):
|
| 697 |
status_container = st.empty()
|
| 698 |
result = st.session_state['builder'].evaluate(test_prompt, status_container)
|
| 699 |
-
|
|
|
|
|
|
|
| 700 |
st.write(f"**Generated Response**: {result}")
|
| 701 |
status_container.empty()
|
| 702 |
elif isinstance(st.session_state['builder'], DiffusionBuilder):
|
|
@@ -705,8 +689,9 @@ with tab5:
|
|
| 705 |
image = st.session_state['builder'].generate(test_prompt)
|
| 706 |
output_file = generate_filename("diffusion_test", "png")
|
| 707 |
image.save(output_file)
|
| 708 |
-
|
| 709 |
-
st.session_state['history']
|
|
|
|
| 710 |
st.image(image, caption="Generated Image")
|
| 711 |
update_gallery()
|
| 712 |
|
|
@@ -720,28 +705,31 @@ with tab6:
|
|
| 720 |
agent = PartyPlannerAgent(st.session_state['builder'].model, st.session_state['builder'].tokenizer)
|
| 721 |
task = "Plan a luxury superhero-themed party at Wayne Manor."
|
| 722 |
plan_df = agent.plan_party(task)
|
| 723 |
-
|
|
|
|
|
|
|
| 724 |
st.dataframe(plan_df)
|
| 725 |
elif isinstance(st.session_state['builder'], DiffusionBuilder):
|
| 726 |
if st.button("Run CV RAG Demo 🎉"):
|
| 727 |
agent = CVPartyPlannerAgent(st.session_state['builder'].pipeline)
|
| 728 |
task = "Generate images for a luxury superhero-themed party."
|
| 729 |
plan_df = agent.plan_party(task)
|
| 730 |
-
|
|
|
|
|
|
|
| 731 |
st.dataframe(plan_df)
|
| 732 |
for _, row in plan_df.iterrows():
|
| 733 |
image = agent.generate(row["Image Idea"])
|
| 734 |
output_file = generate_filename(f"cv_rag_{row['Theme'].lower()}", "png")
|
| 735 |
image.save(output_file)
|
| 736 |
-
st.session_state['captured_files'][f"cv_rag_{row['Theme'].lower()}"] = output_file
|
| 737 |
st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
|
| 738 |
update_gallery()
|
| 739 |
|
| 740 |
with tab7:
|
| 741 |
st.header("Test OCR 🔍")
|
| 742 |
-
captured_files =
|
| 743 |
if captured_files:
|
| 744 |
-
selected_file = st.selectbox("Select Image",
|
| 745 |
if selected_file:
|
| 746 |
image = Image.open(selected_file)
|
| 747 |
st.image(image, caption="Input Image", use_container_width=True)
|
|
@@ -749,7 +737,9 @@ with tab7:
|
|
| 749 |
output_file = generate_filename("ocr_output", "txt")
|
| 750 |
st.session_state['processing']['ocr'] = True
|
| 751 |
result = asyncio.run(process_ocr(image, output_file))
|
| 752 |
-
|
|
|
|
|
|
|
| 753 |
st.text_area("OCR Result", result, height=200, key="ocr_result")
|
| 754 |
st.success(f"OCR output saved to {output_file}")
|
| 755 |
st.session_state['processing']['ocr'] = False
|
|
@@ -758,9 +748,9 @@ with tab7:
|
|
| 758 |
|
| 759 |
with tab8:
|
| 760 |
st.header("Test Image Gen 🎨")
|
| 761 |
-
captured_files =
|
| 762 |
if captured_files:
|
| 763 |
-
selected_file = st.selectbox("Select Image",
|
| 764 |
if selected_file:
|
| 765 |
image = Image.open(selected_file)
|
| 766 |
st.image(image, caption="Reference Image", use_container_width=True)
|
|
@@ -769,7 +759,9 @@ with tab8:
|
|
| 769 |
output_file = generate_filename("gen_output", "png")
|
| 770 |
st.session_state['processing']['gen'] = True
|
| 771 |
result = asyncio.run(process_image_gen(prompt, output_file))
|
| 772 |
-
|
|
|
|
|
|
|
| 773 |
st.image(result, caption="Generated Image", use_container_width=True)
|
| 774 |
st.success(f"Image saved to {output_file}")
|
| 775 |
st.session_state['processing']['gen'] = False
|
|
@@ -779,10 +771,10 @@ with tab8:
|
|
| 779 |
with tab9:
|
| 780 |
st.header("Custom Diffusion 🎨🤓")
|
| 781 |
st.write("Unleash your inner artist with our tiny diffusion models!")
|
| 782 |
-
captured_files =
|
| 783 |
if captured_files:
|
| 784 |
st.subheader("Select Images to Train")
|
| 785 |
-
selected_files = st.multiselect("Pick Images",
|
| 786 |
images = [Image.open(file) for file in selected_files]
|
| 787 |
|
| 788 |
model_options = [
|
|
@@ -803,8 +795,9 @@ with tab9:
|
|
| 803 |
builder.load_model(model_name)
|
| 804 |
result = builder.generate("A superhero scene inspired by captured images")
|
| 805 |
result.save(output_file)
|
| 806 |
-
|
| 807 |
-
st.session_state['history']
|
|
|
|
| 808 |
st.image(result, caption=f"{model_choice} Masterpiece", use_container_width=True)
|
| 809 |
st.success(f"Image saved to {output_file}")
|
| 810 |
st.session_state['processing']['diffusion'] = False
|
|
|
|
| 53 |
)
|
| 54 |
|
| 55 |
# Initialize st.session_state
|
|
|
|
|
|
|
| 56 |
if 'history' not in st.session_state:
|
| 57 |
+
st.session_state['history'] = [] # Flat list for history
|
| 58 |
if 'builder' not in st.session_state:
|
| 59 |
st.session_state['builder'] = None
|
| 60 |
if 'model_loaded' not in st.session_state:
|
|
|
|
| 327 |
path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
|
| 328 |
return [d for d in glob.glob(path) if os.path.isdir(d)]
|
| 329 |
|
| 330 |
+
def get_gallery_files(file_types=["png"]):
|
| 331 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
# Mock Search Tool for RAG
|
| 334 |
def mock_search(query: str) -> str:
|
| 335 |
if "superhero" in query.lower():
|
|
|
|
| 431 |
output_files.append(output_file)
|
| 432 |
elapsed = int(time.time() - start_time)
|
| 433 |
status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!")
|
|
|
|
|
|
|
|
|
|
| 434 |
update_gallery()
|
| 435 |
return output_files
|
| 436 |
except Exception as e:
|
|
|
|
| 448 |
status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
|
| 449 |
async with aiofiles.open(output_file, "w") as f:
|
| 450 |
await f.write(result)
|
|
|
|
|
|
|
| 451 |
update_gallery()
|
| 452 |
return result
|
| 453 |
|
|
|
|
| 460 |
elapsed = int(time.time() - start_time)
|
| 461 |
status.text(f"Image Gen completed in {elapsed}s!")
|
| 462 |
gen_image.save(output_file)
|
|
|
|
|
|
|
| 463 |
update_gallery()
|
| 464 |
return gen_image
|
| 465 |
|
|
|
|
| 475 |
elapsed = int(time.time() - start_time)
|
| 476 |
status.text(f"{model_name} completed in {elapsed}s!")
|
| 477 |
upscaled_image.save(output_file)
|
|
|
|
|
|
|
| 478 |
update_gallery()
|
| 479 |
return upscaled_image
|
| 480 |
|
|
|
|
| 483 |
|
| 484 |
# Sidebar
|
| 485 |
st.sidebar.header("Captured Files 📜")
|
| 486 |
+
gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2) # Default to 2
|
| 487 |
def update_gallery():
|
| 488 |
+
media_files = get_gallery_files(["png"])
|
| 489 |
+
if media_files:
|
|
|
|
| 490 |
cols = st.sidebar.columns(2)
|
| 491 |
+
for idx, file in enumerate(media_files[:gallery_size * 2]): # Limit by gallery size
|
| 492 |
+
with cols[idx % 2]:
|
| 493 |
+
st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True)
|
|
|
|
|
|
|
|
|
|
| 494 |
update_gallery()
|
| 495 |
|
| 496 |
st.sidebar.subheader("Model Management 🗂️")
|
|
|
|
| 514 |
st.sidebar.subheader("History 📜")
|
| 515 |
history_container = st.sidebar.empty()
|
| 516 |
with history_container:
|
| 517 |
+
for entry in st.session_state['history'][-gallery_size * 2:]: # Limit by gallery size
|
|
|
|
| 518 |
st.write(entry)
|
| 519 |
|
| 520 |
# Tabs
|
|
|
|
| 533 |
filename = generate_filename("cam0")
|
| 534 |
with open(filename, "wb") as f:
|
| 535 |
f.write(cam0_img.getvalue())
|
| 536 |
+
entry = f"Snapshot from Cam 0: {filename}"
|
| 537 |
+
if entry not in st.session_state['history']:
|
| 538 |
+
st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 0:")] + [entry]
|
| 539 |
st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
|
| 540 |
logger.info(f"Saved snapshot from Camera 0: {filename}")
|
| 541 |
update_gallery()
|
|
|
|
| 545 |
filename = generate_filename("cam1")
|
| 546 |
with open(filename, "wb") as f:
|
| 547 |
f.write(cam1_img.getvalue())
|
| 548 |
+
entry = f"Snapshot from Cam 1: {filename}"
|
| 549 |
+
if entry not in st.session_state['history']:
|
| 550 |
+
st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 1:")] + [entry]
|
| 551 |
st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
|
| 552 |
logger.info(f"Saved snapshot from Camera 1: {filename}")
|
| 553 |
update_gallery()
|
|
|
|
| 563 |
pdf_path = generate_filename("downloaded", "pdf")
|
| 564 |
if download_pdf(url, pdf_path):
|
| 565 |
logger.info(f"Downloaded PDF from {url} to {pdf_path}")
|
| 566 |
+
entry = f"Downloaded PDF: {pdf_path}"
|
| 567 |
+
if entry not in st.session_state['history']:
|
| 568 |
+
st.session_state['history'].append(entry)
|
| 569 |
snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode.lower().replace(" ", "")))
|
| 570 |
for snapshot in snapshots:
|
| 571 |
st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
|
|
|
|
| 587 |
builder.save_model(config.model_path)
|
| 588 |
st.session_state['builder'] = builder
|
| 589 |
st.session_state['model_loaded'] = True
|
| 590 |
+
entry = f"Built {model_type} model: {model_name}"
|
| 591 |
+
if entry not in st.session_state['history']:
|
| 592 |
+
st.session_state['history'].append(entry)
|
| 593 |
st.success(f"Model downloaded and saved to {config.model_path}! 🎉")
|
| 594 |
st.rerun()
|
| 595 |
|
|
|
|
| 624 |
st.session_state['builder'].save_model(new_config.model_path)
|
| 625 |
zip_path = f"{new_config.model_path}.zip"
|
| 626 |
zip_directory(new_config.model_path, zip_path)
|
| 627 |
+
entry = f"Fine-tuned Causal LM: {new_model_name}"
|
| 628 |
+
if entry not in st.session_state['history']:
|
| 629 |
+
st.session_state['history'].append(entry)
|
| 630 |
st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)
|
| 631 |
st.rerun()
|
| 632 |
elif isinstance(st.session_state['builder'], DiffusionBuilder):
|
| 633 |
+
captured_files = get_gallery_files(["png"])
|
| 634 |
if len(captured_files) >= 2:
|
| 635 |
+
demo_data = [{"image": img, "text": f"Superhero {os.path.basename(img).split('.')[0]}"} for img in captured_files]
|
| 636 |
edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic")
|
| 637 |
if st.button("Fine-Tune with Dataset 🔄"):
|
| 638 |
images = [Image.open(row["image"]) for _, row in edited_data.iterrows()]
|
|
|
|
| 644 |
st.session_state['builder'].save_model(new_config.model_path)
|
| 645 |
zip_path = f"{new_config.model_path}.zip"
|
| 646 |
zip_directory(new_config.model_path, zip_path)
|
| 647 |
+
entry = f"Fine-tuned Diffusion: {new_model_name}"
|
| 648 |
+
if entry not in st.session_state['history']:
|
| 649 |
+
st.session_state['history'].append(entry)
|
| 650 |
st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
|
| 651 |
csv_path = f"sft_dataset_{int(time.time())}.csv"
|
| 652 |
with open(csv_path, "w", newline="") as f:
|
| 653 |
writer = csv.writer(f)
|
| 654 |
+
writer.writerow(["image", "text()])
|
| 655 |
for _, row in edited_data.iterrows():
|
| 656 |
writer.writerow([row["image"], row["text"]])
|
| 657 |
st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
|
|
|
|
| 678 |
if st.button("Run Test ▶️"):
|
| 679 |
status_container = st.empty()
|
| 680 |
result = st.session_state['builder'].evaluate(test_prompt, status_container)
|
| 681 |
+
entry = f"Causal LM Test: {test_prompt} -> {result}"
|
| 682 |
+
if entry not in st.session_state['history']:
|
| 683 |
+
st.session_state['history'].append(entry)
|
| 684 |
st.write(f"**Generated Response**: {result}")
|
| 685 |
status_container.empty()
|
| 686 |
elif isinstance(st.session_state['builder'], DiffusionBuilder):
|
|
|
|
| 689 |
image = st.session_state['builder'].generate(test_prompt)
|
| 690 |
output_file = generate_filename("diffusion_test", "png")
|
| 691 |
image.save(output_file)
|
| 692 |
+
entry = f"Diffusion Test: {test_prompt} -> {output_file}"
|
| 693 |
+
if entry not in st.session_state['history']:
|
| 694 |
+
st.session_state['history'].append(entry)
|
| 695 |
st.image(image, caption="Generated Image")
|
| 696 |
update_gallery()
|
| 697 |
|
|
|
|
| 705 |
agent = PartyPlannerAgent(st.session_state['builder'].model, st.session_state['builder'].tokenizer)
|
| 706 |
task = "Plan a luxury superhero-themed party at Wayne Manor."
|
| 707 |
plan_df = agent.plan_party(task)
|
| 708 |
+
entry = f"NLP RAG Demo: Planned party at Wayne Manor"
|
| 709 |
+
if entry not in st.session_state['history']:
|
| 710 |
+
st.session_state['history'].append(entry)
|
| 711 |
st.dataframe(plan_df)
|
| 712 |
elif isinstance(st.session_state['builder'], DiffusionBuilder):
|
| 713 |
if st.button("Run CV RAG Demo 🎉"):
|
| 714 |
agent = CVPartyPlannerAgent(st.session_state['builder'].pipeline)
|
| 715 |
task = "Generate images for a luxury superhero-themed party."
|
| 716 |
plan_df = agent.plan_party(task)
|
| 717 |
+
entry = f"CV RAG Demo: Generated party images"
|
| 718 |
+
if entry not in st.session_state['history']:
|
| 719 |
+
st.session_state['history'].append(entry)
|
| 720 |
st.dataframe(plan_df)
|
| 721 |
for _, row in plan_df.iterrows():
|
| 722 |
image = agent.generate(row["Image Idea"])
|
| 723 |
output_file = generate_filename(f"cv_rag_{row['Theme'].lower()}", "png")
|
| 724 |
image.save(output_file)
|
|
|
|
| 725 |
st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
|
| 726 |
update_gallery()
|
| 727 |
|
| 728 |
with tab7:
|
| 729 |
st.header("Test OCR 🔍")
|
| 730 |
+
captured_files = get_gallery_files(["png"])
|
| 731 |
if captured_files:
|
| 732 |
+
selected_file = st.selectbox("Select Image", captured_files, key="ocr_select")
|
| 733 |
if selected_file:
|
| 734 |
image = Image.open(selected_file)
|
| 735 |
st.image(image, caption="Input Image", use_container_width=True)
|
|
|
|
| 737 |
output_file = generate_filename("ocr_output", "txt")
|
| 738 |
st.session_state['processing']['ocr'] = True
|
| 739 |
result = asyncio.run(process_ocr(image, output_file))
|
| 740 |
+
entry = f"OCR Test: {selected_file} -> {output_file}"
|
| 741 |
+
if entry not in st.session_state['history']:
|
| 742 |
+
st.session_state['history'].append(entry)
|
| 743 |
st.text_area("OCR Result", result, height=200, key="ocr_result")
|
| 744 |
st.success(f"OCR output saved to {output_file}")
|
| 745 |
st.session_state['processing']['ocr'] = False
|
|
|
|
| 748 |
|
| 749 |
with tab8:
|
| 750 |
st.header("Test Image Gen 🎨")
|
| 751 |
+
captured_files = get_gallery_files(["png"])
|
| 752 |
if captured_files:
|
| 753 |
+
selected_file = st.selectbox("Select Image", captured_files, key="gen_select")
|
| 754 |
if selected_file:
|
| 755 |
image = Image.open(selected_file)
|
| 756 |
st.image(image, caption="Reference Image", use_container_width=True)
|
|
|
|
| 759 |
output_file = generate_filename("gen_output", "png")
|
| 760 |
st.session_state['processing']['gen'] = True
|
| 761 |
result = asyncio.run(process_image_gen(prompt, output_file))
|
| 762 |
+
entry = f"Image Gen Test: {prompt} -> {output_file}"
|
| 763 |
+
if entry not in st.session_state['history']:
|
| 764 |
+
st.session_state['history'].append(entry)
|
| 765 |
st.image(result, caption="Generated Image", use_container_width=True)
|
| 766 |
st.success(f"Image saved to {output_file}")
|
| 767 |
st.session_state['processing']['gen'] = False
|
|
|
|
| 771 |
with tab9:
|
| 772 |
st.header("Custom Diffusion 🎨🤓")
|
| 773 |
st.write("Unleash your inner artist with our tiny diffusion models!")
|
| 774 |
+
captured_files = get_gallery_files(["png"])
|
| 775 |
if captured_files:
|
| 776 |
st.subheader("Select Images to Train")
|
| 777 |
+
selected_files = st.multiselect("Pick Images", captured_files, key="diffusion_select")
|
| 778 |
images = [Image.open(file) for file in selected_files]
|
| 779 |
|
| 780 |
model_options = [
|
|
|
|
| 795 |
builder.load_model(model_name)
|
| 796 |
result = builder.generate("A superhero scene inspired by captured images")
|
| 797 |
result.save(output_file)
|
| 798 |
+
entry = f"Custom Diffusion: {model_choice} -> {output_file}"
|
| 799 |
+
if entry not in st.session_state['history']:
|
| 800 |
+
st.session_state['history'].append(entry)
|
| 801 |
st.image(result, caption=f"{model_choice} Masterpiece", use_container_width=True)
|
| 802 |
st.success(f"Image saved to {output_file}")
|
| 803 |
st.session_state['processing']['diffusion'] = False
|