Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
#!/usr/bin/env python3 | |
# What is working: | |
# Img Gen, PDF Download, | |
# Next: Get multiple PDF upload 2 workflow pages by image through image fly wheel of AI. | |
import os | |
import glob | |
import base64 | |
import time | |
import shutil | |
import streamlit as st | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel | |
from diffusers import StableDiffusionPipeline | |
from torch.utils.data import Dataset, DataLoader | |
import csv | |
import fitz | |
import requests | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import logging | |
import asyncio | |
import aiofiles | |
from io import BytesIO | |
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
import zipfile | |
import math | |
import random | |
import re | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
log_records = [] | |
class LogCaptureHandler(logging.Handler): | |
def emit(self, record): | |
log_records.append(record) | |
logger.addHandler(LogCaptureHandler()) | |
st.set_page_config( | |
page_title="AI Vision & SFT Titans 🚀", | |
page_icon="🤖", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
menu_items={ | |
'Get Help': 'https://huggingface.co/awacke1', | |
'Report a Bug': 'https://huggingface.co/spaces/awacke1', | |
'About': "AI Vision & SFT Titans: PDFs, OCR, Image Gen, Line Drawings, Custom Diffusion, and SFT on CPU! 🌌" | |
} | |
) | |
if 'history' not in st.session_state: | |
st.session_state['history'] = [] | |
if 'builder' not in st.session_state: | |
st.session_state['builder'] = None | |
if 'model_loaded' not in st.session_state: | |
st.session_state['model_loaded'] = False | |
if 'processing' not in st.session_state: | |
st.session_state['processing'] = {} | |
if 'asset_checkboxes' not in st.session_state: | |
st.session_state['asset_checkboxes'] = {} | |
if 'downloaded_pdfs' not in st.session_state: | |
st.session_state['downloaded_pdfs'] = {} | |
if 'unique_counter' not in st.session_state: | |
st.session_state['unique_counter'] = 0 | |
if 'selected_model_type' not in st.session_state: | |
st.session_state['selected_model_type'] = "Causal LM" | |
if 'selected_model' not in st.session_state: | |
st.session_state['selected_model'] = "None" | |
if 'cam0_file' not in st.session_state: | |
st.session_state['cam0_file'] = None | |
if 'cam1_file' not in st.session_state: | |
st.session_state['cam1_file'] = None | |
class ModelConfig: | |
name: str | |
base_model: str | |
size: str | |
domain: Optional[str] = None | |
model_type: str = "causal_lm" | |
def model_path(self): | |
return f"models/{self.name}" | |
class DiffusionConfig: | |
name: str | |
base_model: str | |
size: str | |
domain: Optional[str] = None | |
def model_path(self): | |
return f"diffusion_models/{self.name}" | |
class ModelBuilder: | |
def __init__(self): | |
self.config = None | |
self.model = None | |
self.tokenizer = None | |
self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! 😂", "Training complete! Time for a binary coffee break. ☕"] | |
def load_model(self, model_path: str, config: Optional[ModelConfig] = None): | |
with st.spinner(f"Loading {model_path}... ⏳"): | |
self.model = AutoModelForCausalLM.from_pretrained(model_path) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
if config: | |
self.config = config | |
self.model.to("cuda" if torch.cuda.is_available() else "cpu") | |
st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}") | |
return self | |
def save_model(self, path: str): | |
with st.spinner("Saving model... 💾"): | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
self.model.save_pretrained(path) | |
self.tokenizer.save_pretrained(path) | |
st.success(f"Model saved at {path}! ✅") | |
class DiffusionBuilder: | |
def __init__(self): | |
self.config = None | |
self.pipeline = None | |
def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None): | |
with st.spinner(f"Loading diffusion model {model_path}... ⏳"): | |
self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu") | |
if config: | |
self.config = config | |
st.success(f"Diffusion model loaded! 🎨") | |
return self | |
def save_model(self, path: str): | |
with st.spinner("Saving diffusion model... 💾"): | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
self.pipeline.save_pretrained(path) | |
st.success(f"Diffusion model saved at {path}! ✅") | |
def generate(self, prompt: str): | |
return self.pipeline(prompt, num_inference_steps=20).images[0] | |
def generate_filename(sequence, ext="png"): | |
timestamp = time.strftime("%d%m%Y%H%M%S") | |
return f"{sequence}_{timestamp}.{ext}" | |
def pdf_url_to_filename(url): | |
safe_name = re.sub(r'[<>:"/\\|?*]', '_', url) | |
return f"{safe_name}.pdf" | |
def get_download_link(file_path, mime_type="application/pdf", label="Download"): | |
with open(file_path, 'rb') as f: | |
data = f.read() | |
b64 = base64.b64encode(data).decode() | |
return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label}</a>' | |
def zip_directory(directory_path, zip_path): | |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
for root, _, files in os.walk(directory_path): | |
for file in files: | |
zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.dirname(directory_path))) | |
def get_model_files(model_type="causal_lm"): | |
path = "models/*" if model_type == "causal_lm" else "diffusion_models/*" | |
dirs = [d for d in glob.glob(path) if os.path.isdir(d)] | |
return dirs if dirs else ["None"] | |
def get_gallery_files(file_types=["png", "pdf"]): | |
return sorted(list(set([f for ext in file_types for f in glob.glob(f"*.{ext}")]))) # Deduplicate files | |
def get_pdf_files(): | |
return sorted(glob.glob("*.pdf")) | |
def download_pdf(url, output_path): | |
try: | |
response = requests.get(url, stream=True, timeout=10) | |
if response.status_code == 200: | |
with open(output_path, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
return True | |
except requests.RequestException as e: | |
logger.error(f"Failed to download {url}: {e}") | |
return False | |
async def process_pdf_snapshot(pdf_path, mode="single"): | |
start_time = time.time() | |
status = st.empty() | |
status.text(f"Processing PDF Snapshot ({mode})... (0s)") | |
try: | |
doc = fitz.open(pdf_path) | |
output_files = [] | |
if mode == "single": | |
page = doc[0] | |
pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
output_file = generate_filename("single", "png") | |
pix.save(output_file) | |
output_files.append(output_file) | |
elif mode == "twopage": | |
for i in range(min(2, len(doc))): | |
page = doc[i] | |
pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
output_file = generate_filename(f"twopage_{i}", "png") | |
pix.save(output_file) | |
output_files.append(output_file) | |
elif mode == "allpages": | |
for i in range(len(doc)): | |
page = doc[i] | |
pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
output_file = generate_filename(f"page_{i}", "png") | |
pix.save(output_file) | |
output_files.append(output_file) | |
doc.close() | |
elapsed = int(time.time() - start_time) | |
status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!") | |
update_gallery() | |
return output_files | |
except Exception as e: | |
status.error(f"Failed to process PDF: {str(e)}") | |
return [] | |
async def process_ocr(image, output_file): | |
start_time = time.time() | |
status = st.empty() | |
status.text("Processing GOT-OCR2_0... (0s)") | |
tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True) | |
model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval() | |
# Save image to temporary file since GOT-OCR2_0 expects a file path | |
temp_file = f"temp_{int(time.time())}.png" | |
image.save(temp_file) | |
result = model.chat(tokenizer, temp_file, ocr_type='ocr') | |
os.remove(temp_file) # Clean up temporary file | |
elapsed = int(time.time() - start_time) | |
status.text(f"GOT-OCR2_0 completed in {elapsed}s!") | |
async with aiofiles.open(output_file, "w") as f: | |
await f.write(result) | |
update_gallery() | |
return result | |
async def process_image_gen(prompt, output_file): | |
start_time = time.time() | |
status = st.empty() | |
status.text("Processing Image Gen... (0s)") | |
if st.session_state['builder'] and isinstance(st.session_state['builder'], DiffusionBuilder) and st.session_state['builder'].pipeline: | |
pipeline = st.session_state['builder'].pipeline | |
else: | |
pipeline = StableDiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0", torch_dtype=torch.float32).to("cpu") | |
gen_image = pipeline(prompt, num_inference_steps=20).images[0] | |
elapsed = int(time.time() - start_time) | |
status.text(f"Image Gen completed in {elapsed}s!") | |
gen_image.save(output_file) | |
update_gallery() | |
return gen_image | |
st.title("AI Vision & SFT Titans 🚀") | |
# Sidebar | |
model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"], key="sidebar_model_type", index=0 if st.session_state['selected_model_type'] == "Causal LM" else 1) | |
model_dirs = get_model_files(model_type) | |
if model_dirs and st.session_state['selected_model'] == "None" and "None" not in model_dirs: | |
st.session_state['selected_model'] = model_dirs[0] | |
selected_model = st.sidebar.selectbox("Select Saved Model", model_dirs, key="sidebar_model_select", index=model_dirs.index(st.session_state['selected_model']) if st.session_state['selected_model'] in model_dirs else 0) | |
if selected_model != "None" and st.sidebar.button("Load Model 📂"): | |
builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder() | |
config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small") | |
builder.load_model(selected_model, config) | |
st.session_state['builder'] = builder | |
st.session_state['model_loaded'] = True | |
st.rerun() | |
st.sidebar.header("Captured Files 📜") | |
cols = st.sidebar.columns(2) | |
with cols[0]: | |
if st.button("Zip All 🤐"): | |
zip_path = f"all_assets_{int(time.time())}.zip" | |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
for file in get_gallery_files(): | |
zipf.write(file, os.path.basename(file)) | |
st.sidebar.markdown(get_download_link(zip_path, "application/zip", "Download All Assets"), unsafe_allow_html=True) | |
with cols[1]: | |
if st.button("Zap All! 🗑️"): | |
for file in get_gallery_files(): | |
os.remove(file) | |
st.session_state['asset_checkboxes'].clear() | |
st.session_state['downloaded_pdfs'].clear() | |
st.session_state['cam0_file'] = None | |
st.session_state['cam1_file'] = None | |
st.sidebar.success("All assets vaporized! 💨") | |
st.rerun() | |
gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2) | |
def update_gallery(): | |
all_files = get_gallery_files() | |
if all_files: | |
st.sidebar.subheader("Asset Gallery 📸📖") | |
cols = st.sidebar.columns(2) | |
for idx, file in enumerate(all_files[:gallery_size * 2]): | |
with cols[idx % 2]: | |
st.session_state['unique_counter'] += 1 | |
unique_id = st.session_state['unique_counter'] | |
if file.endswith('.png'): | |
st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True) | |
else: | |
doc = fitz.open(file) | |
pix = doc[0].get_pixmap(matrix=fitz.Matrix(0.5, 0.5)) | |
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
st.image(img, caption=os.path.basename(file), use_container_width=True) | |
doc.close() | |
checkbox_key = f"asset_{file}_{unique_id}" | |
st.session_state['asset_checkboxes'][file] = st.checkbox( | |
"Use for SFT/Input", | |
value=st.session_state['asset_checkboxes'].get(file, False), | |
key=checkbox_key | |
) | |
mime_type = "image/png" if file.endswith('.png') else "application/pdf" | |
st.markdown(get_download_link(file, mime_type, "Snag It! 📥"), unsafe_allow_html=True) | |
if st.button("Zap It! 🗑️", key=f"delete_{file}_{unique_id}"): | |
os.remove(file) | |
if file in st.session_state['asset_checkboxes']: | |
del st.session_state['asset_checkboxes'][file] | |
if file.endswith('.pdf'): | |
url_key = next((k for k, v in st.session_state['downloaded_pdfs'].items() if v == file), None) | |
if url_key: | |
del st.session_state['downloaded_pdfs'][url_key] | |
if file == st.session_state['cam0_file']: | |
st.session_state['cam0_file'] = None | |
if file == st.session_state['cam1_file']: | |
st.session_state['cam1_file'] = None | |
st.sidebar.success(f"Asset {os.path.basename(file)} vaporized! 💨") | |
st.rerun() | |
update_gallery() | |
st.sidebar.subheader("Action Logs 📜") | |
log_container = st.sidebar.empty() | |
with log_container: | |
for record in log_records: | |
st.write(f"{record.asctime} - {record.levelname} - {record.message}") | |
st.sidebar.subheader("History 📜") | |
history_container = st.sidebar.empty() | |
with history_container: | |
for entry in st.session_state['history'][-gallery_size * 2:]: | |
st.write(entry) | |
tab1, tab2, tab3, tab4 = st.tabs([ | |
"Camera Snap 📷", "Download PDFs 📥", "Test OCR 🔍", "Build Titan 🌱" | |
]) | |
with tab1: | |
st.header("Camera Snap 📷") | |
st.subheader("Single Capture") | |
cols = st.columns(2) | |
with cols[0]: | |
cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0") | |
if cam0_img: | |
filename = generate_filename("cam0") | |
if st.session_state['cam0_file'] and os.path.exists(st.session_state['cam0_file']): | |
os.remove(st.session_state['cam0_file']) | |
with open(filename, "wb") as f: | |
f.write(cam0_img.getvalue()) | |
st.session_state['cam0_file'] = filename | |
entry = f"Snapshot from Cam 0: {filename}" | |
if entry not in st.session_state['history']: | |
st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 0:")] + [entry] | |
st.image(Image.open(filename), caption="Camera 0", use_container_width=True) | |
logger.info(f"Saved snapshot from Camera 0: {filename}") | |
update_gallery() | |
elif st.session_state['cam0_file'] and os.path.exists(st.session_state['cam0_file']): | |
st.image(Image.open(st.session_state['cam0_file']), caption="Camera 0", use_container_width=True) | |
with cols[1]: | |
cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1") | |
if cam1_img: | |
filename = generate_filename("cam1") | |
if st.session_state['cam1_file'] and os.path.exists(st.session_state['cam1_file']): | |
os.remove(st.session_state['cam1_file']) | |
with open(filename, "wb") as f: | |
f.write(cam1_img.getvalue()) | |
st.session_state['cam1_file'] = filename | |
entry = f"Snapshot from Cam 1: {filename}" | |
if entry not in st.session_state['history']: | |
st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 1:")] + [entry] | |
st.image(Image.open(filename), caption="Camera 1", use_container_width=True) | |
logger.info(f"Saved snapshot from Camera 1: {filename}") | |
update_gallery() | |
elif st.session_state['cam1_file'] and os.path.exists(st.session_state['cam1_file']): | |
st.image(Image.open(st.session_state['cam1_file']), caption="Camera 1", use_container_width=True) | |
with tab2: | |
st.header("Download PDFs 📥") | |
if st.button("Examples 📚"): | |
example_urls = [ | |
"https://arxiv.org/pdf/2308.03892", | |
"https://arxiv.org/pdf/1912.01703", | |
"https://arxiv.org/pdf/2408.11039", | |
"https://arxiv.org/pdf/2109.10282", | |
"https://arxiv.org/pdf/2112.10752", | |
"https://arxiv.org/pdf/2308.11236", | |
"https://arxiv.org/pdf/1706.03762", | |
"https://arxiv.org/pdf/2006.11239", | |
"https://arxiv.org/pdf/2305.11207", | |
"https://arxiv.org/pdf/2106.09685", | |
"https://arxiv.org/pdf/2005.11401", | |
"https://arxiv.org/pdf/2106.10504" | |
] | |
st.session_state['pdf_urls'] = "\n".join(example_urls) | |
url_input = st.text_area("Enter PDF URLs (one per line)", value=st.session_state.get('pdf_urls', ""), height=200) | |
if st.button("Robo-Download 🤖"): | |
urls = url_input.strip().split("\n") | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
total_urls = len(urls) | |
existing_pdfs = get_pdf_files() | |
for idx, url in enumerate(urls): | |
if url: | |
output_path = pdf_url_to_filename(url) | |
status_text.text(f"Fetching {idx + 1}/{total_urls}: {os.path.basename(output_path)}...") | |
if output_path not in existing_pdfs: | |
if download_pdf(url, output_path): | |
st.session_state['downloaded_pdfs'][url] = output_path | |
logger.info(f"Downloaded PDF from {url} to {output_path}") | |
entry = f"Downloaded PDF: {output_path}" | |
if entry not in st.session_state['history']: | |
st.session_state['history'].append(entry) | |
st.session_state['asset_checkboxes'][output_path] = True # Auto-check the box | |
else: | |
st.error(f"Failed to nab {url} 😿") | |
else: | |
st.info(f"Already got {os.path.basename(output_path)}! Skipping... 🐾") | |
st.session_state['downloaded_pdfs'][url] = output_path | |
progress_bar.progress((idx + 1) / total_urls) | |
status_text.text("Robo-Download complete! 🚀") | |
update_gallery() | |
mode = st.selectbox("Snapshot Mode", ["Single Page (High-Res)", "Two Pages (High-Res)", "All Pages (High-Res)"], key="download_mode") | |
if st.button("Snapshot Selected 📸"): | |
selected_pdfs = [path for path in get_gallery_files() if path.endswith('.pdf') and st.session_state['asset_checkboxes'].get(path, False)] | |
if selected_pdfs: | |
for pdf_path in selected_pdfs: | |
mode_key = {"Single Page (High-Res)": "single", "Two Pages (High-Res)": "twopage", "All Pages (High-Res)": "allpages"}[mode] | |
snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode_key)) | |
for snapshot in snapshots: | |
st.image(Image.open(snapshot), caption=snapshot, use_container_width=True) | |
st.session_state['asset_checkboxes'][snapshot] = True # Auto-check new snapshots | |
update_gallery() | |
else: | |
st.warning("No PDFs selected for snapshotting! Check some boxes in the sidebar gallery.") | |
with tab3: | |
st.header("Test OCR 🔍") | |
all_files = get_gallery_files() | |
if all_files: | |
if st.button("OCR All Assets 🚀"): | |
full_text = "# OCR Results\n\n" | |
for file in all_files: | |
if file.endswith('.png'): | |
image = Image.open(file) | |
else: | |
doc = fitz.open(file) | |
pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
doc.close() | |
output_file = generate_filename(f"ocr_{os.path.basename(file)}", "txt") | |
result = asyncio.run(process_ocr(image, output_file)) | |
full_text += f"## {os.path.basename(file)}\n\n{result}\n\n" | |
entry = f"OCR Test: {file} -> {output_file}" | |
if entry not in st.session_state['history']: | |
st.session_state['history'].append(entry) | |
md_output_file = f"full_ocr_{int(time.time())}.md" | |
with open(md_output_file, "w") as f: | |
f.write(full_text) | |
st.success(f"Full OCR saved to {md_output_file}") | |
st.markdown(get_download_link(md_output_file, "text/markdown", "Download Full OCR Markdown"), unsafe_allow_html=True) | |
selected_file = st.selectbox("Select Image or PDF", all_files, key="ocr_select") | |
if selected_file: | |
if selected_file.endswith('.png'): | |
image = Image.open(selected_file) | |
else: | |
doc = fitz.open(selected_file) | |
pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
doc.close() | |
st.image(image, caption="Input Image", use_container_width=True) | |
if st.button("Run OCR 🚀", key="ocr_run"): | |
output_file = generate_filename("ocr_output", "txt") | |
st.session_state['processing']['ocr'] = True | |
result = asyncio.run(process_ocr(image, output_file)) | |
entry = f"OCR Test: {selected_file} -> {output_file}" | |
if entry not in st.session_state['history']: | |
st.session_state['history'].append(entry) | |
st.text_area("OCR Result", result, height=200, key="ocr_result") | |
st.success(f"OCR output saved to {output_file}") | |
st.session_state['processing']['ocr'] = False | |
if selected_file.endswith('.pdf') and st.button("OCR All Pages 🚀", key="ocr_all_pages"): | |
doc = fitz.open(selected_file) | |
full_text = f"# OCR Results for {os.path.basename(selected_file)}\n\n" | |
for i in range(len(doc)): | |
pix = doc[i].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
output_file = generate_filename(f"ocr_page_{i}", "txt") | |
result = asyncio.run(process_ocr(image, output_file)) | |
full_text += f"## Page {i + 1}\n\n{result}\n\n" | |
entry = f"OCR Test: {selected_file} Page {i + 1} -> {output_file}" | |
if entry not in st.session_state['history']: | |
st.session_state['history'].append(entry) | |
md_output_file = f"full_ocr_{os.path.basename(selected_file)}_{int(time.time())}.md" | |
with open(md_output_file, "w") as f: | |
f.write(full_text) | |
st.success(f"Full OCR saved to {md_output_file}") | |
st.markdown(get_download_link(md_output_file, "text/markdown", "Download Full OCR Markdown"), unsafe_allow_html=True) | |
else: | |
st.warning("No assets in gallery yet. Use Camera Snap or Download PDFs!") | |
with tab4: | |
st.header("Build Titan 🌱") | |
model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type") | |
base_model = st.selectbox("Select Tiny Model", | |
["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else | |
["OFA-Sys/small-stable-diffusion-v0", "stabilityai/stable-diffusion-2-base"]) | |
model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}") | |
domain = st.text_input("Target Domain", "general") | |
if st.button("Download Model ⬇️"): | |
config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small", domain=domain) | |
builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder() | |
builder.load_model(base_model, config) | |
builder.save_model(config.model_path) | |
st.session_state['builder'] = builder | |
st.session_state['model_loaded'] = True | |
st.session_state['selected_model_type'] = model_type | |
st.session_state['selected_model'] = config.model_path | |
entry = f"Built {model_type} model: {model_name}" | |
if entry not in st.session_state['history']: | |
st.session_state['history'].append(entry) | |
st.success(f"Model downloaded and saved to {config.model_path}! 🎉") | |
st.rerun() | |
tab5 = st.tabs(["Test Image Gen 🎨"])[0] | |
with tab5: | |
st.header("Test Image Gen 🎨") | |
all_files = get_gallery_files() | |
if all_files: | |
selected_file = st.selectbox("Select Image or PDF", all_files, key="gen_select") | |
if selected_file: | |
if selected_file.endswith('.png'): | |
image = Image.open(selected_file) | |
else: | |
doc = fitz.open(selected_file) | |
pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
doc.close() | |
st.image(image, caption="Reference Image", use_container_width=True) | |
prompt = st.text_area("Prompt", "Generate a neon superhero version of this image", key="gen_prompt") | |
if st.button("Run Image Gen 🚀", key="gen_run"): | |
output_file = generate_filename("gen_output", "png") | |
st.session_state['processing']['gen'] = True | |
result = asyncio.run(process_image_gen(prompt, output_file)) | |
entry = f"Image Gen Test: {prompt} -> {output_file}" | |
if entry not in st.session_state['history']: | |
st.session_state['history'].append(entry) | |
st.image(result, caption="Generated Image", use_container_width=True) | |
st.success(f"Image saved to {output_file}") | |
st.session_state['processing']['gen'] = False | |
else: | |
st.warning("No images or PDFs in gallery yet. Use Camera Snap or Download PDFs!") | |
update_gallery() |