ouclxy commited on
Commit
b0560e7
·
verified ·
1 Parent(s): 753f533

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +126 -49
gradio_app.py CHANGED
@@ -198,6 +198,122 @@ def _import_inference_bits():
198
  # -----------------------------------------------------------------------------
199
  SD15_PATH, _, _ = _download_models()
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  # -----------------------------------------------------------------------------
203
  # Gradio inference
@@ -211,7 +327,7 @@ def inference(id_image, hair_image):
211
  # ZeroGPU: 强制使用 'cuda' 设备(ZeroGPU 下 torch.cuda.is_available 可能为 False)。
212
  device = torch.device("cuda")
213
 
214
- # 导入依赖
215
  (
216
  log_validation,
217
  UNet3DConditionModel,
@@ -280,54 +396,15 @@ def inference(id_image, hair_image):
280
  )
281
  logger = logging.getLogger(__name__)
282
 
283
- # Load tokenizer/encoders/vae
284
- tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer",
285
- revision=args.revision)
286
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder, revision=args.revision).to(device)
287
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",
288
- revision=args.revision).to(device, dtype=torch.float32)
289
-
290
- from omegaconf import OmegaConf
291
- infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
292
-
293
- # UNet2D with 8-channel conv_in
294
- unet2 = UNet2DConditionModel.from_pretrained(
295
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, torch_dtype=torch.float32
296
- ).to(device)
297
- conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size,
298
- padding=unet2.conv_in.padding)
299
- conv_in_8.requires_grad_(False)
300
- unet2.conv_in.requires_grad_(False)
301
- torch.nn.init.zeros_(conv_in_8.weight)
302
- conv_in_8.weight[:, :4, :, :].copy_(unet2.conv_in.weight)
303
- conv_in_8.bias.copy_(unet2.conv_in.bias)
304
- unet2.conv_in = conv_in_8
305
-
306
- controlnet = ControlNetModel.from_unet(unet2).to(device)
307
- state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location="cpu")
308
- controlnet.load_state_dict(state_dict2, strict=False)
309
-
310
- prefix = "motion_module"
311
- ckpt_num = "4140000"
312
- save_path = os.path.join(args.model_path, f"{prefix}-{ckpt_num}.pth")
313
- denoising_unet = UNet3DConditionModel.from_pretrained_2d(
314
- args.pretrained_model_name_or_path,
315
- save_path,
316
- subfolder="unet",
317
- unet_additional_kwargs=infer_config.unet_additional_kwargs,
318
- ).to(device)
319
-
320
- cc_projection = CCProjection().to(device)
321
- state_dict3 = torch.load(os.path.join(args.model_path, "pytorch_model_1.bin"), map_location="cpu")
322
- cc_projection.load_state_dict(state_dict3, strict=False)
323
-
324
- from ref_encoder.reference_unet import ref_unet
325
- Hair_Encoder = ref_unet.from_pretrained(
326
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False,
327
- device_map=None, ignore_mismatched_sizes=True
328
- ).to(device)
329
- state_dict4 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu")
330
- Hair_Encoder.load_state_dict(state_dict4, strict=False)
331
 
332
  # Run inference
