#!/usr/bin/env python3
# This version has the image captures working finally using the streamlit camera input which was only thing that worked
# Now that image inputs are in, working on readding the LM components missed and completing the CV diffusion parts next.
import os
import glob
import base64
import streamlit as st
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import csv
import time
from dataclasses import dataclass
from typing import Optional, Tuple
import zipfile
import math
from PIL import Image
import random
import logging
import numpy as np

# Logging setup with a custom buffer
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
log_records = []  # Custom list to store log records

class LogCaptureHandler(logging.Handler):
    def emit(self, record):
        log_records.append(record)

logger.addHandler(LogCaptureHandler())

# Page Configuration
st.set_page_config(
    page_title="SFT Tiny Titans ๐Ÿš€",
    page_icon="๐Ÿค–",
    layout="wide",
    initial_sidebar_state="expanded",
    menu_items={
        'Get Help': 'https://huggingface.co/awacke1',
        'Report a Bug': 'https://huggingface.co/spaces/awacke1',
        'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! ๐ŸŒŒ"
    }
)

# Initialize st.session_state
if 'captured_images' not in st.session_state:
    st.session_state['captured_images'] = []
if 'builder' not in st.session_state:
    st.session_state['builder'] = None
if 'model_loaded' not in st.session_state:
    st.session_state['model_loaded'] = False

# Model Configuration Classes
@dataclass
class ModelConfig:
    name: str
    base_model: str
    size: str
    domain: Optional[str] = None
    model_type: str = "causal_lm"
    @property
    def model_path(self):
        return f"models/{self.name}"

@dataclass
class DiffusionConfig:
    name: str
    base_model: str
    size: str
    @property
    def model_path(self):
        return f"diffusion_models/{self.name}"

# Datasets
class SFTDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        prompt = self.data[idx]["prompt"]
        response = self.data[idx]["response"]
        full_text = f"{prompt} {response}"
        full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
        prompt_encoding = self.tokenizer(prompt, max_length=self.max_length, padding=False, truncation=True, return_tensors="pt")
        input_ids = full_encoding["input_ids"].squeeze()
        attention_mask = full_encoding["attention_mask"].squeeze()
        labels = input_ids.clone()
        prompt_len = prompt_encoding["input_ids"].shape[1]
        if prompt_len < self.max_length:
            labels[:prompt_len] = -100
        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

class DiffusionDataset(Dataset):
    def __init__(self, images, texts):
        self.images = images
        self.texts = texts
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        return {"image": self.images[idx], "text": self.texts[idx]}

# Model Builders
class ModelBuilder:
    def __init__(self):
        self.config = None
        self.model = None
        self.tokenizer = None
        self.sft_data = None
        self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! ๐Ÿ˜‚", "Training complete! Time for a binary coffee break. โ˜•"]
    def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
        with st.spinner(f"Loading {model_path}... โณ"):
            self.model = AutoModelForCausalLM.from_pretrained(model_path)
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            if config:
                self.config = config
            self.model.to("cuda" if torch.cuda.is_available() else "cpu")
        st.success(f"Model loaded! ๐ŸŽ‰ {random.choice(self.jokes)}")
        return self
    def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
        self.sft_data = []
        with open(csv_path, "r") as f:
            reader = csv.DictReader(f)
            for row in reader:
                self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
        dataset = SFTDataset(self.sft_data, self.tokenizer)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
        self.model.train()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)
        for epoch in range(epochs):
            with st.spinner(f"Training epoch {epoch + 1}/{epochs}... โš™๏ธ"):
                total_loss = 0
                for batch in dataloader:
                    optimizer.zero_grad()
                    input_ids = batch["input_ids"].to(device)
                    attention_mask = batch["attention_mask"].to(device)
                    labels = batch["labels"].to(device)
                    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
        st.success(f"SFT Fine-tuning completed! ๐ŸŽ‰ {random.choice(self.jokes)}")
        return self
    def save_model(self, path: str):
        with st.spinner("Saving model... ๐Ÿ’พ"):
            os.makedirs(os.path.dirname(path), exist_ok=True)
            self.model.save_pretrained(path)
            self.tokenizer.save_pretrained(path)
        st.success(f"Model saved at {path}! โœ…")
    def evaluate(self, prompt: str, status_container=None):
        self.model.eval()
        if status_container:
            status_container.write("Preparing to evaluate... ๐Ÿง ")
        try:
            with torch.no_grad():
                inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
                outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
                return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        except Exception as e:
            if status_container:
                status_container.error(f"Oops! Something broke: {str(e)} ๐Ÿ’ฅ")
            return f"Error: {str(e)}"

