Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from tok.ar_dtok.ar_model import ARModel | |
| from tok.ar_dtok.vqvae import VQVAE | |
| from tok.ta_tok import TextAlignedTokenizer | |
| class MMAutoEncoder(nn.Module): | |
| def __init__(self, | |
| ar_path_dict, | |
| encoder_path, decoder_path, | |
| encoder_args={}, decoder_args={}): | |
| super().__init__() | |
| self.ar_model = nn.ModuleDict({resolution: ARModel.from_checkpoint(ar_path) for resolution, ar_path in ar_path_dict.items()}) | |
| self.encoder = TextAlignedTokenizer.from_checkpoint(encoder_path, load_teacher=False, **encoder_args) | |
| self.decoder = VQVAE.from_checkpoint(decoder_path, **decoder_args) | |
| def ar_sample(self, x, args): | |
| resolution = args.get("resolution", "1024px") | |
| x = self.ar_model[resolution].sample( | |
| x, | |
| cfg_scale=args.get('cfg_scale', 1.0), | |
| cfg_interval=args.get('cfg_interval', -1), | |
| temperature=args.get('temperature', 1.0), | |
| top_k=args.get('top_k', 0), | |
| top_p=args.get('top_p', 1.0) | |
| ) | |
| return x | |
| def post_process(self, x): | |
| x = x.cpu().float().clamp(0., 1.) * 255. | |
| x = x.permute(0, 2, 3, 1) # [b, h, w, c] | |
| x = x.to(torch.uint8) | |
| return x | |
| def encode(self, x): | |
| return self.encoder(x.to(self.encoder.dtype))['encoded'] | |
| def get_encoder_indices(self, x): | |
| # img -> encoder -> indices | |
| return self.encoder(x.to(self.encoder.dtype))['bottleneck_rep'] | |
| def decode_from_encoder_indices(self, indices, args={}): | |
| # indices -> encoder feats -> ar -> decoder | |
| encoder_x = self.encoder.decode_from_bottleneck(indices) | |
| ar_indices = self.ar_sample(encoder_x, args) | |
| decoder_x = self.decoder.decode_from_bottleneck(ar_indices) | |
| x = self.post_process(decoder_x) | |
| return x | |
| def decode_from_vqvae_indices(self, indices): | |
| decoder_x = self.decoder.decode_from_bottleneck(indices) | |
| x = self.post_process(decoder_x) | |
| return x | |
| def forward(self, x, args={}): | |
| encoder_x = self.encoder(x.to(self.encoder.dtype))['encoded'] | |
| ar_indices = self.ar_sample(encoder_x, args) | |
| decoder_x = self.decoder.decode_from_bottleneck(ar_indices) | |
| x = self.post_process(decoder_x) | |
| return x |