Spaces:
Runtime error
Runtime error
[feat] add torch.compile
Browse files- pipeline_ace_step.py +21 -25
pipeline_ace_step.py
CHANGED
|
@@ -2,17 +2,14 @@ import random
|
|
| 2 |
import time
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
-
import glob
|
| 6 |
|
| 7 |
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
from loguru import logger
|
| 10 |
from tqdm import tqdm
|
| 11 |
import json
|
| 12 |
import math
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
|
| 15 |
-
# from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 16 |
from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
| 17 |
from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
|
| 18 |
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
|
@@ -64,12 +61,12 @@ class ACEStepPipeline:
|
|
| 64 |
|
| 65 |
def load_checkpoint(self, checkpoint_dir=None):
|
| 66 |
device = self.device
|
| 67 |
-
|
| 68 |
dcae_model_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
|
| 69 |
vocoder_model_path = os.path.join(checkpoint_dir, "music_vocoder")
|
| 70 |
ace_step_model_path = os.path.join(checkpoint_dir, "ace_step_transformer")
|
| 71 |
text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base")
|
| 72 |
-
|
| 73 |
files_exist = (
|
| 74 |
os.path.exists(os.path.join(dcae_model_path, "config.json")) and
|
| 75 |
os.path.exists(os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")) and
|
|
@@ -154,9 +151,9 @@ class ACEStepPipeline:
|
|
| 154 |
self.loaded = True
|
| 155 |
|
| 156 |
# compile
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
|
| 161 |
def get_text_embeddings(self, texts, device, text_max_length=256):
|
| 162 |
inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
|
|
@@ -223,7 +220,7 @@ class ACEStepPipeline:
|
|
| 223 |
|
| 224 |
def get_lang(self, text):
|
| 225 |
language = "en"
|
| 226 |
-
try:
|
| 227 |
_ = self.lang_segment.getTexts(text)
|
| 228 |
langCounts = self.lang_segment.getCounts()
|
| 229 |
language = langCounts[0][0]
|
|
@@ -341,10 +338,10 @@ class ACEStepPipeline:
|
|
| 341 |
retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
|
| 342 |
# to make sure mean = 0, std = 1
|
| 343 |
target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
| 344 |
-
|
| 345 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
| 346 |
-
|
| 347 |
-
# guidance interval
|
| 348 |
start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
|
| 349 |
end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
|
| 350 |
logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
|
|
@@ -353,20 +350,20 @@ class ACEStepPipeline:
|
|
| 353 |
|
| 354 |
def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
|
| 355 |
handlers = []
|
| 356 |
-
|
| 357 |
def hook(module, input, output):
|
| 358 |
output[:] *= tau
|
| 359 |
return output
|
| 360 |
-
|
| 361 |
for i in range(l_min, l_max):
|
| 362 |
handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
|
| 363 |
handlers.append(handler)
|
| 364 |
-
|
| 365 |
encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
|
| 366 |
-
|
| 367 |
for hook in handlers:
|
| 368 |
hook.remove()
|
| 369 |
-
|
| 370 |
return encoder_hidden_states
|
| 371 |
|
| 372 |
# P(speaker, text, lyric)
|
|
@@ -399,7 +396,7 @@ class ACEStepPipeline:
|
|
| 399 |
torch.zeros_like(lyric_token_ids),
|
| 400 |
lyric_mask,
|
| 401 |
)
|
| 402 |
-
|
| 403 |
encoder_hidden_states_no_lyric = None
|
| 404 |
if do_double_condition_guidance:
|
| 405 |
# P(null_speaker, text, lyric_weaker)
|
|
@@ -426,11 +423,11 @@ class ACEStepPipeline:
|
|
| 426 |
|
| 427 |
def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
|
| 428 |
handlers = []
|
| 429 |
-
|
| 430 |
def hook(module, input, output):
|
| 431 |
output[:] *= tau
|
| 432 |
return output
|
| 433 |
-
|
| 434 |
for i in range(l_min, l_max):
|
| 435 |
handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
|
| 436 |
handlers.append(handler)
|
|
@@ -438,13 +435,12 @@ class ACEStepPipeline:
|
|
| 438 |
handlers.append(handler)
|
| 439 |
|
| 440 |
sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
|
| 441 |
-
|
| 442 |
for hook in handlers:
|
| 443 |
hook.remove()
|
| 444 |
-
|
| 445 |
return sample
|
| 446 |
|
| 447 |
-
|
| 448 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
| 449 |
# expand the latents if we are doing classifier free guidance
|
| 450 |
latents = target_latents
|
|
@@ -549,7 +545,7 @@ class ACEStepPipeline:
|
|
| 549 |
).sample
|
| 550 |
|
| 551 |
target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
|
| 552 |
-
|
| 553 |
return target_latents
|
| 554 |
|
| 555 |
def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
|
|
@@ -624,7 +620,7 @@ class ACEStepPipeline:
|
|
| 624 |
oss_steps = list(map(int, oss_steps.split(",")))
|
| 625 |
else:
|
| 626 |
oss_steps = []
|
| 627 |
-
|
| 628 |
texts = [prompt]
|
| 629 |
encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts, self.device)
|
| 630 |
encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
|
|
|
|
| 2 |
import time
|
| 3 |
import os
|
| 4 |
import re
|
|
|
|
| 5 |
|
| 6 |
import torch
|
|
|
|
| 7 |
from loguru import logger
|
| 8 |
from tqdm import tqdm
|
| 9 |
import json
|
| 10 |
import math
|
| 11 |
from huggingface_hub import hf_hub_download
|
| 12 |
|
|
|
|
| 13 |
from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
| 14 |
from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
|
| 15 |
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
|
|
|
| 61 |
|
| 62 |
def load_checkpoint(self, checkpoint_dir=None):
|
| 63 |
device = self.device
|
| 64 |
+
|
| 65 |
dcae_model_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
|
| 66 |
vocoder_model_path = os.path.join(checkpoint_dir, "music_vocoder")
|
| 67 |
ace_step_model_path = os.path.join(checkpoint_dir, "ace_step_transformer")
|
| 68 |
text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base")
|
| 69 |
+
|
| 70 |
files_exist = (
|
| 71 |
os.path.exists(os.path.join(dcae_model_path, "config.json")) and
|
| 72 |
os.path.exists(os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")) and
|
|
|
|
| 151 |
self.loaded = True
|
| 152 |
|
| 153 |
# compile
|
| 154 |
+
self.music_dcae = torch.compile(self.music_dcae)
|
| 155 |
+
self.ace_step_transformer = torch.compile(self.ace_step_transformer)
|
| 156 |
+
self.text_encoder_model = torch.compile(self.text_encoder_model)
|
| 157 |
|
| 158 |
def get_text_embeddings(self, texts, device, text_max_length=256):
|
| 159 |
inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
|
|
|
|
| 220 |
|
| 221 |
def get_lang(self, text):
|
| 222 |
language = "en"
|
| 223 |
+
try:
|
| 224 |
_ = self.lang_segment.getTexts(text)
|
| 225 |
langCounts = self.lang_segment.getCounts()
|
| 226 |
language = langCounts[0][0]
|
|
|
|
| 338 |
retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
|
| 339 |
# to make sure mean = 0, std = 1
|
| 340 |
target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
| 341 |
+
|
| 342 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
| 343 |
+
|
| 344 |
+
# guidance interval
|
| 345 |
start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
|
| 346 |
end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
|
| 347 |
logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
|
|
|
|
| 350 |
|
| 351 |
def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
|
| 352 |
handlers = []
|
| 353 |
+
|
| 354 |
def hook(module, input, output):
|
| 355 |
output[:] *= tau
|
| 356 |
return output
|
| 357 |
+
|
| 358 |
for i in range(l_min, l_max):
|
| 359 |
handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
|
| 360 |
handlers.append(handler)
|
| 361 |
+
|
| 362 |
encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
|
| 363 |
+
|
| 364 |
for hook in handlers:
|
| 365 |
hook.remove()
|
| 366 |
+
|
| 367 |
return encoder_hidden_states
|
| 368 |
|
| 369 |
# P(speaker, text, lyric)
|
|
|
|
| 396 |
torch.zeros_like(lyric_token_ids),
|
| 397 |
lyric_mask,
|
| 398 |
)
|
| 399 |
+
|
| 400 |
encoder_hidden_states_no_lyric = None
|
| 401 |
if do_double_condition_guidance:
|
| 402 |
# P(null_speaker, text, lyric_weaker)
|
|
|
|
| 423 |
|
| 424 |
def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
|
| 425 |
handlers = []
|
| 426 |
+
|
| 427 |
def hook(module, input, output):
|
| 428 |
output[:] *= tau
|
| 429 |
return output
|
| 430 |
+
|
| 431 |
for i in range(l_min, l_max):
|
| 432 |
handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
|
| 433 |
handlers.append(handler)
|
|
|
|
| 435 |
handlers.append(handler)
|
| 436 |
|
| 437 |
sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
|
| 438 |
+
|
| 439 |
for hook in handlers:
|
| 440 |
hook.remove()
|
| 441 |
+
|
| 442 |
return sample
|
| 443 |
|
|
|
|
| 444 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
| 445 |
# expand the latents if we are doing classifier free guidance
|
| 446 |
latents = target_latents
|
|
|
|
| 545 |
).sample
|
| 546 |
|
| 547 |
target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
|
| 548 |
+
|
| 549 |
return target_latents
|
| 550 |
|
| 551 |
def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
|
|
|
|
| 620 |
oss_steps = list(map(int, oss_steps.split(",")))
|
| 621 |
else:
|
| 622 |
oss_steps = []
|
| 623 |
+
|
| 624 |
texts = [prompt]
|
| 625 |
encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts, self.device)
|
| 626 |
encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
|