333
  log_validation(
 
198
  # -----------------------------------------------------------------------------
199
  SD15_PATH, _, _ = _download_models()
200
 
201
+ # -----------------------------------------------------------------------------
202
+ # Global model loading (CPU) so GPU task only does inference
203
+ # -----------------------------------------------------------------------------
204
+ def _resolve_trained_model_dir() -> str:
205
+ tm_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
206
+ if tm_dir is None and os.path.isdir("pretrain"):
207
+ tm_dir = os.path.abspath("pretrain")
208
+ if tm_dir is None:
209
+ raise RuntimeError("Missing trained model weights. Provide TRAINED_MODEL_REPO or include ./pretrain.")
210
+ return tm_dir
211
+
212
+
213
+ # Lazy globals
214
+ G_ARGS = None
215
+ G_INFER_CONFIG = None
216
+ G_TOKENIZER = None
217
+ G_IMAGE_ENCODER = None
218
+ G_VAE = None
219
+ G_UNET2 = None
220
+ G_CONTROLNET = None
221
+ G_DENOISING_UNET = None
222
+ G_CC_PROJ = None
223
+ G_HAIR_ENCODER = None
224
+
225
+
226
+ def _load_models_cpu_once():
227
+ global G_ARGS, G_INFER_CONFIG, G_TOKENIZER, G_IMAGE_ENCODER, G_VAE
228
+ global G_UNET2, G_CONTROLNET, G_DENOISING_UNET, G_CC_PROJ, G_HAIR_ENCODER
229
+
230
+ if all(x is not None for x in (
231
+ G_ARGS, G_INFER_CONFIG, G_TOKENIZER, G_IMAGE_ENCODER, G_VAE,
232
+ G_UNET2, G_CONTROLNET, G_DENOISING_UNET, G_CC_PROJ, G_HAIR_ENCODER
233
+ )):
234
+ return
235
+
236
+ class _Args:
237
+ pretrained_model_name_or_path = SD15_PATH or os.path.abspath("stable-diffusion-v1-5/stable-diffusion-v1-5")
238
+ model_path = _resolve_trained_model_dir()
239
+ image_encoder = "openai/clip-vit-large-patch14"
240
+ controlnet_model_name_or_path = None
241
+ revision = None
242
+ output_dir = "gradio_outputs"
243
+ seed = 42
244
+ num_validation_images = 1
245
+ validation_ids = []
246
+ validation_hairs = []
247
+ use_fp16 = False
248
+ align_before_infer = True
249
+ align_size = 1024
250
+
251
+ G_ARGS = _Args()
252
+
253
+ # Import heavy libs only here
254
+ from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel
255
+ from test_stablehairv2 import UNet3DConditionModel, CCProjection, ControlNetModel
256
+ from omegaconf import OmegaConf
257
+
258
+ # Config
259
+ G_INFER_CONFIG = OmegaConf.load('./configs/inference/inference_v2.yaml')
260
+
261
+ # Tokenizer / encoders / vae (CPU)
262
+ G_TOKENIZER = AutoTokenizer.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="tokenizer",
263
+ revision=G_ARGS.revision)
264
+ G_IMAGE_ENCODER = CLIPVisionModelWithProjection.from_pretrained(G_ARGS.image_encoder, revision=G_ARGS.revision)
265
+ G_VAE = AutoencoderKL.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="vae",
266
+ revision=G_ARGS.revision)
267
+
268
+ # UNet2D with 8-channel conv_in (CPU)
269
+ G_UNET2 = UNet2DConditionModel.from_pretrained(
270
+ G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, torch_dtype=torch.float32
271
+ )
272
+ conv_in_8 = torch.nn.Conv2d(8, G_UNET2.conv_in.out_channels, kernel_size=G_UNET2.conv_in.kernel_size,
273
+ padding=G_UNET2.conv_in.padding)
274
+ conv_in_8.requires_grad_(False)
275
+ G_UNET2.conv_in.requires_grad_(False)
276
+ torch.nn.init.zeros_(conv_in_8.weight)
277
+ conv_in_8.weight[:, :4, :, :].copy_(G_UNET2.conv_in.weight)
278
+ conv_in_8.bias.copy_(G_UNET2.conv_in.bias)
279
+ G_UNET2.conv_in = conv_in_8
280
+
281
+ # ControlNet (CPU)
282
+ G_CONTROLNET = ControlNetModel.from_unet(G_UNET2)
283
+ state_dict2 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model.bin"), map_location="cpu")
284
+ G_CONTROLNET.load_state_dict(state_dict2, strict=False)
285
+
286
+ # UNet3D (CPU)
287
+ prefix = "motion_module"
288
+ ckpt_num = "4140000"
289
+ save_path = os.path.join(G_ARGS.model_path, f"{prefix}-{ckpt_num}.pth")
290
+ G_DENOISING_UNET = UNet3DConditionModel.from_pretrained_2d(
291
+ G_ARGS.pretrained_model_name_or_path,
292
+ save_path,
293
+ subfolder="unet",
294
+ unet_additional_kwargs=G_INFER_CONFIG.unet_additional_kwargs,
295
+ )
296
+
297
+ # CC projection (CPU)
298
+ G_CC_PROJ = CCProjection()
299
+ state_dict3 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_1.bin"), map_location="cpu")
300
+ G_CC_PROJ.load_state_dict(state_dict3, strict=False)
301
+
302
+ # Hair encoder (CPU)
303
+ from ref_encoder.reference_unet import ref_unet
304
+ G_HAIR_ENCODER = ref_unet.from_pretrained(
305
+ G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, low_cpu_mem_usage=False,
306
+ device_map=None, ignore_mismatched_sizes=True
307
+ )
308
+ state_dict4 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_2.bin"), map_location="cpu")
309
+ G_HAIR_ENCODER.load_state_dict(state_dict4, strict=False)
310
+
311
+
312
+ try:
313
+ _load_models_cpu_once()
314
+ except Exception as _e:
315
+ print(f"[init] Model preload warning: {_e}", flush=True)
316
+
317
 
318
  # -----------------------------------------------------------------------------
319
  # Gradio inference
 
327
  # ZeroGPU: 强制使用 'cuda' 设备(ZeroGPU 下 torch.cuda.is_available 可能为 False)。
328
  device = torch.device("cuda")
329
 
330
+ # 导入依赖(轻量函数,不再加载大模型)
331
  (
332
  log_validation,
333
  UNet3DConditionModel,
 
396
  )
397
  logger = logging.getLogger(__name__)
398
 
399
+ # 将已加载的全局模型迁移到 GPU
400
+ tokenizer = G_TOKENIZER
401
+ image_encoder = G_IMAGE_ENCODER.to(device)
402
+ vae = G_VAE.to(device, dtype=torch.float32)
403
+ unet2 = G_UNET2.to(device)
404
+ controlnet = G_CONTROLNET.to(device)
405
+ denoising_unet = G_DENOISING_UNET.to(device)
406
+ cc_projection = G_CC_PROJ.to(device)
407
+ Hair_Encoder = G_HAIR_ENCODER.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
  # Run inference
410
  log_validation(