Spaces:
Paused
Paused
import os | |
from typing import TYPE_CHECKING, List, Optional | |
import torch | |
import yaml | |
from toolkit.config_modules import GenerateImageConfig, ModelConfig | |
from toolkit.models.base_model import BaseModel | |
from diffusers import AutoencoderKL | |
from toolkit.basic import flush | |
from toolkit.prompt_utils import PromptEmbeds | |
from toolkit.samplers.custom_flowmatch_sampler import ( | |
CustomFlowMatchEulerDiscreteScheduler, | |
) | |
from toolkit.accelerator import unwrap_model | |
from optimum.quanto import freeze | |
from toolkit.util.quantize import quantize, get_qtype | |
from .src.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline | |
from .src.models.transformers import OmniGen2Transformer2DModel | |
from .src.models.transformers.repo import OmniGen2RotaryPosEmbed | |
from .src.schedulers.scheduling_flow_match_euler_discrete import ( | |
FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler, | |
) | |
from PIL import Image | |
from transformers import ( | |
CLIPProcessor, | |
Qwen2_5_VLForConditionalGeneration, | |
) | |
import torch.nn.functional as F | |
if TYPE_CHECKING: | |
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO | |
scheduler_config = {"num_train_timesteps": 1000} | |
BASE_MODEL_PATH = "OmniGen2/OmniGen2" | |
class OmniGen2Model(BaseModel): | |
arch = "omnigen2" | |
def __init__( | |
self, | |
device, | |
model_config: ModelConfig, | |
dtype="bf16", | |
custom_pipeline=None, | |
noise_scheduler=None, | |
**kwargs, | |
): | |
super().__init__( | |
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs | |
) | |
self.is_flow_matching = True | |
self.is_transformer = True | |
self.target_lora_modules = ["OmniGen2Transformer2DModel"] | |
self._control_latent = None | |
# static method to get the noise scheduler | |
def get_train_scheduler(): | |
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) | |
def get_bucket_divisibility(self): | |
return 16 | |
def load_model(self): | |
dtype = self.torch_dtype | |
# HiDream-ai/HiDream-I1-Full | |
self.print_and_status_update("Loading OmniGen2 model") | |
# will be updated if we detect a existing checkpoint in training folder | |
model_path = self.model_config.name_or_path | |
extras_path = self.model_config.extras_name_or_path | |
scheduler = OmniGen2Model.get_train_scheduler() | |
self.print_and_status_update("Loading Qwen2.5 VL") | |
processor = CLIPProcessor.from_pretrained( | |
extras_path, subfolder="processor", use_fast=True | |
) | |
mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
extras_path, subfolder="mllm", torch_dtype=torch.bfloat16 | |
) | |
mllm.to(self.device_torch, dtype=dtype) | |
if self.model_config.quantize_te: | |
self.print_and_status_update("Quantizing Qwen2.5 VL model") | |
quantization_type = get_qtype(self.model_config.qtype_te) | |
quantize(mllm, weights=quantization_type) | |
freeze(mllm) | |
if self.low_vram: | |
# unload it for now | |
mllm.to("cpu") | |
flush() | |
self.print_and_status_update("Loading transformer") | |
transformer = OmniGen2Transformer2DModel.from_pretrained( | |
model_path, subfolder="transformer", torch_dtype=torch.bfloat16 | |
) | |
if not self.low_vram: | |
transformer.to(self.device_torch, dtype=dtype) | |
if self.model_config.quantize: | |
self.print_and_status_update("Quantizing transformer") | |
quantization_type = get_qtype(self.model_config.qtype) | |
quantize(transformer, weights=quantization_type) | |
freeze(transformer) | |
if self.low_vram: | |
# unload it for now | |
transformer.to("cpu") | |
flush() | |
self.print_and_status_update("Loading vae") | |
vae = AutoencoderKL.from_pretrained( | |
extras_path, subfolder="vae", torch_dtype=torch.bfloat16 | |
).to(self.device_torch, dtype=dtype) | |
flush() | |
self.print_and_status_update("Loading Qwen2.5 VLProcessor") | |
flush() | |
if self.low_vram: | |
self.print_and_status_update("Moving everything to device") | |
# move it all back | |
transformer.to(self.device_torch, dtype=dtype) | |
vae.to(self.device_torch, dtype=dtype) | |
mllm.to(self.device_torch, dtype=dtype) | |
# set to eval mode | |
# transformer.eval() | |
vae.eval() | |
mllm.eval() | |
mllm.requires_grad_(False) | |
pipe: OmniGen2Pipeline = OmniGen2Pipeline( | |
transformer=transformer, | |
vae=vae, | |
scheduler=scheduler, | |
mllm=mllm, | |
processor=processor, | |
) | |
flush() | |
text_encoder_list = [mllm] | |
tokenizer_list = [processor] | |
flush() | |
# save it to the model class | |
self.vae = vae | |
self.text_encoder = text_encoder_list # list of text encoders | |
self.tokenizer = tokenizer_list # list of tokenizers | |
self.model = pipe.transformer | |
self.pipeline = pipe | |
self.freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis( | |
transformer.config.axes_dim_rope, | |
transformer.config.axes_lens, | |
theta=10000, | |
) | |
self.print_and_status_update("Model Loaded") | |
def get_generation_pipeline(self): | |
scheduler = OmniFlowMatchEuler( | |
dynamic_time_shift=True, num_train_timesteps=1000 | |
) | |
pipeline: OmniGen2Pipeline = OmniGen2Pipeline( | |
transformer=self.model, | |
vae=self.vae, | |
scheduler=scheduler, | |
mllm=self.text_encoder[0], | |
processor=self.tokenizer[0], | |
) | |
pipeline = pipeline.to(self.device_torch) | |
return pipeline | |
def generate_single_image( | |
self, | |
pipeline: OmniGen2Pipeline, | |
gen_config: GenerateImageConfig, | |
conditional_embeds: PromptEmbeds, | |
unconditional_embeds: PromptEmbeds, | |
generator: torch.Generator, | |
extra: dict, | |
): | |
input_images = [] | |
if gen_config.ctrl_img is not None: | |
control_img = Image.open(gen_config.ctrl_img) | |
control_img = control_img.convert("RGB") | |
# resize to width and height | |
if control_img.size != (gen_config.width, gen_config.height): | |
control_img = control_img.resize( | |
(gen_config.width, gen_config.height), Image.BILINEAR | |
) | |
input_images = [control_img] | |
img = pipeline( | |
prompt_embeds=conditional_embeds.text_embeds, | |
prompt_attention_mask=conditional_embeds.attention_mask, | |
negative_prompt_embeds=unconditional_embeds.text_embeds, | |
negative_prompt_attention_mask=unconditional_embeds.attention_mask, | |
height=gen_config.height, | |
width=gen_config.width, | |
num_inference_steps=gen_config.num_inference_steps, | |
text_guidance_scale=gen_config.guidance_scale, | |
image_guidance_scale=1.0, # reference image guidance scale. Add this for controls | |
latents=gen_config.latents, | |
align_res=False, | |
generator=generator, | |
input_images=input_images, | |
**extra, | |
).images[0] | |
return img | |
def get_noise_prediction( | |
self, | |
latent_model_input: torch.Tensor, | |
timestep: torch.Tensor, # 0 to 1000 scale | |
text_embeddings: PromptEmbeds, | |
**kwargs, | |
): | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
try: | |
timestep = timestep.expand(latent_model_input.shape[0]).to( | |
latent_model_input.dtype | |
) | |
except Exception as e: | |
pass | |
timesteps = timestep / 1000 # convert to 0 to 1 scale | |
# timestep for model starts at 0 instead of 1. So we need to reverse them | |
timestep = 1 - timesteps | |
model_pred = self.model( | |
latent_model_input, | |
timestep, | |
text_embeddings.text_embeds, | |
self.freqs_cis, | |
text_embeddings.attention_mask, | |
ref_image_hidden_states=self._control_latent, | |
) | |
return model_pred | |
def condition_noisy_latents( | |
self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" | |
): | |
# reset the control latent | |
self._control_latent = None | |
with torch.no_grad(): | |
control_tensor = batch.control_tensor | |
if control_tensor is not None: | |
self.vae.to(self.device_torch) | |
# we are not packed here, so we just need to pass them so we can pack them later | |
control_tensor = control_tensor * 2 - 1 | |
control_tensor = control_tensor.to( | |
self.vae_device_torch, dtype=self.torch_dtype | |
) | |
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it | |
# todo, we may not need to do this, check | |
if batch.tensor is not None: | |
target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] | |
else: | |
# When caching latents, batch.tensor is None. We get the size from the file_items instead. | |
target_h = batch.file_items[0].crop_height | |
target_w = batch.file_items[0].crop_width | |
if ( | |
control_tensor.shape[2] != target_h | |
or control_tensor.shape[3] != target_w | |
): | |
control_tensor = F.interpolate( | |
control_tensor, size=(target_h, target_w), mode="bilinear" | |
) | |
control_latent = self.encode_images(control_tensor).to( | |
latents.device, latents.dtype | |
) | |
self._control_latent = [ | |
[x.squeeze(0)] | |
for x in torch.chunk(control_latent, control_latent.shape[0], dim=0) | |
] | |
return latents.detach() | |
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
prompt = [self.pipeline._apply_chat_template(_prompt) for _prompt in prompt] | |
self.text_encoder_to(self.device_torch, dtype=self.torch_dtype) | |
max_sequence_length = 256 | |
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt( | |
prompt=prompt, | |
do_classifier_free_guidance=False, | |
device=self.device_torch, | |
max_sequence_length=max_sequence_length, | |
) | |
pe = PromptEmbeds(prompt_embeds) | |
pe.attention_mask = prompt_attention_mask | |
return pe | |
def get_model_has_grad(self): | |
# return from a weight if it has grad | |
return False | |
def get_te_has_grad(self): | |
# assume no one wants to finetune 4 text encoders. | |
return False | |
def save_model(self, output_path, meta, save_dtype): | |
# only save the transformer | |
transformer: OmniGen2Transformer2DModel = unwrap_model(self.model) | |
transformer.save_pretrained( | |
save_directory=os.path.join(output_path, "transformer"), | |
safe_serialization=True, | |
) | |
meta_path = os.path.join(output_path, "aitk_meta.yaml") | |
with open(meta_path, "w") as f: | |
yaml.dump(meta, f) | |
def get_loss_target(self, *args, **kwargs): | |
noise = kwargs.get("noise") | |
batch = kwargs.get("batch") | |
# return (noise - batch.latents).detach() | |
return (batch.latents - noise).detach() | |
def get_transformer_block_names(self) -> Optional[List[str]]: | |
# omnigen2 had a few blocks for things like noise_refiner, ref_image_refiner, context_refiner, and layers. | |
# lets do all but image refiner until we add it | |
if self.model_config.model_kwargs.get("use_image_refiner", False): | |
return ["noise_refiner", "context_refiner", "ref_image_refiner", "layers"] | |
return ["noise_refiner", "context_refiner", "layers"] | |
def convert_lora_weights_before_save(self, state_dict): | |
# currently starte with transformer. but needs to start with diffusion_model. for comfyui | |
new_sd = {} | |
for key, value in state_dict.items(): | |
new_key = key.replace("transformer.", "diffusion_model.") | |
new_sd[new_key] = value | |
return new_sd | |
def convert_lora_weights_before_load(self, state_dict): | |
# saved as diffusion_model. but needs to be transformer. for ai-toolkit | |
new_sd = {} | |
for key, value in state_dict.items(): | |
new_key = key.replace("diffusion_model.", "transformer.") | |
new_sd[new_key] = value | |
return new_sd | |
def get_base_model_version(self): | |
return "omnigen2" | |