Create tangoflux
Browse files
tangoflux
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import AutoencoderOobleck
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import T5EncoderModel,T5TokenizerFast
|
| 4 |
+
from diffusers import FluxTransformer2DModel
|
| 5 |
+
from torch import nn
|
| 6 |
+
from typing import List
|
| 7 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 8 |
+
from diffusers.training_utils import compute_density_for_timestep_sampling
|
| 9 |
+
import copy
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import numpy as np
|
| 12 |
+
from src.model import TangoFlux
|
| 13 |
+
from huggingface_hub import snapshot_download
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from typing import Optional,Union,List
|
| 16 |
+
from datasets import load_dataset, Audio
|
| 17 |
+
from math import pi
|
| 18 |
+
import json
|
| 19 |
+
import inspect
|
| 20 |
+
import yaml
|
| 21 |
+
from safetensors.torch import load_file
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TangoFluxInference:
|
| 25 |
+
|
| 26 |
+
def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
self.vae = AutoencoderOobleck()
|
| 30 |
+
|
| 31 |
+
paths = snapshot_download(repo_id=name)
|
| 32 |
+
vae_weights = load_file("{}/vae.safetensors".format(paths))
|
| 33 |
+
self.vae.load_state_dict(vae_weights)
|
| 34 |
+
weights = load_file("{}/tangoflux.safetensors".format(paths))
|
| 35 |
+
|
| 36 |
+
with open('{}/config.json'.format(paths),'r') as f:
|
| 37 |
+
config = json.load(f)
|
| 38 |
+
self.model = TangoFlux(config)
|
| 39 |
+
self.model.load_state_dict(weights,strict=False)
|
| 40 |
+
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
|
| 41 |
+
self.vae.to(device)
|
| 42 |
+
self.model.to(device)
|
| 43 |
+
|
| 44 |
+
def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
|
| 45 |
+
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
latents = self.model.inference_flow(prompt,
|
| 48 |
+
duration=duration,
|
| 49 |
+
num_inference_steps=steps,
|
| 50 |
+
guidance_scale=guidance_scale)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
|
| 55 |
+
waveform_end = int(duration * self.vae.config.sampling_rate)
|
| 56 |
+
wave = wave[:, :waveform_end]
|
| 57 |
+
return wave
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|