Update TangoFlux.py
Browse files- TangoFlux.py +6 -1
TangoFlux.py
CHANGED
@@ -26,9 +26,14 @@ class TangoFluxInference:
|
|
26 |
def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
|
27 |
|
28 |
|
29 |
-
self.vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0",subfolder='vae')
|
|
|
30 |
|
|
|
|
|
31 |
paths = snapshot_download(repo_id=name)
|
|
|
|
|
32 |
weights = load_file("{}/tangoflux.safetensors".format(paths))
|
33 |
|
34 |
with open('{}/config.json'.format(paths),'r') as f:
|
|
|
26 |
def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
|
27 |
|
28 |
|
29 |
+
#self.vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0",subfolder='vae')
|
30 |
+
self.vae = AutoencoderOobleck()
|
31 |
|
32 |
+
#paths = snapshot_download(repo_id=name)
|
33 |
+
|
34 |
paths = snapshot_download(repo_id=name)
|
35 |
+
vae_weights = load_file("{}/vae.safetensors".format(paths))
|
36 |
+
self.vae.load_state_dict(vae_weights)
|
37 |
weights = load_file("{}/tangoflux.safetensors".format(paths))
|
38 |
|
39 |
with open('{}/config.json'.format(paths),'r') as f:
|