Rex Cheng
		
	commited on
		
		
					Commit 
							
							·
						
						164c335
	
1
								Parent(s):
							
							627e0b8
								
speed up inference
Browse files- app.py +2 -1
- mmaudio/eval_utils.py +20 -17
- mmaudio/ext/autoencoder/autoencoder.py +5 -1
- mmaudio/model/utils/features_utils.py +7 -5
    	
        app.py
    CHANGED
    
    | @@ -48,7 +48,8 @@ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: | |
| 48 | 
             
                                              synchformer_ckpt=model.synchformer_ckpt,
         | 
| 49 | 
             
                                              enable_conditions=True,
         | 
| 50 | 
             
                                              mode=model.mode,
         | 
| 51 | 
            -
                                              bigvgan_vocoder_ckpt=model.bigvgan_16k_path | 
|  | |
| 52 | 
             
                feature_utils = feature_utils.to(device, dtype).eval()
         | 
| 53 |  | 
| 54 | 
             
                return net, feature_utils, seq_cfg
         | 
|  | |
| 48 | 
             
                                              synchformer_ckpt=model.synchformer_ckpt,
         | 
| 49 | 
             
                                              enable_conditions=True,
         | 
| 50 | 
             
                                              mode=model.mode,
         | 
| 51 | 
            +
                                              bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
         | 
| 52 | 
            +
                                              need_vae_encoder=False)
         | 
| 53 | 
             
                feature_utils = feature_utils.to(device, dtype).eval()
         | 
| 54 |  | 
| 55 | 
             
                return net, feature_utils, seq_cfg
         | 
    	
        mmaudio/eval_utils.py
    CHANGED
    
    | @@ -76,29 +76,37 @@ all_model_cfg: dict[str, ModelConfig] = { | |
| 76 | 
             
            }
         | 
