import torch import torch.nn as nn import os from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, ) from typing import Union, List, Optional class SD3TextEncoderWithMask(nn.Module): def __init__(self, model_path, torch_dtype): super().__init__() # Define the devices for each GPU self.device_0 = torch.device('cuda:0') # GPU 0 for text encoder self.device_1 = torch.device('cuda:1') # GPU 1 for other tasks # Tokenizers for CLIP and T5 self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer')) self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2')) self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3')) # Lazy loading of models self.text_encoder = None self.text_encoder_2 = None self.text_encoder_3 = None self.model_path = model_path self.torch_dtype = torch_dtype self.tokenizer_max_length = self.tokenizer.model_max_length # Freeze parameters to avoid training overhead self._freeze() def _freeze(self): """ Freeze all model parameters to avoid training overhead. """ for param in self.parameters(): param.requires_grad = False def _load_models_if_needed(self): """ Load models only if they haven't been loaded already. """ if self.text_encoder is None: self.text_encoder = CLIPTextModelWithProjection.from_pretrained( os.path.join(self.model_path, 'text_encoder'), torch_dtype=self.torch_dtype ).to(self.device_0) # Move to GPU 0 if self.text_encoder_2 is None: self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( os.path.join(self.model_path, 'text_encoder_2'), torch_dtype=self.torch_dtype ).to(self.device_0) # Move to GPU 0 if self.text_encoder_3 is None: self.text_encoder_3 = T5EncoderModel.from_pretrained( os.path.join(self.model_path, 'text_encoder_3'), torch_dtype=self.torch_dtype ).to(self.device_0) # Move to GPU 0 def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, device: Optional[torch.device] = None, max_sequence_length: int = 128, ): """ Get embeddings from T5 model. """ self._load_models_if_needed() # Lazy loading prompt = [prompt] if isinstance(prompt, str) else prompt text_inputs = self.tokenizer_3( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) prompt_attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder_3(text_input_ids, attention_mask=prompt_attention_mask)[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_3.dtype, device=device) # Duplicate embeddings for each image generation batch_size = len(prompt) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, seq_len, -1) prompt_attention_mask = prompt_attention_mask.view(batch_size, -1).repeat(num_images_per_prompt, 1) return prompt_embeds, prompt_attention_mask def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clip_model_index: int = 0, ): """ Get embeddings from CLIP model. """ self._load_models_if_needed() # Lazy loading clip_tokenizers = [self.tokenizer, self.tokenizer_2] clip_text_encoders = [self.text_encoder, self.text_encoder_2] tokenizer = clip_tokenizers[clip_model_index] text_encoder = clip_text_encoders[clip_model_index] text_inputs = tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)[0] # Duplicate embeddings for each image generation batch_size = len(prompt) pooled_prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(batch_size * num_images_per_prompt, -1) return pooled_prompt_embeds def encode_prompt(self, prompt, num_images_per_prompt=1, device=None ): """ Encode the prompt using both CLIP and T5 models. """ prompt = [prompt] if isinstance(prompt, str) else prompt # Get embeddings from both CLIP models (on GPU 0) pooled_prompt_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0, clip_model_index=0) pooled_prompt_2_embed = self._get_clip_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0, clip_model_index=1) pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) # Get T5 embeddings (on GPU 0) prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(prompt, num_images_per_prompt=num_images_per_prompt, device=self.device_0) return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds def forward(self, input_prompts): """ Forward pass for encoding prompts. """ with torch.no_grad(): prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts) return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds # Example code for using GPU 1 for other parts of the model class OtherModel(nn.Module): def __init__(self): super(OtherModel, self).__init__() # Define your model layers self.fc = nn.Linear(512, 512).to('cuda:1') # Example layer on GPU 1 def forward(self, x): return self.fc(x) # In the main script or generation process, use GPU 1 for other tasks other_model = OtherModel().to('cuda:1') # Load on GPU 1 input_data = torch.randn(64, 512).to('cuda:1') # Move input data to GPU 1 # Perform forward pass on GPU 1 output = other_model(input_data) print(output)