sky24h commited on
Commit
31abe01
·
1 Parent(s): 1a1a5fe
Files changed (1) hide show
  1. inference_utils.py +18 -20
inference_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
- import torch, random
 
3
 
4
  seed = 1024
5
  random.seed(seed)
@@ -10,12 +11,12 @@ torch.backends.cudnn.deterministic = True
10
  torch.backends.cudnn.benchmark = False
11
 
12
  # SPIGA ckpt downloading always fails, so we download it manually and put it in the right place.
13
- import site
14
  from gdown import download
15
 
16
- user_site_packages_path = site.getusersitepackages()
17
  spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
18
- ckpt_path = os.path.join(user_site_packages_path, "spiga/models/weights/spiga_300wpublic.pt")
19
  if not os.path.exists(ckpt_path):
20
  os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
21
  download(id=spiga_file_id, output=ckpt_path)
@@ -30,7 +31,6 @@ from diffusers import DDIMScheduler, ControlNetModel
30
  from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
31
  from detail_encoder.encoder_plus import detail_encoder
32
 
33
-
34
  detector = FaceDetector(weight_path="./models/mobilenet0.25_Final.pth")
35
 
36
 
@@ -64,21 +64,21 @@ def concatenate_images(image_files, output_file):
64
 
65
  def init_pipeline():
66
  # Initialize the model
67
- model_id = "runwayml/stable-diffusion-v1-5" # or your local sdv1-5 path
68
  base_path = "./checkpoints/stablemakeup"
69
  folder_id = "1397t27GrUyLPnj17qVpKWGwg93EcaFfg"
70
  if not os.path.exists(base_path):
71
  download_folder(id=folder_id, output=base_path, quiet=False, use_cookies=False)
72
  makeup_encoder_path = base_path + "/pytorch_model.bin"
73
- id_encoder_path = base_path + "/pytorch_model_1.bin"
74
- pose_encoder_path = base_path + "/pytorch_model_2.bin"
75
-
76
- Unet = OriginalUNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to("cuda")
77
- id_encoder = ControlNetModel.from_unet(Unet)
78
- pose_encoder = ControlNetModel.from_unet(Unet)
79
- makeup_encoder = detail_encoder(Unet, "openai/clip-vit-large-patch14", "cuda", dtype=torch.float32)
80
- id_state_dict = torch.load(id_encoder_path)
81
- pose_state_dict = torch.load(pose_encoder_path)
82
  makeup_state_dict = torch.load(makeup_encoder_path)
83
  id_encoder.load_state_dict(id_state_dict, strict=False)
84
  pose_encoder.load_state_dict(pose_state_dict, strict=False)
@@ -99,10 +99,8 @@ pipeline, makeup_encoder = init_pipeline()
99
 
100
 
101
  def inference(id_image_pil, makeup_image_pil, guidance_scale=1.6, size=512):
102
- id_image = id_image_pil.resize((size, size))
103
  makeup_image = makeup_image_pil.resize((size, size))
104
- pose_image = get_draw(id_image, size=size)
105
- result_img = makeup_encoder.generate(
106
- id_image=[id_image, pose_image], makeup_image=makeup_image, pipe=pipeline, guidance_scale=guidance_scale
107
- )
108
  return result_img
 
1
  import os
2
+ import torch
3
+ import random
4
 
5
  seed = 1024
6
  random.seed(seed)
 
11
  torch.backends.cudnn.benchmark = False
12
 
13
  # SPIGA ckpt downloading always fails, so we download it manually and put it in the right place.
14
+ import spiga
15
  from gdown import download
16
 
17
+ pkg_path = spiga.__file__.replace("/__init__.py", "")
18
  spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
19
+ ckpt_path = os.path.join(pkg_path, "spiga/models/weights/spiga_300wpublic.pt")
20
  if not os.path.exists(ckpt_path):
21
  os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
22
  download(id=spiga_file_id, output=ckpt_path)
 
31
  from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
32
  from detail_encoder.encoder_plus import detail_encoder
33
 
 
34
  detector = FaceDetector(weight_path="./models/mobilenet0.25_Final.pth")
35
 
36
 
 
64
 
65
  def init_pipeline():
66
  # Initialize the model
67
+ model_id = "runwayml/stable-diffusion-v1-5" # or your local sdv1-5 path
68
  base_path = "./checkpoints/stablemakeup"
69
  folder_id = "1397t27GrUyLPnj17qVpKWGwg93EcaFfg"
70
  if not os.path.exists(base_path):
71
  download_folder(id=folder_id, output=base_path, quiet=False, use_cookies=False)
72
  makeup_encoder_path = base_path + "/pytorch_model.bin"
73
+ id_encoder_path = base_path + "/pytorch_model_1.bin"
74
+ pose_encoder_path = base_path + "/pytorch_model_2.bin"
75
+
76
+ Unet = OriginalUNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to("cuda")
77
+ id_encoder = ControlNetModel.from_unet(Unet)
78
+ pose_encoder = ControlNetModel.from_unet(Unet)
79
+ makeup_encoder = detail_encoder(Unet, "openai/clip-vit-large-patch14", "cuda", dtype=torch.float32)
80
+ id_state_dict = torch.load(id_encoder_path)
81
+ pose_state_dict = torch.load(pose_encoder_path)
82
  makeup_state_dict = torch.load(makeup_encoder_path)
83
  id_encoder.load_state_dict(id_state_dict, strict=False)
84
  pose_encoder.load_state_dict(pose_state_dict, strict=False)
 
99
 
100
 
101
  def inference(id_image_pil, makeup_image_pil, guidance_scale=1.6, size=512):
102
+ id_image = id_image_pil.resize((size, size))
103
  makeup_image = makeup_image_pil.resize((size, size))
104
+ pose_image = get_draw(id_image, size=size)
105
+ result_img = makeup_encoder.generate(id_image=[id_image, pose_image], makeup_image=makeup_image, pipe=pipeline, guidance_scale=guidance_scale)
 
 
106
  return result_img