import gc import os, sys from tqdm import tqdm import numpy as np import json sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # set visible devices to 0 # os.environ["CUDA_VISIBLE_DEVICES"] = "0" # protect from formatting if True: import torch from optimum.quanto import freeze, qfloat8, QTensor, qint4 from diffusers import FluxTransformer2DModel, FluxPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler from toolkit.util.quantize import quantize, get_qtype from transformers import T5EncoderModel, T5TokenizerFast, CLIPTextModel, CLIPTokenizer from torchvision import transforms qtype = "qfloat8" dtype = torch.bfloat16 # base_model_path = "black-forest-labs/FLUX.1-dev" base_model_path = "ostris/Flex.1-alpha" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Loading Transformer...") prompt = "Photo of a man and a woman in a park, sunny day" output_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "output") output_path = os.path.join(output_root, "flex_timestep_weights.json") img_output_path = os.path.join(output_root, "flex_timestep_weights.png") quantization_type = get_qtype(qtype) def flush(): torch.cuda.empty_cache() gc.collect() pil_to_tensor = transforms.ToTensor() with torch.no_grad(): transformer = FluxTransformer2DModel.from_pretrained( base_model_path, subfolder='transformer', torch_dtype=dtype ) transformer.to(device, dtype=dtype) print("Quantizing Transformer...") quantize(transformer, weights=quantization_type) freeze(transformer) flush() print("Loading Scheduler...") scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") print("Loading Autoencoder...") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) vae.to(device, dtype=dtype) flush() print("Loading Text Encoder...") tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) text_encoder_2.to(device, dtype=dtype) print("Quantizing Text Encoder...") quantize(text_encoder_2, weights=get_qtype(qtype)) freeze(text_encoder_2) flush() print("Loading CLIP") text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) text_encoder.to(device, dtype=dtype) print("Making pipe") pipe: FluxPipeline = FluxPipeline( scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=None, tokenizer_2=tokenizer_2, vae=vae, transformer=None, ) pipe.text_encoder_2 = text_encoder_2 pipe.transformer = transformer pipe.to(device, dtype=dtype) print("Encoding prompt...") prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt( prompt, prompt_2=prompt, device=device ) generator = torch.manual_seed(42) height = 1024 width = 1024 print("Generating image...") # Fix a bug in diffusers/torch def callback_on_step_end(pipe, i, t, callback_kwargs): latents = callback_kwargs["latents"] if latents.dtype != dtype: latents = latents.to(dtype) return {"latents": latents} img = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, height=height, width=height, num_inference_steps=30, guidance_scale=3.5, generator=generator, callback_on_step_end=callback_on_step_end, ).images[0] img.save(img_output_path) print(f"Image saved to {img_output_path}") print("Encoding image...") # img is a PIL image. convert it to a -1 to 1 tensor img = pil_to_tensor(img) img = img.unsqueeze(0) # add batch dimension img = img * 2 - 1 # convert to -1 to 1 range img = img.to(device, dtype=dtype) latents = vae.encode(img).latent_dist.sample() shift = vae.config['shift_factor'] if vae.config['shift_factor'] is not None else 0 latents = vae.config['scaling_factor'] * (latents - shift) num_channels_latents = pipe.transformer.config.in_channels // 4 l_height = 2 * (int(height) // (pipe.vae_scale_factor * 2)) l_width = 2 * (int(width) // (pipe.vae_scale_factor * 2)) packed_latents = pipe._pack_latents(latents, 1, num_channels_latents, l_height, l_width) packed_latents, latent_image_ids = pipe.prepare_latents( 1, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, packed_latents, ) print("Calculating timestep weights...") torch.manual_seed(8675309) noise = torch.randn_like(packed_latents, device=device, dtype=dtype) # Create linear timesteps from 1000 to 0 num_train_timesteps = 1000 timesteps_torch = torch.linspace(1000, 1, num_train_timesteps, device='cpu') timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) timestep_weights = torch.zeros(num_train_timesteps, dtype=torch.float32, device=device) guidance = torch.full([1], 1.0, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) pbar = tqdm(range(num_train_timesteps), desc="loss: 0.000000 scaler: 0.0000") for i in pbar: timestep = timesteps[i:i+1].to(device) t_01 = (timestep / 1000).to(device) t_01 = t_01.reshape(-1, 1, 1) noisy_latents = (1.0 - t_01) * packed_latents + t_01 * noise noise_pred = pipe.transformer( hidden_states=noisy_latents, # torch.Size([1, 4096, 64]) timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, return_dict=False, )[0] target = noise - packed_latents loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float()) loss = loss # determine scaler to multiply loss by to make it 1 scaler = 1.0 / (loss + 1e-6) timestep_weights[i] = scaler pbar.set_description(f"loss: {loss.item():.6f} scaler: {scaler.item():.4f}") print("normalizing timestep weights...") # normalize the timestep weights so they are a mean of 1.0 timestep_weights = timestep_weights / timestep_weights.mean() timestep_weights = timestep_weights.cpu().numpy().tolist() print("Saving timestep weights...") with open(output_path, 'w') as f: json.dump(timestep_weights, f) print(f"Timestep weights saved to {output_path}") print("Done!") flush()