jbilcke-hf HF Staff Claude commited on
Commit
681d9b5
·
1 Parent(s): 3b24690

Fix torch.compile optimization issues in image editing

Browse files

- Add

@torch
.compiler.allow_in_graph decorators for PIL operations
- Create writable numpy arrays to avoid compilation errors
- Isolate torch.compile to core inference function only
- Preserve timing functionality outside compiled scope

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

Files changed (1) hide show
  1. gradio_tabs/img_edit.py +12 -4
gradio_tabs/img_edit.py CHANGED
@@ -37,16 +37,20 @@ labels_v = [
37
  ]
38
 
39
 
 
40
  def load_image(img, size):
41
  img = Image.open(img).convert('RGB')
42
  w, h = img.size
43
  img = img.resize((size, size))
44
  img = np.asarray(img)
 
 
45
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
46
 
47
  return img / 255.0, w, h
48
 
49
 
 
50
  def img_preprocessing(img_path, size):
51
  img, w, h = load_image(img_path, size) # [0, 1]
52
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
@@ -105,11 +109,15 @@ def img_postprocessing(img, w, h):
105
 
106
  def img_edit(gen, device):
107
 
 
 
 
 
 
108
  @spaces.GPU
109
  @torch.inference_mode()
110
- @torch.compile
111
  def edit_img(image, *selected_s):
112
- # Start timing
113
  start_time = time.time()
114
  print(f"[edit_img] Starting image editing...")
115
 
@@ -120,9 +128,9 @@ def img_edit(gen, device):
120
  preprocess_end = time.time()
121
  print(f"[edit_img] Preprocessing took: {(preprocess_end - preprocess_start) * 1000:.2f} ms")
122
 
123
- # Model inference timing
124
  inference_start = time.time()
125
- edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
126
  inference_end = time.time()
127
  print(f"[edit_img] Model inference took: {(inference_end - inference_start) * 1000:.2f} ms")
128
 
 
37
  ]
38
 
39
 
40
+ @torch.compiler.allow_in_graph
41
  def load_image(img, size):
42
  img = Image.open(img).convert('RGB')
43
  w, h = img.size
44
  img = img.resize((size, size))
45
  img = np.asarray(img)
46
+ # Make a writable copy to avoid torch.compile issues
47
+ img = np.copy(img)
48
  img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
49
 
50
  return img / 255.0, w, h
51
 
52
 
53
+ @torch.compiler.allow_in_graph
54
  def img_preprocessing(img_path, size):
55
  img, w, h = load_image(img_path, size) # [0, 1]
56
  img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
 
109
 
110
  def img_edit(gen, device):
111
 
112
+ @torch.compile
113
+ def compiled_inference(image_tensor, selected_s):
114
+ """Compiled version of just the model inference"""
115
+ return gen.edit_img(image_tensor, labels_v, selected_s)
116
+
117
  @spaces.GPU
118
  @torch.inference_mode()
 
119
  def edit_img(image, *selected_s):
120
+ # Start timing (outside compiled function)
121
  start_time = time.time()
122
  print(f"[edit_img] Starting image editing...")
123
 
 
128
  preprocess_end = time.time()
129
  print(f"[edit_img] Preprocessing took: {(preprocess_end - preprocess_start) * 1000:.2f} ms")
130
 
131
+ # Model inference timing (compile only the core computation)
132
  inference_start = time.time()
133
+ edited_image_tensor = compiled_inference(image_tensor, selected_s)
134
  inference_end = time.time()
135
  print(f"[edit_img] Model inference took: {(inference_end - inference_start) * 1000:.2f} ms")
136