class DiffusionBuilder:
    def __init__(self):
        self.config = None
        self.pipeline = None
    def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None):
        from diffusers import StableDiffusionPipeline
        with st.spinner(f"Loading diffusion model {model_path}... โณ"):
            self.pipeline = StableDiffusionPipeline.from_pretrained(model_path)
            self.pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
            if config:
                self.config = config
        st.success(f"Diffusion model loaded! ๐ŸŽจ")
        return self
    def fine_tune_sft(self, images, texts, epochs=3):
        dataset = DiffusionDataset(images, texts)
        dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
        optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
        self.pipeline.unet.train()
        for epoch in range(epochs):
            with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... โš™๏ธ"):
                total_loss = 0
                for batch in dataloader:
                    optimizer.zero_grad()
                    image = batch["image"][0].to(self.pipeline.device)
                    text = batch["text"][0]
                    latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(self.pipeline.device)).latent_dist.sample()
                    noise = torch.randn_like(latents)
                    timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
                    noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
                    text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0]
                    pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
                    loss = torch.nn.functional.mse_loss(pred_noise, noise)
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
        st.success("Diffusion SFT Fine-tuning completed! ๐ŸŽจ")
        return self
    def save_model(self, path: str):
        with st.spinner("Saving diffusion model... ๐Ÿ’พ"):
            os.makedirs(os.path.dirname(path), exist_ok=True)
            self.pipeline.save_pretrained(path)
        st.success(f"Diffusion model saved at {path}! โœ…")
    def generate(self, prompt: str):
        return self.pipeline(prompt, num_inference_steps=50).images[0]

# Utility Functions
def generate_filename(sequence, ext="png"):
    from datetime import datetime
    import pytz
    central = pytz.timezone('US/Central')
    timestamp = datetime.now(central).strftime("%d%m%Y%H%M%S%p")
    return f"{sequence}{timestamp}.{ext}"

def get_download_link(file_path, mime_type="text/plain", label="Download"):
    with open(file_path, 'rb') as f:
        data = f.read()
    b64 = base64.b64encode(data).decode()
    return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} ๐Ÿ“ฅ</a>'

def zip_directory(directory_path, zip_path):
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, _, files in os.walk(directory_path):
            for file in files:
                zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.dirname(directory_path)))

def get_model_files(model_type="causal_lm"):
    path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
    return [d for d in glob.glob(path) if os.path.isdir(d)]

def get_gallery_files(file_types):
    return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])

def update_gallery():
    media_files = get_gallery_files(["png"])
    if media_files:
        cols = st.sidebar.columns(2)
        for idx, file in enumerate(media_files[:gallery_size * 2]):
            with cols[idx % 2]:
                st.image(Image.open(file), caption=file, use_container_width=True)
                st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)

# Mock Search Tool for RAG
def mock_search(query: str) -> str:
    if "superhero" in query.lower():
        return "Latest trends: Gold-plated Batman statues, VR superhero battles."
    return "No relevant results found."

