Arrcttacsrks commited on
Commit
0931b5b
·
verified ·
1 Parent(s): 4c38b91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -34
app.py CHANGED
@@ -24,49 +24,59 @@ def load_and_preprocess_image(image):
24
  return img.astype(np.float32)
25
 
26
  def swap_faces(source_image, target_image):
27
- # Load the ONNX model
28
- session = ort.InferenceSession(model_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # Preprocess the images
31
- source_img = load_and_preprocess_image(source_image)
32
- target_img = load_and_preprocess_image(target_image)
33
-
34
- # Prepare input data for the model
35
- source_input = np.expand_dims(source_img.transpose(2, 0, 1), 0) # Shape: (1, 3, 512, 512)
36
- target_input = np.expand_dims(target_img.transpose(2, 0, 1), 0) # Shape: (1, 3, 512, 512)
37
-
38
- # Get input names dynamically
39
- input_names = [input.name for input in session.get_inputs()]
40
-
41
- # Create input dictionary with correct shapes
42
- input_dict = {
43
- input_names[0]: source_input, # First input for source image
44
- input_names[1]: target_input # Second input for target image
45
- }
46
-
47
- # Run inference
48
- result = session.run(None, input_dict)[0]
49
-
50
- # Post-process the result
51
- result = result[0].transpose(1, 2, 0) # Convert from NCHW to HWC format
52
- result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) # Convert back to RGB
53
- return np.clip(result * 255, 0, 255).astype(np.uint8)
54
 
55
  # Create Gradio interface
56
  interface = gr.Interface(
57
  fn=swap_faces,
58
  inputs=[
59
- gr.Image(label="Source Face"),
60
- gr.Image(label="Target Image")
61
  ],
62
  outputs=gr.Image(label="Result"),
63
  title="Face Swap using SimSwap",
64
  description="Upload a source face and a target image to swap faces. The source face will be transferred onto the target image.",
65
- allow_flagging="never",
66
- examples=[
67
- ["path_to_example_source.jpg", "path_to_example_target.jpg"] # Replace with actual example image paths if available
68
- ]
69
  )
70
 
71
- # Launch the interface with share=True for public access
72
- interface.launch(share=True)
 
 
24
  return img.astype(np.float32)
25
 
26
  def swap_faces(source_image, target_image):
27
+ try:
28
+ # Load the ONNX model
29
+ session = ort.InferenceSession(model_file)
30
+
31
+ # Get input names
32
+ input_names = [input.name for input in session.get_inputs()]
33
+
34
+ # Print input shapes for debugging
35
+ for input in session.get_inputs():
36
+ print(f"Input '{input.name}' expects shape: {input.shape}")
37
+
38
+ # Preprocess the images
39
+ source_img = load_and_preprocess_image(source_image)
40
+ target_img = load_and_preprocess_image(target_image)
41
+
42
+ # Reshape inputs according to model requirements
43
+ # For the first input (assuming it's the image input)
44
+ source_input = source_img.transpose(2, 0, 1)[np.newaxis, ...] # Shape: (1, 3, 512, 512)
45
+
46
+ # For the second input (onnx::Gemm_1), reshape to rank 2 as required by the error message
47
+ target_features = target_img.transpose(2, 0, 1).reshape(-1, 512) # Reshape to 2D array
48
+
49
+ # Create input dictionary
50
+ input_dict = {
51
+ input_names[0]: source_input.astype(np.float32),
52
+ input_names[1]: target_features.astype(np.float32)
53
+ }
54
+
55
+ # Run inference
56
+ result = session.run(None, input_dict)[0]
57
+
58
+ # Post-process the result
59
+ result = result[0].transpose(1, 2, 0) # Convert from NCHW to HWC format
60
+ result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) # Convert back to RGB
61
+ return np.clip(result * 255, 0, 255).astype(np.uint8)
62
 
63
+ except Exception as e:
64
+ print(f"Error during face swapping: {str(e)}")
65
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Create Gradio interface
68
  interface = gr.Interface(
69
  fn=swap_faces,
70
  inputs=[
71
+ gr.Image(label="Source Face", type="numpy"),
72
+ gr.Image(label="Target Image", type="numpy")
73
  ],
74
  outputs=gr.Image(label="Result"),
75
  title="Face Swap using SimSwap",
76
  description="Upload a source face and a target image to swap faces. The source face will be transferred onto the target image.",
77
+ allow_flagging="never"
 
 
 
78
  )
79
 
80
+ # Launch the interface
81
+ if __name__ == "__main__":
82
+ interface.launch()