Spaces:
Running
on
Zero
Running
on
Zero
Update gradio_app.py
Browse files- gradio_app.py +126 -49
gradio_app.py
CHANGED
@@ -198,6 +198,122 @@ def _import_inference_bits():
|
|
198 |
# -----------------------------------------------------------------------------
|
199 |
SD15_PATH, _, _ = _download_models()
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
# -----------------------------------------------------------------------------
|
203 |
# Gradio inference
|
@@ -211,7 +327,7 @@ def inference(id_image, hair_image):
|
|
211 |
# ZeroGPU: 强制使用 'cuda' 设备(ZeroGPU 下 torch.cuda.is_available 可能为 False)。
|
212 |
device = torch.device("cuda")
|
213 |
|
214 |
-
#
|
215 |
(
|
216 |
log_validation,
|
217 |
UNet3DConditionModel,
|
@@ -280,54 +396,15 @@ def inference(id_image, hair_image):
|
|
280 |
)
|
281 |
logger = logging.getLogger(__name__)
|
282 |
|
283 |
-
#
|
284 |
-
tokenizer =
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
# UNet2D with 8-channel conv_in
|
294 |
-
unet2 = UNet2DConditionModel.from_pretrained(
|
295 |
-
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, torch_dtype=torch.float32
|
296 |
-
).to(device)
|
297 |
-
conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size,
|
298 |
-
padding=unet2.conv_in.padding)
|
299 |
-
conv_in_8.requires_grad_(False)
|
300 |
-
unet2.conv_in.requires_grad_(False)
|
301 |
-
torch.nn.init.zeros_(conv_in_8.weight)
|
302 |
-
conv_in_8.weight[:, :4, :, :].copy_(unet2.conv_in.weight)
|
303 |
-
conv_in_8.bias.copy_(unet2.conv_in.bias)
|
304 |
-
unet2.conv_in = conv_in_8
|
305 |
-
|
306 |
-
controlnet = ControlNetModel.from_unet(unet2).to(device)
|
307 |
-
state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location="cpu")
|
308 |
-
controlnet.load_state_dict(state_dict2, strict=False)
|
309 |
-
|
310 |
-
prefix = "motion_module"
|
311 |
-
ckpt_num = "4140000"
|
312 |
-
save_path = os.path.join(args.model_path, f"{prefix}-{ckpt_num}.pth")
|
313 |
-
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
314 |
-
args.pretrained_model_name_or_path,
|
315 |
-
save_path,
|
316 |
-
subfolder="unet",
|
317 |
-
unet_additional_kwargs=infer_config.unet_additional_kwargs,
|
318 |
-
).to(device)
|
319 |
-
|
320 |
-
cc_projection = CCProjection().to(device)
|
321 |
-
state_dict3 = torch.load(os.path.join(args.model_path, "pytorch_model_1.bin"), map_location="cpu")
|
322 |
-
cc_projection.load_state_dict(state_dict3, strict=False)
|
323 |
-
|
324 |
-
from ref_encoder.reference_unet import ref_unet
|
325 |
-
Hair_Encoder = ref_unet.from_pretrained(
|
326 |
-
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False,
|
327 |
-
device_map=None, ignore_mismatched_sizes=True
|
328 |
-
).to(device)
|
329 |
-
state_dict4 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu")
|
330 |
-
Hair_Encoder.load_state_dict(state_dict4, strict=False)
|
331 |
|
332 |
# Run inference
|
333 |
log_validation(
|
|
|
198 |
# -----------------------------------------------------------------------------
|
199 |
SD15_PATH, _, _ = _download_models()
|
200 |
|
201 |
+
# -----------------------------------------------------------------------------
|
202 |
+
# Global model loading (CPU) so GPU task only does inference
|
203 |
+
# -----------------------------------------------------------------------------
|
204 |
+
def _resolve_trained_model_dir() -> str:
|
205 |
+
tm_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
|
206 |
+
if tm_dir is None and os.path.isdir("pretrain"):
|
207 |
+
tm_dir = os.path.abspath("pretrain")
|
208 |
+
if tm_dir is None:
|
209 |
+
raise RuntimeError("Missing trained model weights. Provide TRAINED_MODEL_REPO or include ./pretrain.")
|
210 |
+
return tm_dir
|
211 |
+
|
212 |
+
|
213 |
+
# Lazy globals
|
214 |
+
G_ARGS = None
|
215 |
+
G_INFER_CONFIG = None
|
216 |
+
G_TOKENIZER = None
|
217 |
+
G_IMAGE_ENCODER = None
|
218 |
+
G_VAE = None
|
219 |
+
G_UNET2 = None
|
220 |
+
G_CONTROLNET = None
|
221 |
+
G_DENOISING_UNET = None
|
222 |
+
G_CC_PROJ = None
|
223 |
+
G_HAIR_ENCODER = None
|
224 |
+
|
225 |
+
|
226 |
+
def _load_models_cpu_once():
|
227 |
+
global G_ARGS, G_INFER_CONFIG, G_TOKENIZER, G_IMAGE_ENCODER, G_VAE
|
228 |
+
global G_UNET2, G_CONTROLNET, G_DENOISING_UNET, G_CC_PROJ, G_HAIR_ENCODER
|
229 |
+
|
230 |
+
if all(x is not None for x in (
|
231 |
+
G_ARGS, G_INFER_CONFIG, G_TOKENIZER, G_IMAGE_ENCODER, G_VAE,
|
232 |
+
G_UNET2, G_CONTROLNET, G_DENOISING_UNET, G_CC_PROJ, G_HAIR_ENCODER
|
233 |
+
)):
|
234 |
+
return
|
235 |
+
|
236 |
+
class _Args:
|
237 |
+
pretrained_model_name_or_path = SD15_PATH or os.path.abspath("stable-diffusion-v1-5/stable-diffusion-v1-5")
|
238 |
+
model_path = _resolve_trained_model_dir()
|
239 |
+
image_encoder = "openai/clip-vit-large-patch14"
|
240 |
+
controlnet_model_name_or_path = None
|
241 |
+
revision = None
|
242 |
+
output_dir = "gradio_outputs"
|
243 |
+
seed = 42
|
244 |
+
num_validation_images = 1
|
245 |
+
validation_ids = []
|
246 |
+
validation_hairs = []
|
247 |
+
use_fp16 = False
|
248 |
+
align_before_infer = True
|
249 |
+
align_size = 1024
|
250 |
+
|
251 |
+
G_ARGS = _Args()
|
252 |
+
|
253 |
+
# Import heavy libs only here
|
254 |
+
from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel
|
255 |
+
from test_stablehairv2 import UNet3DConditionModel, CCProjection, ControlNetModel
|
256 |
+
from omegaconf import OmegaConf
|
257 |
+
|
258 |
+
# Config
|
259 |
+
G_INFER_CONFIG = OmegaConf.load('./configs/inference/inference_v2.yaml')
|
260 |
+
|
261 |
+
# Tokenizer / encoders / vae (CPU)
|
262 |
+
G_TOKENIZER = AutoTokenizer.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="tokenizer",
|
263 |
+
revision=G_ARGS.revision)
|
264 |
+
G_IMAGE_ENCODER = CLIPVisionModelWithProjection.from_pretrained(G_ARGS.image_encoder, revision=G_ARGS.revision)
|
265 |
+
G_VAE = AutoencoderKL.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="vae",
|
266 |
+
revision=G_ARGS.revision)
|
267 |
+
|
268 |
+
# UNet2D with 8-channel conv_in (CPU)
|
269 |
+
G_UNET2 = UNet2DConditionModel.from_pretrained(
|
270 |
+
G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, torch_dtype=torch.float32
|
271 |
+
)
|
272 |
+
conv_in_8 = torch.nn.Conv2d(8, G_UNET2.conv_in.out_channels, kernel_size=G_UNET2.conv_in.kernel_size,
|
273 |
+
padding=G_UNET2.conv_in.padding)
|
274 |
+
conv_in_8.requires_grad_(False)
|
275 |
+
G_UNET2.conv_in.requires_grad_(False)
|
276 |
+
torch.nn.init.zeros_(conv_in_8.weight)
|
277 |
+
conv_in_8.weight[:, :4, :, :].copy_(G_UNET2.conv_in.weight)
|
278 |
+
conv_in_8.bias.copy_(G_UNET2.conv_in.bias)
|
279 |
+
G_UNET2.conv_in = conv_in_8
|
280 |
+
|
281 |
+
# ControlNet (CPU)
|
282 |
+
G_CONTROLNET = ControlNetModel.from_unet(G_UNET2)
|
283 |
+
state_dict2 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model.bin"), map_location="cpu")
|
284 |
+
G_CONTROLNET.load_state_dict(state_dict2, strict=False)
|
285 |
+
|
286 |
+
# UNet3D (CPU)
|
287 |
+
prefix = "motion_module"
|
288 |
+
ckpt_num = "4140000"
|
289 |
+
save_path = os.path.join(G_ARGS.model_path, f"{prefix}-{ckpt_num}.pth")
|
290 |
+
G_DENOISING_UNET = UNet3DConditionModel.from_pretrained_2d(
|
291 |
+
G_ARGS.pretrained_model_name_or_path,
|
292 |
+
save_path,
|
293 |
+
subfolder="unet",
|
294 |
+
unet_additional_kwargs=G_INFER_CONFIG.unet_additional_kwargs,
|
295 |
+
)
|
296 |
+
|
297 |
+
# CC projection (CPU)
|
298 |
+
G_CC_PROJ = CCProjection()
|
299 |
+
state_dict3 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_1.bin"), map_location="cpu")
|
300 |
+
G_CC_PROJ.load_state_dict(state_dict3, strict=False)
|
301 |
+
|
302 |
+
# Hair encoder (CPU)
|
303 |
+
from ref_encoder.reference_unet import ref_unet
|
304 |
+
G_HAIR_ENCODER = ref_unet.from_pretrained(
|
305 |
+
G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, low_cpu_mem_usage=False,
|
306 |
+
device_map=None, ignore_mismatched_sizes=True
|
307 |
+
)
|
308 |
+
state_dict4 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_2.bin"), map_location="cpu")
|
309 |
+
G_HAIR_ENCODER.load_state_dict(state_dict4, strict=False)
|
310 |
+
|
311 |
+
|
312 |
+
try:
|
313 |
+
_load_models_cpu_once()
|
314 |
+
except Exception as _e:
|
315 |
+
print(f"[init] Model preload warning: {_e}", flush=True)
|
316 |
+
|
317 |
|
318 |
# -----------------------------------------------------------------------------
|
319 |
# Gradio inference
|
|
|
327 |
# ZeroGPU: 强制使用 'cuda' 设备(ZeroGPU 下 torch.cuda.is_available 可能为 False)。
|
328 |
device = torch.device("cuda")
|
329 |
|
330 |
+
# 导入依赖(轻量函数,不再加载大模型)
|
331 |
(
|
332 |
log_validation,
|
333 |
UNet3DConditionModel,
|
|
|
396 |
)
|
397 |
logger = logging.getLogger(__name__)
|
398 |
|
399 |
+
# 将已加载的全局模型迁移到 GPU
|
400 |
+
tokenizer = G_TOKENIZER
|
401 |
+
image_encoder = G_IMAGE_ENCODER.to(device)
|
402 |
+
vae = G_VAE.to(device, dtype=torch.float32)
|
403 |
+
unet2 = G_UNET2.to(device)
|
404 |
+
controlnet = G_CONTROLNET.to(device)
|
405 |
+
denoising_unet = G_DENOISING_UNET.to(device)
|
406 |
+
cc_projection = G_CC_PROJ.to(device)
|
407 |
+
Hair_Encoder = G_HAIR_ENCODER.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
|
409 |
# Run inference
|
410 |
log_validation(
|