Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import os | |
| import glob | |
| import base64 | |
| import time | |
| import shutil | |
| import zipfile | |
| import re | |
| import logging | |
| import asyncio | |
| import random # Added for ModelBuilder jokes | |
| from io import BytesIO | |
| from datetime import datetime | |
| import pytz | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| import fitz | |
| import requests | |
| import aiofiles # Added for async file operations | |
| from PIL import Image | |
| from diffusers import StableDiffusionPipeline | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel | |
| # --- OpenAI Setup (for GPT related features) --- | |
| import openai | |
| openai.api_key = os.getenv('OPENAI_API_KEY') | |
| openai.organization = os.getenv('OPENAI_ORG_ID') | |
| # --- Logging --- | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| log_records = [] | |
| class LogCaptureHandler(logging.Handler): | |
| def emit(self, record): | |
| log_records.append(record) | |
| logger.addHandler(LogCaptureHandler()) | |
| # --- Streamlit Page Config --- | |
| st.set_page_config( | |
| page_title="AI Vision & SFT Titans 🚀", | |
| page_icon="🤖", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| menu_items={ | |
| 'Get Help': 'https://huggingface.co/awacke1', | |
| 'Report a Bug': 'https://huggingface.co/spaces/awacke1', | |
| 'About': "AI Vision & SFT Titans: PDFs, OCR, Image Gen, Line Drawings, Custom Diffusion, and SFT on CPU! 🌌" | |
| } | |
| ) | |
| # --- Session State Defaults --- | |
| if 'history' not in st.session_state: | |
| st.session_state['history'] = [] | |
| if 'builder' not in st.session_state: | |
| st.session_state['builder'] = None | |
| if 'model_loaded' not in st.session_state: | |
| st.session_state['model_loaded'] = False | |
| if 'processing' not in st.session_state: | |
| st.session_state['processing'] = {} | |
| if 'asset_checkboxes' not in st.session_state: | |
| st.session_state['asset_checkboxes'] = {} | |
| if 'downloaded_pdfs' not in st.session_state: | |
| st.session_state['downloaded_pdfs'] = {} | |
| if 'unique_counter' not in st.session_state: | |
| st.session_state['unique_counter'] = 0 | |
| if 'selected_model_type' not in st.session_state: | |
| st.session_state['selected_model_type'] = "Causal LM" | |
| if 'selected_model' not in st.session_state: | |
| st.session_state['selected_model'] = "None" | |
| if 'cam0_file' not in st.session_state: | |
| st.session_state['cam0_file'] = None | |
| if 'cam1_file' not in st.session_state: | |
| st.session_state['cam1_file'] = None | |
| # --- Model & Diffusion DataClasses --- | |
| class ModelConfig: | |
| name: str | |
| base_model: str | |
| size: str | |
| domain: Optional[str] = None | |
| model_type: str = "causal_lm" | |
| def model_path(self): | |
| return f"models/{self.name}" | |
| class DiffusionConfig: | |
| name: str | |
| base_model: str | |
| size: str | |
| domain: Optional[str] = None | |
| def model_path(self): | |
| return f"diffusion_models/{self.name}" | |
| # --- Model Builders --- | |
| class ModelBuilder: | |
| def __init__(self): | |
| self.config = None | |
| self.model = None | |
| self.tokenizer = None | |
| self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! 😂", | |
| "Training complete! Time for a binary coffee break. ☕"] | |
| def load_model(self, model_path: str, config: Optional[ModelConfig] = None): | |
| with st.spinner(f"Loading {model_path}... ⏳"): | |
| self.model = AutoModelForCausalLM.from_pretrained(model_path) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| if config: | |
| self.config = config | |
| self.model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}") | |
| return self | |
| def save_model(self, path: str): | |
| with st.spinner("Saving model... 💾"): | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| self.model.save_pretrained(path) | |
| self.tokenizer.save_pretrained(path) | |
| st.success(f"Model saved at {path}! ✅") | |
| class DiffusionBuilder: | |
| def __init__(self): | |
| self.config = None | |
| self.pipeline = None | |
| def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None): | |
| with st.spinner(f"Loading diffusion model {model_path}... ⏳"): | |
| self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu") | |
| if config: | |
| self.config = config | |
| st.success("Diffusion model loaded! 🎨") | |
| return self | |
| def save_model(self, path: str): | |
| with st.spinner("Saving diffusion model... 💾"): | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| self.pipeline.save_pretrained(path) | |
| st.success(f"Diffusion model saved at {path}! ✅") | |
| def generate(self, prompt: str): | |
| return self.pipeline(prompt, num_inference_steps=20).images[0] | |
| # --- Utility Functions --- | |
| def generate_filename(sequence, ext="png"): | |
| timestamp = time.strftime("%d%m%Y%H%M%S") | |
| return f"{sequence}_{timestamp}.{ext}" | |
| def pdf_url_to_filename(url): | |
| safe_name = re.sub(r'[<>:"/\\|?*]', '_', url) | |
| return f"{safe_name}.pdf" | |
| def get_download_link(file_path, mime_type="application/pdf", label="Download"): | |
| with open(file_path, 'rb') as f: | |
| data = f.read() | |
| b64 = base64.b64encode(data).decode() | |
| return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label}</a>' | |
| def zip_directory(directory_path, zip_path): | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| for root, _, files in os.walk(directory_path): | |
| for file in files: | |
| zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.dirname(directory_path))) | |
| def get_model_files(model_type="causal_lm"): | |
| path = "models/*" if model_type == "causal_lm" else "diffusion_models/*" | |
| dirs = [d for d in glob.glob(path) if os.path.isdir(d)] | |
| return dirs if dirs else ["None"] | |
| def get_gallery_files(file_types=["png", "pdf"]): | |
| return sorted(list(set([f for ext in file_types for f in glob.glob(f"*.{ext}")]))) # Deduplicate files | |
| def get_pdf_files(): | |
| return sorted(glob.glob("*.pdf")) | |
| def download_pdf(url, output_path): | |
| try: | |
| response = requests.get(url, stream=True, timeout=10) | |
| if response.status_code == 200: | |
| with open(output_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| return True | |
| except requests.RequestException as e: | |
| logger.error(f"Failed to download {url}: {e}") | |
| return False | |
| # --- Original PDF Snapshot & OCR Functions --- | |
| async def process_pdf_snapshot(pdf_path, mode="single"): | |
| start_time = time.time() | |
| status = st.empty() | |
| status.text(f"Processing PDF Snapshot ({mode})... (0s)") | |
| try: | |
| doc = fitz.open(pdf_path) | |
| output_files = [] | |
| if mode == "single": | |
| page = doc[0] | |
| pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| output_file = generate_filename("single", "png") | |
| pix.save(output_file) | |
| output_files.append(output_file) | |
| elif mode == "twopage": | |
| for i in range(min(2, len(doc))): | |
| page = doc[i] | |
| pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| output_file = generate_filename(f"twopage_{i}", "png") | |
| pix.save(output_file) | |
| output_files.append(output_file) | |
| elif mode == "allpages": | |
| for i in range(len(doc)): | |
| page = doc[i] | |
| pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| output_file = generate_filename(f"page_{i}", "png") | |
| pix.save(output_file) | |
| output_files.append(output_file) | |
| doc.close() | |
| elapsed = int(time.time() - start_time) | |
| status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!") | |
| update_gallery() | |
| return output_files | |
| except Exception as e: | |
| status.error(f"Failed to process PDF: {str(e)}") | |
| return [] | |
| async def process_ocr(image, output_file): | |
| start_time = time.time() | |
| status = st.empty() | |
| status.text("Processing GOT-OCR2_0... (0s)") | |
| tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True) | |
| model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval() | |
| temp_file = f"temp_{int(time.time())}.png" | |
| image.save(temp_file) | |
| result = model.chat(tokenizer, temp_file, ocr_type='ocr') | |
| os.remove(temp_file) | |
| elapsed = int(time.time() - start_time) | |
| status.text(f"GOT-OCR2_0 completed in {elapsed}s!") | |
| async with aiofiles.open(output_file, "w") as f: | |
| await f.write(result) | |
| update_gallery() | |
| return result | |
| async def process_image_gen(prompt, output_file): | |
| start_time = time.time() | |
| status = st.empty() | |
| status.text("Processing Image Gen... (0s)") | |
| if st.session_state['builder'] and isinstance(st.session_state['builder'], DiffusionBuilder) and st.session_state['builder'].pipeline: | |
| pipeline = st.session_state['builder'].pipeline | |
| else: | |
| pipeline = StableDiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0", torch_dtype=torch.float32).to("cpu") | |
| gen_image = pipeline(prompt, num_inference_steps=20).images[0] | |
| elapsed = int(time.time() - start_time) | |
| status.text(f"Image Gen completed in {elapsed}s!") | |
| gen_image.save(output_file) | |
| update_gallery() | |
| return gen_image | |
| # --- New Function: Process an image (PIL) with a custom prompt using GPT --- | |
| def process_image_with_prompt(image, prompt, model="o3-mini-high"): | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}} | |
| ] | |
| }] | |
| try: | |
| response = openai.ChatCompletion.create(model=model, messages=messages) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"Error processing image with GPT: {str(e)}" | |
| # --- Sidebar Setup --- | |
| st.sidebar.subheader("Gallery Settings") | |
| if 'gallery_size' not in st.session_state: | |
| st.session_state['gallery_size'] = 2 # Default value | |
| st.session_state['gallery_size'] = st.sidebar.slider( | |
| "Gallery Size", | |
| 1, 10, st.session_state['gallery_size'], | |
| key="gallery_size_slider" # Unique key for the slider | |
| ) | |
| # --- Updated Gallery Function --- | |
| def update_gallery(): | |
| all_files = get_gallery_files() | |
| if all_files: | |
| st.sidebar.subheader("Asset Gallery 📸📖") | |
| cols = st.sidebar.columns(2) | |
| for idx, file in enumerate(all_files[:st.session_state['gallery_size']]): | |
| with cols[idx % 2]: | |
| st.session_state['unique_counter'] += 1 | |
| unique_id = st.session_state['unique_counter'] | |
| if file.endswith('.png'): | |
| st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True) | |
| else: | |
| doc = fitz.open(file) | |
| pix = doc[0].get_pixmap(matrix=fitz.Matrix(0.5, 0.5)) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| st.image(img, caption=os.path.basename(file), use_container_width=True) | |
| doc.close() | |
| checkbox_key = f"asset_{file}_{unique_id}" | |
| st.session_state['asset_checkboxes'][file] = st.checkbox( | |
| "Use for SFT/Input", | |
| value=st.session_state['asset_checkboxes'].get(file, False), | |
| key=checkbox_key | |
| ) | |
| mime_type = "image/png" if file.endswith('.png') else "application/pdf" | |
| st.markdown(get_download_link(file, mime_type, "Snag It! 📥"), unsafe_allow_html=True) | |
| if st.button("Zap It! 🗑️", key=f"delete_{file}_{unique_id}"): | |
| os.remove(file) | |
| st.session_state['asset_checkboxes'].pop(file, None) | |
| st.sidebar.success(f"Asset {os.path.basename(file)} vaporized! 💨") | |
| st.rerun() | |
| # Call update_gallery() once initially | |
| update_gallery() | |
| # --- Sidebar Logs & History --- | |
| st.sidebar.subheader("Action Logs 📜") | |
| with st.sidebar: | |
| for record in log_records: | |
| st.write(f"{record.asctime} - {record.levelname} - {record.message}") | |
| st.sidebar.subheader("History 📜") | |
| with st.sidebar: | |
| for entry in st.session_state['history']: | |
| st.write(entry) | |
| # --- Create Tabs --- | |
| tabs = st.tabs([ | |
| "Camera Snap 📷", | |
| "Download PDFs 📥", | |
| "Test OCR 🔍", | |
| "Build Titan 🌱", | |
| "Test Image Gen 🎨", | |
| "PDF Process 📄", | |
| "Image Process 🖼️", | |
| "MD Gallery 📚" | |
| ]) | |
| (tab_camera, tab_download, tab_ocr, tab_build, tab_imggen, tab_pdf_process, tab_image_process, tab_md_gallery) = tabs | |
| # === Tab: Camera Snap === | |
| with tab_camera: | |
| st.header("Camera Snap 📷") | |
| st.subheader("Single Capture") | |
| cols = st.columns(2) | |
| with cols[0]: | |
| cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0") | |
| if cam0_img: | |
| filename = generate_filename("cam0") | |
| if st.session_state['cam0_file'] and os.path.exists(st.session_state['cam0_file']): | |
| os.remove(st.session_state['cam0_file']) | |
| with open(filename, "wb") as f: | |
| f.write(cam0_img.getvalue()) | |
| st.session_state['cam0_file'] = filename | |
| entry = f"Snapshot from Cam 0: {filename}" | |
| if entry not in st.session_state['history']: | |
| st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 0:")] + [entry] | |
| st.image(Image.open(filename), caption="Camera 0", use_container_width=True) | |
| logger.info(f"Saved snapshot from Camera 0: {filename}") | |
| update_gallery() | |
| with cols[1]: | |
| cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1") | |
| if cam1_img: | |
| filename = generate_filename("cam1") | |
| if st.session_state['cam1_file'] and os.path.exists(st.session_state['cam1_file']): | |
| os.remove(st.session_state['cam1_file']) | |
| with open(filename, "wb") as f: | |
| f.write(cam1_img.getvalue()) | |
| st.session_state['cam1_file'] = filename | |
| entry = f"Snapshot from Cam 1: {filename}" | |
| if entry not in st.session_state['history']: | |
| st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 1:")] + [entry] | |
| st.image(Image.open(filename), caption="Camera 1", use_container_width=True) | |
| logger.info(f"Saved snapshot from Camera 1: {filename}") | |
| update_gallery() | |
| # === Tab: Download PDFs === | |
| with tab_download: | |
| st.header("Download PDFs 📥") | |
| if st.button("Examples 📚"): | |
| example_urls = [ | |
| "https://arxiv.org/pdf/2308.03892", | |
| "https://arxiv.org/pdf/1912.01703", | |
| "https://arxiv.org/pdf/2408.11039", | |
| "https://arxiv.org/pdf/2109.10282", | |
| "https://arxiv.org/pdf/2112.10752", | |
| "https://arxiv.org/pdf/2308.11236", | |
| "https://arxiv.org/pdf/1706.03762", | |
| "https://arxiv.org/pdf/2006.11239", | |
| "https://arxiv.org/pdf/2305.11207", | |
| "https://arxiv.org/pdf/2106.09685", | |
| "https://arxiv.org/pdf/2005.11401", | |
| "https://arxiv.org/pdf/2106.10504" | |
| ] | |
| st.session_state['pdf_urls'] = "\n".join(example_urls) | |
| url_input = st.text_area("Enter PDF URLs (one per line)", value=st.session_state.get('pdf_urls', ""), height=200) | |
| if st.button("Robo-Download 🤖"): | |
| urls = url_input.strip().split("\n") | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| total_urls = len(urls) | |
| existing_pdfs = get_pdf_files() | |
| for idx, url in enumerate(urls): | |
| if url: | |
| output_path = pdf_url_to_filename(url) | |
| status_text.text(f"Fetching {idx + 1}/{total_urls}: {os.path.basename(output_path)}...") | |
| if output_path not in existing_pdfs: | |
| if download_pdf(url, output_path): | |
| st.session_state['downloaded_pdfs'][url] = output_path | |
| logger.info(f"Downloaded PDF from {url} to {output_path}") | |
| entry = f"Downloaded PDF: {output_path}" | |
| if entry not in st.session_state['history']: | |
| st.session_state['history'].append(entry) | |
| st.session_state['asset_checkboxes'][output_path] = True | |
| else: | |
| st.error(f"Failed to nab {url} 😿") | |
| else: | |
| st.info(f"Already got {os.path.basename(output_path)}! Skipping... 🐾") | |
| st.session_state['downloaded_pdfs'][url] = output_path | |
| progress_bar.progress((idx + 1) / total_urls) | |
| status_text.text("Robo-Download complete! 🚀") | |
| update_gallery() | |
| mode = st.selectbox("Snapshot Mode", ["Single Page (High-Res)", "Two Pages (High-Res)", "All Pages (High-Res)"], key="download_mode") | |
| if st.button("Snapshot Selected 📸"): | |
| selected_pdfs = [path for path in get_gallery_files() if path.endswith('.pdf') and st.session_state['asset_checkboxes'].get(path, False)] | |
| if selected_pdfs: | |
| for pdf_path in selected_pdfs: | |
| mode_key = {"Single Page (High-Res)": "single", "Two Pages (High-Res)": "twopage", "All Pages (High-Res)": "allpages"}[mode] | |
| snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode_key)) | |
| for snapshot in snapshots: | |
| st.image(Image.open(snapshot), caption=snapshot, use_container_width=True) | |
| st.session_state['asset_checkboxes'][snapshot] = True | |
| update_gallery() | |
| else: | |
| st.warning("No PDFs selected for snapshotting! Check some boxes in the sidebar.") | |
| # === Tab: Test OCR === | |
| with tab_ocr: | |
| st.header("Test OCR 🔍") | |
| all_files = get_gallery_files() | |
| if all_files: | |
| if st.button("OCR All Assets 🚀"): | |
| full_text = "# OCR Results\n\n" | |
| for file in all_files: | |
| if file.endswith('.png'): | |
| image = Image.open(file) | |
| else: | |
| doc = fitz.open(file) | |
| pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| doc.close() | |
| output_file = generate_filename(f"ocr_{os.path.basename(file)}", "txt") | |
| result = asyncio.run(process_ocr(image, output_file)) | |
| full_text += f"## {os.path.basename(file)}\n\n{result}\n\n" | |
| entry = f"OCR Test: {file} -> {output_file}" | |
| if entry not in st.session_state['history']: | |
| st.session_state['history'].append(entry) | |
| md_output_file = f"full_ocr_{int(time.time())}.md" | |
| with open(md_output_file, "w") as f: | |
| f.write(full_text) | |
| st.success(f"Full OCR saved to {md_output_file}") | |
| st.markdown(get_download_link(md_output_file, "text/markdown", "Download Full OCR Markdown"), unsafe_allow_html=True) | |
| selected_file = st.selectbox("Select Image or PDF", all_files, key="ocr_select") | |
| if selected_file: | |
| if selected_file.endswith('.png'): | |
| image = Image.open(selected_file) | |
| else: | |
| doc = fitz.open(selected_file) | |
| pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| doc.close() | |
| st.image(image, caption="Input Image", use_container_width=True) | |
| if st.button("Run OCR 🚀", key="ocr_run"): | |
| output_file = generate_filename("ocr_output", "txt") | |
| st.session_state['processing']['ocr'] = True | |
| result = asyncio.run(process_ocr(image, output_file)) | |
| entry = f"OCR Test: {selected_file} -> {output_file}" | |
| if entry not in st.session_state['history']: | |
| st.session_state['history'].append(entry) | |
| st.text_area("OCR Result", result, height=200, key="ocr_result") | |
| st.success(f"OCR output saved to {output_file}") | |
| st.session_state['processing']['ocr'] = False | |
| if selected_file.endswith('.pdf') and st.button("OCR All Pages 🚀", key="ocr_all_pages"): | |
| doc = fitz.open(selected_file) | |
| full_text = f"# OCR Results for {os.path.basename(selected_file)}\n\n" | |
| for i in range(len(doc)): | |
| pix = doc[i].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| output_file = generate_filename(f"ocr_page_{i}", "txt") | |
| result = asyncio.run(process_ocr(image, output_file)) | |
| full_text += f"## Page {i + 1}\n\n{result}\n\n" | |
| entry = f"OCR Test: {selected_file} Page {i + 1} -> {output_file}" | |
| if entry not in st.session_state['history']: | |
| st.session_state['history'].append(entry) | |
| md_output_file = f"full_ocr_{os.path.basename(selected_file)}_{int(time.time())}.md" | |
| with open(md_output_file, "w") as f: | |
| f.write(full_text) | |
| st.success(f"Full OCR saved to {md_output_file}") | |
| st.markdown(get_download_link(md_output_file, "text/markdown", "Download Full OCR Markdown"), unsafe_allow_html=True) | |
| else: | |
| st.warning("No assets in gallery yet. Use Camera Snap or Download PDFs!") | |
| # === Tab: Build Titan === | |
| with tab_build: | |
| st.header("Build Titan 🌱") | |
| model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type") | |
| base_model = st.selectbox("Select Tiny Model", | |
| ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else | |
| ["OFA-Sys/small-stable-diffusion-v0", "stabilityai/stable-diffusion-2-base"]) | |
| model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}") | |
| domain = st.text_input("Target Domain", "general") | |
| if st.button("Download Model ⬇️"): | |
| config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small", domain=domain) | |
| builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder() | |
| builder.load_model(base_model, config) | |
| builder.save_model(config.model_path) | |
| st.session_state['builder'] = builder | |
| st.session_state['model_loaded'] = True | |
| st.session_state['selected_model_type'] = model_type | |
| st.session_state['selected_model'] = config.model_path | |
| entry = f"Built {model_type} model: {model_name}" | |
| if entry not in st.session_state['history']: | |
| st.session_state['history'].append(entry) | |
| st.success(f"Model downloaded and saved to {config.model_path}! 🎉") | |
| st.rerun() | |
| # === Tab: Test Image Gen === | |
| with tab_imggen: | |
| st.header("Test Image Gen 🎨") | |
| all_files = get_gallery_files() | |
| if all_files: | |
| selected_file = st.selectbox("Select Image or PDF", all_files, key="gen_select") | |
| if selected_file: | |
| if selected_file.endswith('.png'): | |
| image = Image.open(selected_file) | |
| else: | |
| doc = fitz.open(selected_file) | |
| pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| doc.close() | |
| st.image(image, caption="Reference Image", use_container_width=True) | |
| prompt = st.text_area("Prompt", "Generate a neon superhero version of this image", key="gen_prompt") | |
| if st.button("Run Image Gen 🚀", key="gen_run"): | |
| output_file = generate_filename("gen_output", "png") | |
| st.session_state['processing']['gen'] = True | |
| result = asyncio.run(process_image_gen(prompt, output_file)) | |
| entry = f"Image Gen Test: {prompt} -> {output_file}" | |
| if entry not in st.session_state['history']: | |
| st.session_state['history'].append(entry) | |
| st.image(result, caption="Generated Image", use_container_width=True) | |
| st.success(f"Image saved to {output_file}") | |
| st.session_state['processing']['gen'] = False | |
| else: | |
| st.warning("No images or PDFs in gallery yet. Use Camera Snap or Download PDFs!") | |
| update_gallery() | |
| # === New Tab: PDF Process === | |
| with tab_pdf_process: | |
| st.header("PDF Process") | |
| st.subheader("Upload PDFs for GPT-based text extraction") | |
| uploaded_pdfs = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True, key="pdf_process_uploader") | |
| view_mode = st.selectbox("View Mode", ["Single Page", "Double Page"], key="pdf_view_mode") | |
| if st.button("Process Uploaded PDFs", key="process_pdfs"): | |
| combined_text = "" | |
| for pdf_file in uploaded_pdfs: | |
| pdf_bytes = pdf_file.read() | |
| temp_pdf_path = f"temp_{pdf_file.name}" | |
| with open(temp_pdf_path, "wb") as f: | |
| f.write(pdf_bytes) | |
| try: | |
| doc = fitz.open(temp_pdf_path) | |
| st.write(f"Processing {pdf_file.name} with {len(doc)} pages") | |
| if view_mode == "Single Page": | |
| for i, page in enumerate(doc): | |
| pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| st.image(img, caption=f"{pdf_file.name} Page {i+1}") | |
| gpt_text = process_image_with_prompt(img, "Extract the electronic text from image") | |
| combined_text += f"\n## {pdf_file.name} - Page {i+1}\n\n{gpt_text}\n" | |
| else: # Double Page: combine two consecutive pages | |
| pages = list(doc) | |
| for i in range(0, len(pages), 2): | |
| if i+1 < len(pages): | |
| pix1 = pages[i].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| img1 = Image.frombytes("RGB", [pix1.width, pix1.height], pix1.samples) | |
| pix2 = pages[i+1].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| img2 = Image.frombytes("RGB", [pix2.width, pix2.height], pix2.samples) | |
| total_width = img1.width + img2.width | |
| max_height = max(img1.height, img2.height) | |
| combined_img = Image.new("RGB", (total_width, max_height)) | |
| combined_img.paste(img1, (0, 0)) | |
| combined_img.paste(img2, (img1.width, 0)) | |
| st.image(combined_img, caption=f"{pdf_file.name} Pages {i+1}-{i+2}") | |
| gpt_text = process_image_with_prompt(combined_img, "Extract the electronic text from image") | |
| combined_text += f"\n## {pdf_file.name} - Pages {i+1}-{i+2}\n\n{gpt_text}\n" | |
| else: | |
| pix = pages[i].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| st.image(img, caption=f"{pdf_file.name} Page {i+1}") | |
| gpt_text = process_image_with_prompt(img, "Extract the electronic text from image") | |
| combined_text += f"\n## {pdf_file.name} - Page {i+1}\n\n{gpt_text}\n" | |
| doc.close() | |
| except Exception as e: | |
| st.error(f"Error processing {pdf_file.name}: {str(e)}") | |
| finally: | |
| os.remove(temp_pdf_path) | |
| output_filename = generate_filename("processed_pdf", "md") | |
| with open(output_filename, "w", encoding="utf-8") as f: | |
| f.write(combined_text) | |
| st.success(f"PDF processing complete. MD file saved as {output_filename}") | |
| st.markdown(get_download_link(output_filename, "text/markdown", "Download Processed PDF MD"), unsafe_allow_html=True) | |
| # === New Tab: Image Process === | |
| with tab_image_process: | |
| st.header("Image Process") | |
| st.subheader("Upload Images for GPT-based OCR") | |
| prompt_img = st.text_input("Enter prompt for image processing", "Extract the electronic text from image", key="img_process_prompt") | |
| uploaded_images = st.file_uploader("Upload image files", type=["png", "jpg", "jpeg"], accept_multiple_files=True, key="image_process_uploader") | |
| if st.button("Process Uploaded Images", key="process_images"): | |
| combined_text = "" | |
| for img_file in uploaded_images: | |
| try: | |
| img = Image.open(img_file) | |
| st.image(img, caption=img_file.name) | |
| gpt_text = process_image_with_prompt(img, prompt_img) | |
| combined_text += f"\n## {img_file.name}\n\n{gpt_text}\n" | |
| except Exception as e: | |
| st.error(f"Error processing image {img_file.name}: {str(e)}") | |
| output_filename = generate_filename("processed_image", "md") | |
| with open(output_filename, "w", encoding="utf-8") as f: | |
| f.write(combined_text) | |
| st.success(f"Image processing complete. MD file saved as {output_filename}") | |
| st.markdown(get_download_link(output_filename, "text/markdown", "Download Processed Image MD"), unsafe_allow_html=True) | |
| # === New Tab: MD Gallery === | |
| with tab_md_gallery: | |
| st.header("MD Gallery and GPT Processing") | |
| md_files = sorted(glob.glob("*.md")) | |
| if md_files: | |
| st.subheader("Individual File Processing") | |
| cols = st.columns(2) | |
| for idx, md_file in enumerate(md_files): | |
| with cols[idx % 2]: | |
| st.write(md_file) | |
| if st.button(f"Process {md_file}", key=f"process_md_{md_file}"): | |
| try: | |
| with open(md_file, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| prompt_md = "Summarize this into markdown outline with emojis and number the topics 1..12" | |
| messages = [{"role": "user", "content": prompt_md + "\n\n" + content}] | |
| response = openai.ChatCompletion.create(model="o3-mini-high", messages=messages) | |
| result_text = response.choices[0].message.content | |
| st.markdown(result_text) | |
| output_filename = generate_filename(f"processed_{os.path.splitext(md_file)[0]}", "md") | |
| with open(output_filename, "w", encoding="utf-8") as f: | |
| f.write(result_text) | |
| st.markdown(get_download_link(output_filename, "text/markdown", f"Download {output_filename}"), unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(f"Error processing {md_file}: {str(e)}") | |
| st.subheader("Batch Processing") | |
| st.write("Select MD files to combine and process:") | |
| selected_md = {} | |
| for md_file in md_files: | |
| selected_md[md_file] = st.checkbox(md_file, key=f"checkbox_md_{md_file}") | |
| batch_prompt = st.text_input("Enter batch processing prompt", "Summarize this into markdown outline with emojis and number the topics 1..12", key="batch_prompt") | |
| if st.button("Process Selected MD Files", key="process_batch_md"): | |
| combined_content = "" | |
| for md_file, selected in selected_md.items(): | |
| if selected: | |
| try: | |
| with open(md_file, "r", encoding="utf-8") as f: | |
| combined_content += f"\n## {md_file}\n" + f.read() + "\n" | |
| except Exception as e: | |
| st.error(f"Error reading {md_file}: {str(e)}") | |
| if combined_content: | |
| messages = [{"role": "user", "content": batch_prompt + "\n\n" + combined_content}] | |
| try: | |
| response = openai.ChatCompletion.create(model="o3-mini-high", messages=messages) | |
| result_text = response.choices[0].message.content | |
| st.markdown(result_text) | |
| output_filename = generate_filename("batch_processed_md", "md") | |
| with open(output_filename, "w", encoding="utf-8") as f: | |
| f.write(result_text) | |
| st.success(f"Batch processing complete. MD file saved as {output_filename}") | |
| st.markdown(get_download_link(output_filename, "text/markdown", "Download Batch Processed MD"), unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(f"Error processing batch: {str(e)}") | |
| else: | |
| st.warning("No MD files selected.") | |
| else: | |
| st.warning("No MD files found.") |