import os import copy import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer from diffusers import DPMSolverMultistepScheduler, AutoencoderDC, FlowMatchEulerDiscreteScheduler from safetensors.torch import load_file from .qwen2_5_vit import Qwen2_5_VisionTransformer from .modeling_qwen2_native import Qwen2ForCausalLM from .sana_transformer import SanaTransformer2DModel from .sana_loss import SANALoss from copy import deepcopy from IPython import embed import logging logger = logging.getLogger(__name__) from .Templates_native import ( DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VID_START_TOKEN, DEFAULT_VID_END_TOKEN, DEFAULT_GEN_IMAGE_PATCH_TOKEN, DEFAULT_GEN_IM_START_TOKEN, DEFAULT_GEN_IM_END_TOKEN, PLACEHOLDER_IMAGE_TOKEN_IN_TEXT, DEFAULT_END_OF_CHUNK_TOKEN, DEFAULT_END_OF_AUDIO_TOKEN, DEFAULT_AUDIO_PATCH_TOKEN, DEFAULT_AU_START_TOKEN, DEFAULT_AU_END_TOKEN, DEFAULT_GEN_AUDIO_PATCH_TOKEN, DEFAULT_GEN_AU_START_TOKEN, DEFAULT_GEN_AU_END_TOKEN, PLACEHOLDER_AUDIO_TOKEN_IN_TEXT, DEFAULT_FRAME_PATCH_TOKEN, interleave_tokens, ) additional_special_tokens_qwen2 = [ "[item]", "", "", "", "", "", "
", "", "", "", "", "", "", "", "" ] def expand_gen_embeds_as_learnable_scales( clip_feat, image_grid_thw, scales, isgen_indicators, learnable_queries_1d, ): resized_clip_feat = [] new_image_grid_thw = [] assert image_grid_thw.ndim == 2 bsz = len(image_grid_thw) assert clip_feat.ndim == 2 feat_dim = clip_feat.shape[1] n_clip_token_cum = 0 assert len(isgen_indicators) == bsz #assert image_grid_thw.ndim == 3 for bsid in range(bsz): thw = image_grid_thw[bsid].tolist() assert thw[0] == 1 assert thw[1] % 2 == 0 assert thw[2] % 2 == 0 clip_h = thw[1] // 2 clip_w = thw[2] // 2 n_clip_token = clip_h * clip_w assert n_clip_token_cum + n_clip_token <= clip_feat.shape[0] if isgen_indicators[bsid]: for scale in scales: clip_feat_one = torch.zeros(scale * scale, feat_dim).to(clip_feat.dtype).to(clip_feat.device) resized_clip_feat.append(clip_feat_one) if learnable_queries_1d: new_image_grid_thw.append([1, 2, scale * scale * 2]) else: new_image_grid_thw.append([1, scale * 2, scale * 2]) else: clip_feat_one = clip_feat[n_clip_token_cum : n_clip_token_cum + n_clip_token, :] resized_clip_feat.append(clip_feat_one) new_image_grid_thw.append(thw) n_clip_token_cum += n_clip_token assert n_clip_token_cum == clip_feat.shape[0] encoder_hidden_states = torch.cat(resized_clip_feat, dim=0) return encoder_hidden_states, torch.tensor(new_image_grid_thw, dtype=image_grid_thw.dtype).to(image_grid_thw.device) def append_understand_embeds_with_learnable_scales( clip_feat, image_grid_thw, scales, dtype, device, feat_dim, learnable_queries_1d, ): if clip_feat is not None: assert feat_dim == clip_feat.shape[-1] assert dtype == clip_feat.dtype assert device == clip_feat.device assert clip_feat.ndim == 2 else: assert image_grid_thw is None fake_learnable_embed = torch.zeros(256, feat_dim).to(dtype).to(device) clip_feat = torch.cat([clip_feat, fake_learnable_embed], dim=0) if clip_feat is not None else fake_learnable_embed fake_image_grid_thw = torch.tensor([[1, 32, 32]], dtype=torch.long).to(device) image_grid_thw = torch.cat([image_grid_thw, fake_image_grid_thw], dim=0) if image_grid_thw is not None else fake_image_grid_thw return expand_gen_embeds_as_learnable_scales( clip_feat, image_grid_thw, scales, isgen_indicators=[False for _ in range(image_grid_thw.shape[0]-1)] + [True], learnable_queries_1d=learnable_queries_1d, ) def expand_gen_input_ids_as_learnable_scales( text_ids, labels, attention_mask, scales, start_token_id, end_token_id, patch_token_id, num_learnable_queries, ): assert text_ids.ndim == 2 assert text_ids.shape == labels.shape assert text_ids.shape == attention_mask.shape default_scaled_tokens = [] for scale in scales: default_scaled_tokens.append(start_token_id) default_scaled_tokens.extend([patch_token_id for _ in range(scale * scale)]) default_scaled_tokens.append(end_token_id) text_ids_list = text_ids.cpu().tolist() labels_list = labels.cpu().tolist() attention_mask_list = attention_mask.cpu().tolist() new_text_ids_list = [] new_labels_list = [] new_attention_mask_list = [] for text_ids_one_batch, labels_one_batch, attention_mask_one_batch in zip(text_ids_list, labels_list, attention_mask_list): assert len(text_ids_one_batch) == len(labels_one_batch) assert len(text_ids_one_batch) == len(attention_mask_one_batch) start_idx = [i for i, j in enumerate(labels_one_batch) if j == start_token_id] end_idx = [i for i, j in enumerate(labels_one_batch) if j == end_token_id] assert len(start_idx) == 1, start_idx assert len(end_idx) == 1, end_idx start_idx = start_idx[0] end_idx = end_idx[0] assert end_idx - start_idx == num_learnable_queries + 1, (start_idx, end_idx) assert text_ids_one_batch[start_idx] == start_token_id and text_ids_one_batch[end_idx] == end_token_id text_ids_one_batch[start_idx: end_idx+1] = deepcopy(default_scaled_tokens) labels_one_batch[start_idx: end_idx+1] = deepcopy(default_scaled_tokens) attention_mask_one_batch[start_idx: end_idx+1] = [1 for _ in range(len(default_scaled_tokens))] new_text_ids_list.append(text_ids_one_batch) new_labels_list.append(labels_one_batch) new_attention_mask_list.append(attention_mask_one_batch) return ( torch.tensor(new_text_ids_list, dtype=text_ids.dtype).to(text_ids.device), torch.tensor(new_labels_list, dtype=labels.dtype).to(labels.device), torch.tensor(new_attention_mask_list, dtype=attention_mask.dtype).to(attention_mask.device) ) def append_input_ids_with_learnable_scales( text_ids, scales, start_token_id, end_token_id, patch_token_id, ): assert text_ids.shape[0] == 1 assert text_ids[0][-1].tolist() == start_token_id labels = torch.cat([ torch.ones_like(text_ids[:,:-1]) * 0 - 100, torch.tensor([[start_token_id, patch_token_id, end_token_id]]).to(text_ids.dtype).to(text_ids.device), ], dim=1) text_ids = torch.cat([ text_ids, torch.tensor([[patch_token_id, end_token_id]]).to(text_ids.dtype).to(text_ids.device), ], dim=1) assert labels.shape == text_ids.shape attention_mask = torch.ones_like(text_ids) text_ids, labels, attention_mask = expand_gen_input_ids_as_learnable_scales( text_ids, labels, attention_mask, scales, start_token_id, end_token_id, patch_token_id, num_learnable_queries=1, ) return text_ids, labels class Ming_Uni_Inference(nn.Module): def __init__(self, inference_model_path): super(Ming_Uni_Inference, self).__init__() self.inference_model_path = inference_model_path print('loading from pretrained:',inference_model_path) self.load_from_huggingface() #embed() def init_tokens(self): num_query_token=2560 num_query_token_video=64 num_query_token_audio=32 num_decoder_image_token=1024 num_decoder_audio_token=512 self.glm_tokenizer.add_special_tokens( {"additional_special_tokens": additional_special_tokens_qwen2} ) num_new_tokens = self.glm_tokenizer.add_tokens( interleave_tokens, special_tokens=True, ) logger.warning("init_mm_specail_tokens: generation_num_tokens = {}".format(num_new_tokens)) self.glm_config.first_signal_token = self.glm_tokenizer.convert_tokens_to_ids("[IMG0]") self.glm_config.image_start_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_IM_START_TOKEN) self.glm_config.image_end_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_IM_END_TOKEN) self.glm_config.image_patch_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_PATCH_TOKEN) self.glm_config.video_start_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_VID_START_TOKEN) self.glm_config.video_end_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_VID_END_TOKEN) self.glm_config.gen_image_start_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_GEN_IM_START_TOKEN) self.glm_config.gen_image_end_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_GEN_IM_END_TOKEN) self.glm_config.gen_image_patch_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_GEN_IMAGE_PATCH_TOKEN) self.glm_config.placeholder_image_token_in_text = self.glm_tokenizer.convert_tokens_to_ids( PLACEHOLDER_IMAGE_TOKEN_IN_TEXT ) # noqa self.glm_config.end_of_chunk_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_END_OF_CHUNK_TOKEN) self.glm_config.end_of_audio_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_END_OF_AUDIO_TOKEN) self.glm_config.audio_start_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_AU_START_TOKEN) self.glm_config.audio_end_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_AU_END_TOKEN) self.glm_config.audio_patch_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_AUDIO_PATCH_TOKEN) self.glm_config.gen_audio_start_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_GEN_AU_START_TOKEN) self.glm_config.gen_audio_end_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_GEN_AU_END_TOKEN) self.glm_config.gen_audio_patch_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_GEN_AUDIO_PATCH_TOKEN) self.glm_config.placeholder_audio_token_in_text = self.glm_tokenizer.convert_tokens_to_ids( PLACEHOLDER_AUDIO_TOKEN_IN_TEXT ) # noqa self.glm_config.frame_patch_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_FRAME_PATCH_TOKEN) self.glm_config.video_patch_token = self.glm_tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_PATCH_TOKEN) self.glm_config.num_image_token = num_query_token self.glm_config.num_video_token = num_query_token_video self.glm_config.num_audio_token = num_query_token_audio self.glm_config.num_decoder_image_token = num_decoder_image_token self.glm_config.num_decoder_audio_token = num_decoder_audio_token def load_from_huggingface(self): # Load Qwen2_5_vit self.eva_encoder = Qwen2_5_VisionTransformer.from_pretrained( os.path.join(self.inference_model_path, 'qwen2_5_vit'), attn_implementation="flash_attention_2", trust_remote_code=True, force_download=True, ) # Load Qwen2_5_llm (GLM model) self.glm_tokenizer = AutoTokenizer.from_pretrained(os.path.join(self.inference_model_path, 'qwen2_5_llm')) self.glm_config = Qwen2ForCausalLM.from_pretrained(os.path.join(self.inference_model_path, 'qwen2_5_llm')).config self.init_tokens() self.glm_config.audio_vocab_size = 4099 self.glm_config.audio_id_shift = 151699 self.glm_config.spatial_merge_size = 2 self.glm_config.tokens_per_second = 2 self.glm_config._attn_implementation = "flash_attention_2" self.glm_config.use_llm_3drope = True self.glm_model = Qwen2ForCausalLM.from_pretrained(os.path.join(self.inference_model_path, 'qwen2_5_llm'), config=self.glm_config) # Load SANA # self.scheduler = DPMSolverMultistepScheduler.from_pretrained(self.inference_model_path, subfolder="scheduler") # self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.inference_model_path, subfolder="scheduler") # self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler) # self.vae = AutoencoderDC.from_pretrained(self.inference_model_path, subfolder="vae") # self.train_model = SanaTransformer2DModel.from_pretrained(self.inference_model_path, subfolder="transformer") # self.train_model = SanaModel_withMLP(self.train_model, vision_dim=self.glm_model.config.hidden_size) # Ensure vision_dim is properly defined/set # mlp_checkpoint_path = os.path.join(self.inference_model_path, 'mlp', 'model.safetensors') # assert os.path.exists(mlp_checkpoint_path), "MLP checkpoint path does not exist." # inference_load_denoising_pretrained_weights(self.train_model, mlp_checkpoint_path) self.diffloss = SANALoss( model_path=self.inference_model_path, scheduler_path=self.inference_model_path, vision_dim=self.glm_model.config.hidden_size, mlp_checkpoint_path=os.path.join(self.inference_model_path, 'mlp', 'model.safetensors'), trainable_params="", ) # Load MLP self.image_emb_dim = 8192 mlp_modules_img = [nn.Linear(self.image_emb_dim, self.glm_model.config.hidden_size)] for _ in range(1, 2): mlp_modules_img.append(nn.GELU()) mlp_modules_img.append(nn.Linear(self.glm_model.config.hidden_size, self.glm_model.config.hidden_size)) self.linear_proj = nn.Sequential(*mlp_modules_img) temp_state_dict = load_file(os.path.join(self.inference_model_path, 'mlp', 'model.safetensors')) modified_state_dict = { '0.weight': temp_state_dict['linear_proj.0.weight'], '0.bias': temp_state_dict['linear_proj.0.bias'], '2.weight': temp_state_dict['linear_proj.2.weight'], '2.bias': temp_state_dict['linear_proj.2.bias'] } self.linear_proj.load_state_dict(modified_state_dict, strict=True) self.norm_query_embeds = True # Load connector self.connector = AutoModelForCausalLM.from_pretrained(os.path.join(self.inference_model_path, 'connector')) for layer in self.connector.model.layers: layer.self_attn.is_causal = False self.proj_in = nn.Linear(self.glm_model.config.hidden_size, self.connector.config.hidden_size) self.proj_out = nn.Linear(self.connector.config.hidden_size, self.glm_model.config.hidden_size) temp_state_dict = load_file(os.path.join(self.inference_model_path, 'mlp', 'model.safetensors')) modified_state_dict_in = { 'weight': temp_state_dict['proj_in.weight'], 'bias': temp_state_dict['proj_in.bias'] } self.proj_in.load_state_dict(modified_state_dict_in, strict=True) modified_state_dict_out = { 'weight': temp_state_dict['proj_out.weight'], 'bias': temp_state_dict['proj_out.bias'] } self.proj_out.load_state_dict(modified_state_dict_out, strict=True) self.num_learnable_queries = 256 self.use_multi_scale = True self.scales = [4, 8, 16] self.learnable_queries_1d = True self.query_tokens_dict = nn.ParameterDict() total_tokens = 0 for scale in self.scales: num_tokens = scale * scale self.query_tokens_dict[f"{scale}x{scale}"] = nn.Parameter( torch.nn.functional.normalize(torch.randn(num_tokens, self.glm_model.config.hidden_size), dim=-1) ) self.query_tokens_dict[f"{scale}x{scale}"].data = temp_state_dict[f"query_tokens_dict.{scale}x{scale}"] total_tokens += num_tokens # 计算各尺度的累积索引 self.scale_indices = [] current_idx = 0 for scale in self.scales: current_idx += scale * scale self.scale_indices.append(current_idx) logger.info("All models load done.") @torch.no_grad() def image_gen_generate( self, samples, steps=20, seed=42, cfg=7.0, height=512, width=512, num_max_output_tokens=100, ): """ Args: samples (dict): A dictionary containing the output of processor steps (int): Number of inference steps for diffusion height (int): height for output image width (int): width for output image Returns: result_word (str): output words result_image (PIL.Image): output image """ assert samples["input_ids"].ndim == 2 assert samples["input_ids"].shape[0] == 1 if samples["input_ids"][0][-1].tolist() != self.glm_config.image_start_token: print("Warning: No found at the end of prompt, back to chat mode.") image_embed_list = [] if ("image" in samples) and (samples["image"] is not None): device = samples["image"].device images = samples["image"] if not isinstance(images, list): images = [images] else: device = samples["input_ids"].device images = [] image_embed_list = [] image_grid_thw = None for idx, item in enumerate(images): if len(images) > 0 and images[idx].size(0) > 0: with torch.cuda.amp.autocast(dtype=torch.bfloat16): pixel_values = images[idx].type(self.eva_encoder.get_dtype()) image_grid_thw = samples["image_grid_thw"] eva_image_feat = self.eva_encoder(pixel_values, grid_thw=image_grid_thw) image_embed_list.append(eva_image_feat) image_embeds = None inputs_opt_visual = None device = samples["input_ids"].device if len(image_embed_list) > 0: with torch.cuda.amp.autocast(dtype=torch.bfloat16): image_embeds = torch.cat(image_embed_list).to(device) image_embeds = image_embeds.float() inputs_opt_visual = self.linear_proj(image_embeds) if self.norm_query_embeds: inputs_opt_visual = torch.nn.functional.normalize(inputs_opt_visual, dim=-1) else: inputs_opt_visual = inputs_opt_visual * self.query_embeds_scale # if self.half_glm: # inputs_opt_visual = inputs_opt_visual.half() inputs = {} inputs["input_ids"] = samples["input_ids"].to(device) assert "position_ids" not in samples or samples["position_ids"] is None inputs["position_ids"] = None inputs["attention_mask"] = samples["generation_attention_mask"].to(device) query_embeds_image = inputs_opt_visual query_embeds_video = None image_grid_thw_video = None inputs["query_embeds_image"] = query_embeds_image inputs["query_embeds_video"] = query_embeds_video inputs["image_grid_thw"] = image_grid_thw inputs["image_grid_thw_video"] = image_grid_thw_video output_str = "" new_token_ids = None new_query_embeds_images = None assert inputs["input_ids"].shape[0] == 1 assert inputs["position_ids"] is None num_remaining_image_gen_token = 0 curr_image_grid_thw = inputs["image_grid_thw"] for _ in range(num_max_output_tokens): assert num_remaining_image_gen_token >= 0 curr_input_ids = torch.cat([inputs["input_ids"], new_token_ids], dim=1) if new_token_ids is not None else inputs["input_ids"] assert num_remaining_image_gen_token >= 0 true_input_ids = curr_input_ids if num_remaining_image_gen_token == 0 else curr_input_ids[:,:-1 * (num_remaining_image_gen_token + 1)] curr_query_embeds_image = inputs["query_embeds_image"] if new_query_embeds_images is not None: if curr_query_embeds_image is None: curr_query_embeds_image = new_query_embeds_images else: curr_query_embeds_image = torch.cat([ curr_query_embeds_image, new_query_embeds_images ], dim=0) if true_input_ids[0][-1].tolist() == self.glm_config.image_start_token: assert num_remaining_image_gen_token == 0 apppended_query_embeds_image, curr_image_grid_thw = append_understand_embeds_with_learnable_scales( clip_feat=curr_query_embeds_image, image_grid_thw=curr_image_grid_thw, scales=self.scales, dtype=torch.bfloat16, device=device, feat_dim=self.glm_model.config.hidden_size, learnable_queries_1d=self.learnable_queries_1d, ) curr_input_ids, labels = append_input_ids_with_learnable_scales( text_ids=true_input_ids, scales=self.scales, start_token_id=self.glm_model.config.image_start_token, end_token_id=self.glm_model.config.image_end_token, patch_token_id=self.glm_model.config.image_patch_token, ) learnable_queries_repeat = torch.cat( [self.query_tokens_dict[f"{scale}x{scale}"] for scale in self.scales], dim=0, ) # 现在基于更新后的text_ids和labels计算inner_gen_mask image_token_mask = (curr_input_ids == self.glm_model.config.image_patch_token).to(device) inner_gen_mask = torch.masked_select(labels, image_token_mask) == self.glm_model.config.image_patch_token inner_gen_mask = inner_gen_mask.unsqueeze(-1).expand_as(apppended_query_embeds_image).to(apppended_query_embeds_image.device) apppended_query_embeds_image = apppended_query_embeds_image.masked_scatter( inner_gen_mask, learnable_queries_repeat ) assert new_token_ids is None new_token_ids = curr_input_ids[:, true_input_ids.shape[1]:] assert new_query_embeds_images is None new_query_embeds_images = apppended_query_embeds_image[curr_query_embeds_image.shape[0]:, :] if curr_query_embeds_image is not None else apppended_query_embeds_image continue curr_position_ids = self.glm_model.get_rope_index(curr_input_ids, curr_image_grid_thw)[0] true_position_ids = curr_position_ids[:,:,:true_input_ids.shape[1]] outputs = self.glm_model( input_ids=true_input_ids, query_embeds_image=curr_query_embeds_image, query_embeds_video=inputs["query_embeds_video"], query_embeds_audio=None, target_embeds=None, position_ids=true_position_ids, attention_mask=None, labels=None, weights=None, image_grid_thw=curr_image_grid_thw, image_grid_thw_video=image_grid_thw_video, ) if new_query_embeds_images is not None: assert labels.shape == true_input_ids.shape gen_image_mask = labels == self.glm_model.config.image_patch_token assert gen_image_mask.sum().cpu().item() == new_query_embeds_images.shape[0] hidden_states_gen = outputs.last_hidden_state[gen_image_mask].view(outputs.last_hidden_state.shape[0], -1, outputs.last_hidden_state.shape[-1]) assert hidden_states_gen.shape[1] == new_query_embeds_images.shape[0] scale_start_idxes = [0] + self.scale_indices[:-1] scale_end_idxes = self.scale_indices assert scale_end_idxes[-1] == hidden_states_gen.shape[1] new_query_embeds_images = {} for scale, scale_start_idx, scale_end_idx in zip(self.scales, scale_start_idxes, scale_end_idxes): scale_name = f"{scale}x{scale}" scale_hidden = hidden_states_gen[:, scale_start_idx : scale_end_idx, :] scale_embeds = self.proj_in(scale_hidden) seq_shape = scale_embeds.shape with torch.cuda.amp.autocast(dtype=torch.bfloat16): scale_embeds = self.connector( inputs_embeds=scale_embeds, attention_mask=torch.ones(seq_shape[0],1,seq_shape[1],seq_shape[1]).to(scale_embeds.device), output_hidden_states=True ).hidden_states[-1] scale_embeds = self.proj_out(scale_embeds) scale_embeds = torch.nn.functional.normalize(scale_embeds, dim=-1) new_query_embeds_images[scale_name] = scale_embeds break assert num_remaining_image_gen_token == 0 new_token_id = outputs.logits[:,-1:,:].argmax(dim=-1) if (new_token_id.tolist())[0][0] == self.eos_token_id: break new_token_ids = torch.cat([new_token_ids, new_token_id], dim=1) if new_token_ids is not None else new_token_id output_str = output_str + self.glm_tokenizer.decode(new_token_id.tolist()[0]) #multiscale_result = None if self.diffloss is not None and new_query_embeds_images is not None: #print("curr_image_grid_thw: ", curr_image_grid_thw) imgs = [] for scale in self.scales: imgs.append(self.diffloss.sample(new_query_embeds_images[f"{scale}x{scale}"], steps=steps, seed=seed, cfg=cfg, height=height, width=width)) #multiscale_result = concat_horizontal(imgs) new_query_embeds_images = imgs[-1] # if self.use_multi_scale: # return output_str, new_query_embeds_images, multiscale_result return output_str, new_query_embeds_images # Usage example: # from MingUniInference import Ming_Uni_Inference # model = Ming_Uni_Inference('/videomm/share/models/xinyu/test1')