| 77 |  | 
| 78 |  | 
| 79 | 
            -
            def generate( | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 89 | 
             
                device = feature_utils.device
         | 
| 90 | 
             
                dtype = feature_utils.dtype
         | 
| 91 |  | 
| 92 | 
             
                bs = len(text)
         | 
| 93 | 
             
                if clip_video is not None:
         | 
| 94 | 
             
                    clip_video = clip_video.to(device, dtype, non_blocking=True)
         | 
| 95 | 
            -
                    clip_features = feature_utils.encode_video_with_clip(clip_video, | 
|  | |
|  | |
| 96 | 
             
                else:
         | 
| 97 | 
             
                    clip_features = net.get_empty_clip_sequence(bs)
         | 
| 98 |  | 
| 99 | 
             
                if sync_video is not None:
         | 
| 100 | 
             
                    sync_video = sync_video.to(device, dtype, non_blocking=True)
         | 
| 101 | 
            -
                    sync_features = feature_utils.encode_video_with_sync(sync_video, | 
|  | |
|  | |
| 102 | 
             
                else:
         | 
| 103 | 
             
                    sync_features = net.get_empty_sync_sequence(bs)
         | 
| 104 |  | 
| @@ -185,14 +193,9 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor | |
| 185 | 
             
                data_chunk = reader.pop_chunks()
         | 
| 186 | 
             
                clip_chunk = data_chunk[0]
         | 
| 187 | 
             
                sync_chunk = data_chunk[1]
         | 
| 188 | 
            -
                print('clip', clip_chunk.shape, clip_chunk.dtype, clip_chunk.max())
         | 
| 189 | 
            -
                print('sync', sync_chunk.shape, sync_chunk.dtype, sync_chunk.max())
         | 
| 190 | 
             
                assert clip_chunk is not None
         | 
| 191 | 
             
                assert sync_chunk is not None
         | 
| 192 |  | 
| 193 | 
            -
                for i in range(reader.num_out_streams):
         | 
| 194 | 
            -
                    print(reader.get_out_stream_info(i))
         | 
| 195 | 
            -
             | 
| 196 | 
             
                clip_frames = clip_transform(clip_chunk)
         | 
| 197 | 
             
                sync_frames = sync_transform(sync_chunk)
         | 
| 198 |  | 
|  | |
| 76 | 
             
            }
         | 
| 77 |  | 
| 78 |  | 
| 79 | 
            +
            def generate(
         | 
| 80 | 
            +
                clip_video: Optional[torch.Tensor],
         | 
| 81 | 
            +
                sync_video: Optional[torch.Tensor],
         | 
| 82 | 
            +
                text: Optional[list[str]],
         | 
| 83 | 
            +
                *,
         | 
| 84 | 
            +
                negative_text: Optional[list[str]] = None,
         | 
| 85 | 
            +
                feature_utils: FeaturesUtils,
         | 
| 86 | 
            +
                net: MMAudio,
         | 
| 87 | 
            +
                fm: FlowMatching,
         | 
| 88 | 
            +
                rng: torch.Generator,
         | 
| 89 | 
            +
                cfg_strength: float,
         | 
| 90 | 
            +
                clip_batch_size_multiplier: int = 40,
         | 
| 91 | 
            +
                sync_batch_size_multiplier: int = 40,
         | 
| 92 | 
            +
            ) -> torch.Tensor:
         | 
| 93 | 
             
                device = feature_utils.device
         | 
| 94 | 
             
                dtype = feature_utils.dtype
         | 
| 95 |  | 
| 96 | 
             
                bs = len(text)
         | 
| 97 | 
             
                if clip_video is not None:
         | 
| 98 | 
             
                    clip_video = clip_video.to(device, dtype, non_blocking=True)
         | 
| 99 | 
            +
                    clip_features = feature_utils.encode_video_with_clip(clip_video,
         | 
| 100 | 
            +
                                                                         batch_size=bs *
         | 
| 101 | 
            +
                                                                         clip_batch_size_multiplier)
         | 
| 102 | 
             
                else:
         | 
| 103 | 
             
                    clip_features = net.get_empty_clip_sequence(bs)
         | 
| 104 |  | 
| 105 | 
             
                if sync_video is not None:
         | 
| 106 | 
             
                    sync_video = sync_video.to(device, dtype, non_blocking=True)
         | 
| 107 | 
            +
                    sync_features = feature_utils.encode_video_with_sync(sync_video,
         | 
| 108 | 
            +
                                                                         batch_size=bs *
         | 
| 109 | 
            +
                                                                         sync_batch_size_multiplier)
         | 
| 110 | 
             
                else:
         | 
| 111 | 
             
                    sync_features = net.get_empty_sync_sequence(bs)
         | 
| 112 |  | 
|  | |
| 193 | 
             
                data_chunk = reader.pop_chunks()
         | 
| 194 | 
             
                clip_chunk = data_chunk[0]
         | 
| 195 | 
             
                sync_chunk = data_chunk[1]
         | 
|  | |
|  | |
| 196 | 
             
                assert clip_chunk is not None
         | 
| 197 | 
             
                assert sync_chunk is not None
         | 
| 198 |  | 
|  | |
|  | |
|  | |
| 199 | 
             
                clip_frames = clip_transform(clip_chunk)
         | 
| 200 | 
             
                sync_frames = sync_transform(sync_chunk)
         | 
| 201 |  | 
    	
        mmaudio/ext/autoencoder/autoencoder.py
    CHANGED
    
    | @@ -15,7 +15,8 @@ class AutoEncoderModule(nn.Module): | |
| 15 | 
             
                             *,
         | 
| 16 | 
             
                             vae_ckpt_path,
         | 
| 17 | 
             
                             vocoder_ckpt_path: Optional[str] = None,
         | 
| 18 | 
            -
                             mode: Literal['16k', '44k'] | 
|  | |
| 19 | 
             
                    super().__init__()
         | 
| 20 | 
             
                    self.vae: VAE = get_my_vae(mode).eval()
         | 
| 21 | 
             
                    vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
         | 
| @@ -35,6 +36,9 @@ class AutoEncoderModule(nn.Module): | |
| 35 | 
             
                    for param in self.parameters():
         | 
| 36 | 
             
                        param.requires_grad = False
         | 
| 37 |  | 
|  | |
|  | |
|  | |
| 38 | 
             
                @torch.inference_mode()
         | 
| 39 | 
             
                def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
         | 
| 40 | 
             
                    return self.vae.encode(x)
         | 
|  | |
| 15 | 
             
                             *,
         | 
| 16 | 
             
                             vae_ckpt_path,
         | 
| 17 | 
             
                             vocoder_ckpt_path: Optional[str] = None,
         | 
| 18 | 
            +
                             mode: Literal['16k', '44k'],
         | 
| 19 | 
            +
                             need_vae_encoder: bool = True):
         | 
| 20 | 
             
                    super().__init__()
         | 
| 21 | 
             
                    self.vae: VAE = get_my_vae(mode).eval()
         | 
| 22 | 
             
                    vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
         | 
|  | |
| 36 | 
             
                    for param in self.parameters():
         | 
| 37 | 
             
                        param.requires_grad = False
         | 
| 38 |  | 
| 39 | 
            +
                    if not need_vae_encoder:
         | 
| 40 | 
            +
                        del self.vae.encoder
         | 
| 41 | 
            +
             | 
| 42 | 
             
                @torch.inference_mode()
         | 
| 43 | 
             
                def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
         | 
| 44 | 
             
                    return self.vae.encode(x)
         | 
    	
        mmaudio/model/utils/features_utils.py
    CHANGED
    
    | @@ -41,6 +41,7 @@ class FeaturesUtils(nn.Module): | |
| 41 | 
             
                    synchformer_ckpt: Optional[str] = None,
         | 
| 42 | 
             
                    enable_conditions: bool = True,
         | 
| 43 | 
             
                    mode=Literal['16k', '44k'],
         | 
|  | |
| 44 | 
             
                ):
         | 
