Update app.py
Browse files
app.py
CHANGED
@@ -136,6 +136,28 @@ automasker = AutoMasker(
|
|
136 |
)
|
137 |
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
|
141 |
@spaces.GPU(duration=120)
|
@@ -480,6 +502,7 @@ def app_gradio():
|
|
480 |
)
|
481 |
|
482 |
|
|
|
483 |
demo.queue().launch(share=True, show_error=True)
|
484 |
|
485 |
|
|
|
136 |
)
|
137 |
|
138 |
|
139 |
+
# Flux-based CatVTON
|
140 |
+
access_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
|
141 |
+
flux_repo = "black-forest-labs/FLUX.1-Fill-dev"
|
142 |
+
pipeline_flux = FluxTryOnPipeline.from_pretrained(flux_repo, use_auth_token=access_token)
|
143 |
+
pipeline_flux.load_lora_weights(
|
144 |
+
os.path.join(repo_path, "flux-lora"),
|
145 |
+
weight_name='pytorch_lora_weights.safetensors'
|
146 |
+
)
|
147 |
+
pipeline_flux.to("cuda", init_weight_dtype(args.mixed_precision))
|
148 |
+
|
149 |
+
|
150 |
+
# Mask-free CatVTON
|
151 |
+
catvton_mf_repo = "zhengchong/CatVTON-MaskFree"
|
152 |
+
repo_path_mf = snapshot_download(repo_id=catvton_mf_repo, use_auth_token=access_token)
|
153 |
+
pipeline_p2p = CatVTONPix2PixPipeline(
|
154 |
+
base_ckpt=args.p2p_base_model_path,
|
155 |
+
attn_ckpt=repo_path_mf,
|
156 |
+
attn_ckpt_version="mix-48k-1024",
|
157 |
+
weight_dtype=init_weight_dtype(args.mixed_precision),
|
158 |
+
use_tf32=args.allow_tf32,
|
159 |
+
device='cuda'
|
160 |
+
)
|
161 |
|
162 |
|
163 |
@spaces.GPU(duration=120)
|
|
|
502 |
)
|
503 |
|
504 |
|
505 |
+
|
506 |
demo.queue().launch(share=True, show_error=True)
|
507 |
|
508 |
|