import gradio as gr import numpy as np from PIL import Image import os.path as osp import torch import matplotlib.pyplot as plt from omegaconf import OmegaConf from tqdm import tqdm from huggingface_hub import hf_hub_download from semanticist.engine.trainer_utils import instantiate_from_config from semanticist.stage1.diffuse_slot import DiffuseSlot from semanticist.stage2.gpt import GPT_models from semanticist.stage2.generate import generate from safetensors import safe_open from semanticist.utils.datasets import vae_transforms from PIL import Image from imagenet_classes import imagenet_classes transform = vae_transforms('test') def norm_ip(img, low, high): img.clamp_(min=low, max=high) img.sub_(low).div_(max(high - low, 1e-5)) def norm_range(t, value_range): if value_range is not None: norm_ip(t, value_range[0], value_range[1]) else: norm_ip(t, float(t.min()), float(t.max())) from PIL import Image def convert_np(img): ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\ .permute(1, 2, 0).to("cpu", torch.uint8).numpy() return ndarr def convert_PIL(img): ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\ .permute(1, 2, 0).to("cpu", torch.uint8).numpy() img = Image.fromarray(ndarr) return img def norm_slots(slots): mean = torch.mean(slots, dim=-1, keepdim=True) std = torch.std(slots, dim=-1, keepdim=True) return (slots - mean) / std def load_state_dict(state_dict, model): """Helper to load a state dict with proper prefix handling.""" if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] # Remove '_orig_mod' prefix if present state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} missing, unexpected = model.load_state_dict( state_dict, strict=False ) # print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}") def load_safetensors(path, model): """Helper to load a safetensors checkpoint.""" from safetensors.torch import safe_open with safe_open(path, framework="pt", device="cpu") as f: state_dict = {k: f.get_tensor(k) for k in f.keys()} load_state_dict(state_dict, model) def load_checkpoint(ckpt_path, model): if ckpt_path is None or not osp.exists(ckpt_path): return if osp.isdir(ckpt_path): # ckpt_path is something like 'path/to/models/step10/' model_path = osp.join(ckpt_path, "model.safetensors") if osp.exists(model_path): load_safetensors(model_path, model) else: # ckpt_path is something like 'path/to/models/step10.pt' if ckpt_path.endswith(".safetensors"): load_safetensors(ckpt_path, model) else: state_dict = torch.load(ckpt_path, map_location="cpu") load_state_dict(state_dict, model) print(f"Loaded checkpoint from {ckpt_path}") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Is CUDA available: {torch.cuda.is_available()}") if device == 'cuda': print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_ar_gen_L.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/') config_path = 'configs/autoregressive_xl.yaml' cfg = OmegaConf.load(config_path) params = cfg.trainer.params ae_model = instantiate_from_config(params.ae_model).to(device) ae_model_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_tok_XL.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/') load_checkpoint(ae_model_path, ae_model) ae_model.eval() gpt_model = GPT_models[params.gpt_model.target](**params.gpt_model.params).to(device) load_checkpoint(ckpt_path, gpt_model) gpt_model.eval(); def viz_diff_slots(model, slots, nums, cfg=1.0, return_figs=False): n_slots_inf = [] for num_slots_to_inference in nums: drop_mask = model.nested_sampler(slots.shape[0], device, num_slots_to_inference) recon_n = model.sample(slots, drop_mask=drop_mask, cfg=cfg) n_slots_inf.append(recon_n) return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))] num_slots = params.ae_model.params.num_slots slot_dim = params.ae_model.params.slot_dim dtype = torch.bfloat16 # the model is trained with only 32 tokens. num_slots_to_gen = 32 # Function to generate image from class def generate_from_class(class_id, cfg_scale): with torch.no_grad(): dtype = torch.float num_slots_to_gen = 32 with torch.autocast(device, dtype=dtype): slots_gen = generate( gpt_model, torch.tensor([class_id]).to(device), num_slots_to_gen, cfg_scale=cfg_scale, cfg_schedule="linear" ) if num_slots_to_gen < num_slots: null_slots = ae_model.dit.null_cond.expand(slots_gen.shape[0], -1, -1) null_slots = null_slots[:, num_slots_to_gen:, :] slots_gen = torch.cat([slots_gen, null_slots], dim=1) return slots_gen with gr.Blocks() as demo: with gr.Row(): # First column - Input and configs with gr.Column(scale=1): gr.Markdown("## Input") # Replace image input with ImageNet class selection imagenet_classes = {k: v for k, v in enumerate(imagenet_classes)} class_choices = [f"{id}: {name}" for id, name in imagenet_classes.items()] # Dropdown for class selection class_dropdown = gr.Dropdown( choices=class_choices, # Limit for demonstration label="Select ImageNet Class", value=class_choices[0] if class_choices else None ) # Option to enter class ID directly class_id_input = gr.Number( label="Or enter class ID directly (0-999)", value=0, minimum=0, maximum=999, step=1 ) with gr.Group(): gr.Markdown("### Configuration") show_gallery = gr.Checkbox(label="Show Gallery", value=True) slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value") labels_input = gr.Textbox( label="Number of tokens to reconstruct (comma-separated)", value="1, 2, 4, 8, 16", placeholder="Enter comma-separated numbers for the number of slots to use" ) # Second column - Output (conditionally rendered) with gr.Column(scale=1): gr.Markdown("## Output") # Container for conditional rendering with gr.Group(visible=True) as gallery_container: gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True) # Always visible output image output_image = gr.Image(label="Generated Image", type="numpy") # Handle form submission submit_btn = gr.Button("Generate") # Define the processing logic def update_outputs(class_selection, class_id, show_gallery_value, slider_value, labels_text): # Determine which class to use - either from dropdown or direct input if class_selection: # Extract class ID from the dropdown selection selected_class_id = int(class_selection.split(":")[0]) else: selected_class_id = int(class_id) # Update the visibility of the gallery container gallery_container.visible = show_gallery_value try: # Parse the labels from the text input if labels_text and "," in labels_text: labels = [int(label.strip()) for label in labels_text.split(",")] else: # Default labels if none provided or in wrong format labels = [1, 4, 16, 64, 256] except: labels = [1, 4, 16, 64, 256] while len(labels) < 3: labels.append(256) # Generate the image based on the selected class slots_gen = generate_from_class(selected_class_id, cfg_scale=slider_value) recon = viz_diff_slots(ae_model, slots_gen, [32], cfg=slider_value)[0] # Always generate the model decomposition for potential gallery display model_decompose = viz_diff_slots(ae_model, slots_gen, labels, cfg=slider_value) if not show_gallery_value: # If only the image should be shown, return just the processed image return gallery_container, [], recon else: # Create image variations and pair them with labels gallery_images = [ (recon, f'Generated from class {selected_class_id}'), ] + [(img, 'Gen. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)] return gallery_container, gallery_images, recon # Connect the inputs and outputs submit_btn.click( fn=update_outputs, inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input], outputs=[gallery_container, gallery, output_image] ) # Also update when checkbox changes show_gallery.change( fn=lambda value: gr.update(visible=value), inputs=[show_gallery], outputs=[gallery_container] ) # Add examples examples = [ # ["0: tench, Tinca tinca", 0, True, 4.0, "1,2,4,8,16"], ["1: goldfish", 1, True, 4.0, "1,2,4,8,16"], # ["2: great white shark, white shark", 2, True, 4.0, "1,2,4,8,16"], ] gr.Examples( examples=examples, inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input], outputs=[gallery_container, gallery, output_image], fn=update_outputs, cache_examples=False ) # Launch the demo if __name__ == "__main__": demo.launch()