class PartyPlannerAgent:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
    def generate(self, prompt: str) -> str:
        self.model.eval()
        with torch.no_grad():
            inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.device)
            outputs = self.model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.7)
            return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
    def plan_party(self, task: str) -> pd.DataFrame:
        search_result = mock_search("superhero party trends")
        prompt = f"Given this context: '{search_result}'\n{task}"
        plan_text = self.generate(prompt)
        locations = {"Wayne Manor": (42.3601, -71.0589), "New York": (40.7128, -74.0060)}
        wayne_coords = locations["Wayne Manor"]
        travel_times = {loc: calculate_cargo_travel_time(coords, wayne_coords) for loc, coords in locations.items() if loc != "Wayne Manor"}
        data = [
            {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gold-plated Batman statues"},
            {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Luxury Idea": "VR superhero battles"}
        ]
        return pd.DataFrame(data)

class CVPartyPlannerAgent:
    def __init__(self, pipeline):
        self.pipeline = pipeline
    def generate(self, prompt: str) -> Image.Image:
        return self.pipeline(prompt, num_inference_steps=50).images[0]
    def plan_party(self, task: str) -> pd.DataFrame:
        search_result = mock_search("superhero party trends")
        prompt = f"Given this context: '{search_result}'\n{task}"
        data = [
            {"Theme": "Batman", "Image Idea": "Gold-plated Batman statue"},
            {"Theme": "Avengers", "Image Idea": "VR superhero battle scene"}
        ]
        return pd.DataFrame(data)

def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
    def to_radians(degrees: float) -> float:
        return degrees * (math.pi / 180)
    lat1, lon1 = map(to_radians, origin_coords)
    lat2, lon2 = map(to_radians, destination_coords)
    EARTH_RADIUS_KM = 6371.0
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
    c = 2 * math.asin(math.sqrt(a))
    distance = EARTH_RADIUS_KM * c
    actual_distance = distance * 1.1
    flight_time = (actual_distance / cruising_speed_kmh) + 1.0
    return round(flight_time, 2)

# Main App
st.title("SFT Tiny Titans ๐Ÿš€ (Small but Mighty!)")

# Sidebar Galleries
st.sidebar.header("Media Gallery ๐ŸŽจ")
gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 4)
update_gallery()

st.sidebar.subheader("Model Management ๐Ÿ—‚๏ธ")
model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
model_dirs = get_model_files("causal_lm" if model_type == "Causal LM" else "diffusion")
selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
if selected_model != "None" and st.sidebar.button("Load Model ๐Ÿ“‚"):
    builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
    config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
    builder.load_model(selected_model, config)
    st.session_state['builder'] = builder
    st.session_state['model_loaded'] = True
    st.rerun()

# Tabs
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Build Titan ๐ŸŒฑ", "Camera Snap ๐Ÿ“ท", "Fine-Tune Titan ๐Ÿ”ง", "Test Titan ๐Ÿงช", "Agentic RAG Party ๐ŸŒ"])

with tab1:
    st.header("Build Titan ๐ŸŒฑ")
    model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
    base_model = st.selectbox("Select Tiny Model", 
        ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else 
        ["stabilityai/stable-diffusion-2-base", "runwayml/stable-diffusion-v1-5"])
    model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
    if st.button("Download Model โฌ‡๏ธ"):
        config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small")
        builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
        builder.load_model(base_model, config)
        builder.save_model(config.model_path)
        st.session_state['builder'] = builder
        st.session_state['model_loaded'] = True
        st.rerun()

with tab2:
    st.header("Camera Snap ๐Ÿ“ท (Dual Capture!)")
    slice_count = st.number_input("Image Slice Count", min_value=1, max_value=20, value=10)
    video_length = st.number_input("Video Length (seconds)", min_value=1, max_value=30, value=10)
    cols = st.columns(2)
    with cols[0]:
        st.subheader("Camera 0")
        cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
        if cam0_img:
            filename = generate_filename(0)
            with open(filename, "wb") as f:
                f.write(cam0_img.getvalue())
            st.image(Image.open(filename), caption=filename, use_container_width=True)
            logger.info(f"Saved snapshot from Camera 0: {filename}")
            st.session_state['captured_images'].append(filename)
            update_gallery()
        if st.button(f"Capture {slice_count} Frames - Cam 0 ๐Ÿ“ธ"):
            st.session_state['cam0_frames'] = []
            for i in range(slice_count):
                img = st.camera_input(f"Frame {i} - Cam 0", key=f"cam0_frame_{i}_{time.time()}")
                if img:
                    filename = generate_filename(f"0_{i}")
                    with open(filename, "wb") as f:
                        f.write(img.getvalue())
                    st.session_state['cam0_frames'].append(filename)
                    logger.info(f"Saved frame {i} from Camera 0: {filename}")
                    time.sleep(1.0 / slice_count)
            st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
            update_gallery()
            for frame in st.session_state['cam0_frames']:
                st.image(Image.open(frame), caption=frame, use_container_width=True)
    with cols[1]:
        st.subheader("Camera 1")
        cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
        if cam1_img:
            filename = generate_filename(1)
            with open(filename, "wb") as f:
                f.write(cam1_img.getvalue())
            st.image(Image.open(filename), caption=filename, use_container_width=True)
            logger.info(f"Saved snapshot from Camera 1: {filename}")
            st.session_state['captured_images'].append(filename)
            update_gallery()
        if st.button(f"Capture {slice_count} Frames - Cam 1 ๐Ÿ“ธ"):
            st.session_state['cam1_frames'] = []
            for i in range(slice_count):
                img = st.camera_input(f"Frame {i} - Cam 1", key=f"cam1_frame_{i}_{time.time()}")
                if img:
                    filename = generate_filename(f"1_{i}")
                    with open(filename, "wb") as f:
                        f.write(img.getvalue())
                    st.session_state['cam1_frames'].append(filename)
                    logger.info(f"Saved frame {i} from Camera 1: {filename}")
                    time.sleep(1.0 / slice_count)
            st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
            update_gallery()
            for frame in st.session_state['cam1_frames']:
                st.image(Image.open(frame), caption=frame, use_container_width=True)

with tab3:
    st.header("Fine-Tune Titan ๐Ÿ”ง")
    if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
        st.warning("Please build or load a Titan first! โš ๏ธ")
    else:
        if isinstance(st.session_state['builder'], ModelBuilder):
            uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
            if uploaded_csv and st.button("Fine-Tune with Uploaded CSV ๐Ÿ”„"):
                csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
                with open(csv_path, "wb") as f:
                    f.write(uploaded_csv.read())
                new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
                new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
                st.session_state['builder'].config = new_config
                st.session_state['builder'].fine_tune_sft(csv_path)
                st.session_state['builder'].save_model(new_config.model_path)
                zip_path = f"{new_config.model_path}.zip"
                zip_directory(new_config.model_path, zip_path)
                st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)
        elif isinstance(st.session_state['builder'], DiffusionBuilder):
            captured_images = get_gallery_files(["png"])
            if len(captured_images) >= 2:
                demo_data = [{"image": img, "text": f"Superhero {os.path.basename(img).split('.')[0]}"} for img in captured_images[:min(len(captured_images), slice_count)]]
                edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic")
                if st.button("Fine-Tune with Dataset ๐Ÿ”„"):
                    images = [Image.open(row["image"]) for _, row in edited_data.iterrows()]
                    texts = [row["text"] for _, row in edited_data.iterrows()]
                    new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
                    new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
                    st.session_state['builder'].config = new_config
                    st.session_state['builder'].fine_tune_sft(images, texts)
                    st.session_state['builder'].save_model(new_config.model_path)
                    zip_path = f"{new_config.model_path}.zip"
                    zip_directory(new_config.model_path, zip_path)
                    st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
                csv_path = f"sft_dataset_{int(time.time())}.csv"
                with open(csv_path, "w", newline="") as f:
                    writer = csv.writer(f)
                    writer.writerow(["image", "text"])
                    for _, row in edited_data.iterrows():
                        writer.writerow([row["image"], row["text"]])
                st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)

