Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from src.models.conditioner.base import BaseConditioner | |
| from transformers import Qwen3Model, Qwen2Tokenizer | |
| class Qwen3TextEncoder(BaseConditioner): | |
| def __init__(self, weight_path: str, embed_dim:int=None, max_length=128): | |
| super().__init__() | |
| self.tokenizer = Qwen2Tokenizer.from_pretrained(weight_path, max_length=max_length, padding_side="right") | |
| # self.model = Qwen3Model.from_pretrained(weight_path, attn_implementation="flex_attention").to(torch.bfloat16) | |
| self.model = Qwen3Model.from_pretrained(weight_path).to(torch.bfloat16) | |
| self.model.compile() | |
| self.uncondition_embedding = None | |
| self.embed_dim = embed_dim | |
| self.max_length = max_length | |
| # torch._dynamo.config.optimize_ddp = False | |
| def _impl_condition(self, y, metadata:dict={}): | |
| tokenized = self.tokenizer(y, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt") | |
| input_ids = tokenized.input_ids.cuda() | |
| attention_mask = tokenized.attention_mask.cuda() | |
| metadata["valid_length_y"] = torch.sum(attention_mask, dim=-1) | |
| y = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] | |
| if y.shape[2] < self.embed_dim: | |
| y = torch.cat([y, torch.zeros(y.shape[0], y.shape[1], self.embed_dim - y.shape[2]).to(y.device, y.dtype)], dim=-1) | |
| if y.shape[2] > self.embed_dim: | |
| y = y[:, :, :self.embed_dim] | |
| return y | |
| def _impl_uncondition(self, y, metadata:dict=None): | |
| if self.uncondition_embedding is not None: | |
| return self.uncondition_embedding.repeat(len(y), 1, 1) | |
| self.uncondition_embedding = self._impl_condition(["",]) | |
| return self.uncondition_embedding.repeat(len(y), 1, 1) |