Josh Brown Kramer commited on
Commit
74091ca
·
1 Parent(s): 4df3bee

Added support for witch and werewolf

Browse files
Files changed (2) hide show
  1. app.py +11 -6
  2. requirements.txt +1 -1
app.py CHANGED
@@ -16,8 +16,12 @@ from faceparsing2 import get_face_mask
16
  # model.load_state_dict(torch.load('models/your_pix2pixhd_model.pth'))
17
  # model.eval()
18
 
19
- model_path = hf_hub_download(repo_id="jbrownkramer/makemeazombie", filename="smaller512x512_32bit.onnx")
20
- ort_session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider'])
 
 
 
 
21
 
22
  # --- 2. Define the prediction function ---
23
  # def predict(input_image):
@@ -34,17 +38,17 @@ ort_session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider
34
 
35
  # # return output_image
36
 
37
- def predict(input_image, mode):
38
  input_image = input_image.convert("RGB")
39
  if mode == "Classic":
40
  # Use the transition_onnx function for side-by-side comparison
41
- zombie_image = zombie.transition_onnx(input_image, ort_session)
42
  if zombie_image is None:
43
  return "No face found"
44
  return zombie_image
45
  elif mode == "In Place":
46
  im_array = np.array(input_image)
47
- zombie_image = zombie.make_faces_zombie_from_array(im_array, None, ort_session)
48
  if zombie_image is None:
49
  return "No face found"
50
  return zombie_image
@@ -61,7 +65,8 @@ demo = gr.Interface(
61
  fn=predict,
62
  inputs=[
63
  gr.Image(type="pil", label="Input Image"),
64
- gr.Dropdown(choices=["Classic", "In Place"], value="Classic", label="Mode")
 
65
  ],
66
  outputs=gr.Image(type="pil", label="Output Image"),
67
  title=title,
 
16
  # model.load_state_dict(torch.load('models/your_pix2pixhd_model.pth'))
17
  # model.eval()
18
 
19
+ model_map = {"zombie":"smaller512x512_32bit.onnx","witch":"witch.onnx","werewolf":"werewolf.onnx"}
20
+
21
+ inference_session_map = {}
22
+ for model_name, model_filename in model_map.items():
23
+ model_path = hf_hub_download(repo_id="jbrownkramer/makemeazombie", filename=model_filename)
24
+ inference_session_map[model_name] = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider'])
25
 
26
  # --- 2. Define the prediction function ---
27
  # def predict(input_image):
 
38
 
39
  # # return output_image
40
 
41
+ def predict(input_image, mode, model_name):
42
  input_image = input_image.convert("RGB")
43
  if mode == "Classic":
44
  # Use the transition_onnx function for side-by-side comparison
45
+ zombie_image = zombie.transition_onnx(input_image, inference_session_map[model_name])
46
  if zombie_image is None:
47
  return "No face found"
48
  return zombie_image
49
  elif mode == "In Place":
50
  im_array = np.array(input_image)
51
+ zombie_image = zombie.make_faces_zombie_from_array(im_array, None, inference_session_map[model_name])
52
  if zombie_image is None:
53
  return "No face found"
54
  return zombie_image
 
65
  fn=predict,
66
  inputs=[
67
  gr.Image(type="pil", label="Input Image"),
68
+ gr.Dropdown(choices=["Classic", "In Place"], value="Classic", label="Mode"),
69
+ gr.Dropdown(choices=["zombie", "witch", "werewolf"], value="zombie", label="Model")
70
  ],
71
  outputs=gr.Image(type="pil", label="Output Image"),
72
  title=title,
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  gradio
2
- onnxruntime
3
  opencv-python
4
  numpy
5
  mediapipe
 
1
  gradio
2
+ onnxruntime-gpu
3
  opencv-python
4
  numpy
5
  mediapipe