with tab4:
    st.header("Test Titan ๐Ÿงช")
    if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
        st.warning("Please build or load a Titan first! โš ๏ธ")
    else:
        if isinstance(st.session_state['builder'], ModelBuilder):
            test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
            if st.button("Run Test โ–ถ๏ธ"):
                result = st.session_state['builder'].evaluate(test_prompt)
                st.write(f"**Generated Response**: {result}")
        elif isinstance(st.session_state['builder'], DiffusionBuilder):
            test_prompt = st.text_area("Enter Test Prompt", "Neon Batman")
            if st.button("Run Test โ–ถ๏ธ"):
                image = st.session_state['builder'].generate(test_prompt)
                st.image(image, caption="Generated Image")

with tab5:
    st.header("Agentic RAG Party ๐ŸŒ")
    if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
        st.warning("Please build or load a Titan first! โš ๏ธ")
    else:
        if isinstance(st.session_state['builder'], ModelBuilder):
            if st.button("Run NLP RAG Demo ๐ŸŽ‰"):
                agent = PartyPlannerAgent(st.session_state['builder'].model, st.session_state['builder'].tokenizer)
                task = "Plan a luxury superhero-themed party at Wayne Manor."
                plan_df = agent.plan_party(task)
                st.dataframe(plan_df)
        elif isinstance(st.session_state['builder'], DiffusionBuilder):
            if st.button("Run CV RAG Demo ๐ŸŽ‰"):
                agent = CVPartyPlannerAgent(st.session_state['builder'].pipeline)
                task = "Generate images for a luxury superhero-themed party."
                plan_df = agent.plan_party(task)
                st.dataframe(plan_df)
                for _, row in plan_df.iterrows():
                    image = agent.generate(row["Image Idea"])
                    st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")

# Display Logs
st.sidebar.subheader("Action Logs ๐Ÿ“œ")
log_container = st.sidebar.empty()
with log_container:
    for record in log_records:
        st.write(f"{record.asctime} - {record.levelname} - {record.message}")

# Initial Gallery Update
update_gallery()