| 45 | 
             
                    super().__init__()
         | 
| 46 |  | 
| @@ -64,19 +65,18 @@ class FeaturesUtils(nn.Module): | |
| 64 | 
             
                    if tod_vae_ckpt is not None:
         | 
| 65 | 
             
                        self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
         | 
| 66 | 
             
                                                     vocoder_ckpt_path=bigvgan_vocoder_ckpt,
         | 
| 67 | 
            -
                                                     mode=mode | 
|  | |
| 68 | 
             
                    else:
         | 
| 69 | 
             
                        self.tod = None
         | 
| 70 | 
             
                    self.mel_converter = MelConverter()
         | 
| 71 |  | 
| 72 | 
             
                def compile(self):
         | 
| 73 | 
             
                    if self.clip_model is not None:
         | 
| 74 | 
            -
                        self.encode_video_with_clip = torch.compile(self.encode_video_with_clip)
         | 
| 75 | 
             
                        self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
         | 
| 76 | 
             
                        self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
         | 
| 77 | 
             
                    if self.synchformer is not None:
         | 
| 78 | 
             
                        self.synchformer = torch.compile(self.synchformer)
         | 
| 79 | 
            -
                    self.tod.encode = torch.compile(self.tod.encode)
         | 
| 80 | 
             
                    self.decode = torch.compile(self.decode)
         | 
| 81 | 
             
                    self.vocode = torch.compile(self.vocode)
         | 
| 82 |  | 
| @@ -121,9 +121,11 @@ class FeaturesUtils(nn.Module): | |
| 121 | 
             
                    outputs = []
         | 
| 122 | 
             
                    if batch_size < 0:
         | 
| 123 | 
             
                        batch_size = b
         | 
| 124 | 
            -
                     | 
|  | |
| 125 | 
             
                        outputs.append(self.synchformer(x[i:i + batch_size]))
         | 
| 126 | 
            -
                    x = torch.cat(outputs, dim=0) | 
|  | |
| 127 | 
             
                    return x
         | 
| 128 |  | 
| 129 | 
             
                @torch.inference_mode()
         | 
|  | |
| 41 | 
             
                    synchformer_ckpt: Optional[str] = None,
         | 
| 42 | 
             
                    enable_conditions: bool = True,
         | 
| 43 | 
             
                    mode=Literal['16k', '44k'],
         | 
| 44 | 
            +
                    need_vae_encoder: bool = True,
         | 
| 45 | 
             
                ):
         | 
| 46 | 
             
                    super().__init__()
         | 
| 47 |  | 
|  | |
| 65 | 
             
                    if tod_vae_ckpt is not None:
         | 
| 66 | 
             
                        self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
         | 
| 67 | 
             
                                                     vocoder_ckpt_path=bigvgan_vocoder_ckpt,
         | 
| 68 | 
            +
                                                     mode=mode,
         | 
| 69 | 
            +
                                                     need_vae_encoder=need_vae_encoder)
         | 
| 70 | 
             
                    else:
         | 
| 71 | 
             
                        self.tod = None
         | 
| 72 | 
             
                    self.mel_converter = MelConverter()
         | 
| 73 |  | 
| 74 | 
             
                def compile(self):
         | 
| 75 | 
             
                    if self.clip_model is not None:
         | 
|  | |
| 76 | 
             
                        self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
         | 
| 77 | 
             
                        self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
         | 
| 78 | 
             
                    if self.synchformer is not None:
         | 
| 79 | 
             
                        self.synchformer = torch.compile(self.synchformer)
         | 
|  | |
| 80 | 
             
                    self.decode = torch.compile(self.decode)
         | 
| 81 | 
             
                    self.vocode = torch.compile(self.vocode)
         | 
| 82 |  | 
|  | |
| 121 | 
             
                    outputs = []
         | 
| 122 | 
             
                    if batch_size < 0:
         | 
| 123 | 
             
                        batch_size = b
         | 
| 124 | 
            +
                    x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
         | 
| 125 | 
            +
                    for i in range(0, b * num_segments, batch_size):
         | 
| 126 | 
             
                        outputs.append(self.synchformer(x[i:i + batch_size]))
         | 
| 127 | 
            +
                    x = torch.cat(outputs, dim=0)
         | 
| 128 | 
            +
                    x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
         | 
| 129 | 
             
                    return x
         | 
| 130 |  | 
| 131 | 
             
                @torch.inference_mode()
         | 
