pablo commited on
Commit
ee7e3f6
·
1 Parent(s): 9b31e74
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -51,11 +51,6 @@ def estimate_depth(image: Image) -> Image:
51
  output = prediction.cpu().numpy()
52
 
53
  output= 255 * output/np.max(output)
54
-
55
- # If 3 channels convert to 1
56
- if (len(output.shape) == 3):
57
- if (output.shape[2] == 3):
58
- output = output[:, :, 0]
59
 
60
  return Image.fromarray(output.astype("uint8"))
61
 
@@ -73,8 +68,10 @@ def predict(dict, depth, prompt="", negative_prompt="", guidance_scale=7.5, step
73
  negative_prompt = None
74
  scheduler_class_name = scheduler.split("-")[0]
75
 
76
- if (depth is None):
77
  depth_image = estimate_depth(image)
 
 
78
 
79
  scheduler = getattr(diffusers, scheduler_class_name)
80
  pipe.scheduler = scheduler.from_pretrained("Intel/ldm3d-4c", subfolder="scheduler")
@@ -132,6 +129,7 @@ with image_blocks as demo:
132
  with gr.Column():
133
  image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload",height=400)
134
  depth = gr.Image(source='upload', elem_id="depth_upload", type="pil", label="Upload",height=400)
 
135
  print(depth)
136
 
137
  with gr.Row(elem_id="prompt-container", mobile_collapse=False, equal_height=True):
 
51
  output = prediction.cpu().numpy()
52
 
53
  output= 255 * output/np.max(output)
 
 
 
 
 
54
 
55
  return Image.fromarray(output.astype("uint8"))
56
 
 
68
  negative_prompt = None
69
  scheduler_class_name = scheduler.split("-")[0]
70
 
71
+ if (depth.value == None):
72
  depth_image = estimate_depth(image)
73
+ else:
74
+ depth_image = depth.convert("L")
75
 
76
  scheduler = getattr(diffusers, scheduler_class_name)
77
  pipe.scheduler = scheduler.from_pretrained("Intel/ldm3d-4c", subfolder="scheduler")
 
129
  with gr.Column():
130
  image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload",height=400)
131
  depth = gr.Image(source='upload', elem_id="depth_upload", type="pil", label="Upload",height=400)
132
+
133
  print(depth)
134
 
135
  with gr.Row(elem_id="prompt-container", mobile_collapse=False, equal_height=True):