Spaces:
Running
on
L40S
Running
on
L40S
Commit
·
681d9b5
1
Parent(s):
3b24690
Fix torch.compile optimization issues in image editing
Browse files- Add
@torch
.compiler.allow_in_graph decorators for PIL operations
- Create writable numpy arrays to avoid compilation errors
- Isolate torch.compile to core inference function only
- Preserve timing functionality outside compiled scope
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <[email protected]>
- gradio_tabs/img_edit.py +12 -4
gradio_tabs/img_edit.py
CHANGED
@@ -37,16 +37,20 @@ labels_v = [
|
|
37 |
]
|
38 |
|
39 |
|
|
|
40 |
def load_image(img, size):
|
41 |
img = Image.open(img).convert('RGB')
|
42 |
w, h = img.size
|
43 |
img = img.resize((size, size))
|
44 |
img = np.asarray(img)
|
|
|
|
|
45 |
img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
|
46 |
|
47 |
return img / 255.0, w, h
|
48 |
|
49 |
|
|
|
50 |
def img_preprocessing(img_path, size):
|
51 |
img, w, h = load_image(img_path, size) # [0, 1]
|
52 |
img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
|
@@ -105,11 +109,15 @@ def img_postprocessing(img, w, h):
|
|
105 |
|
106 |
def img_edit(gen, device):
|
107 |
|
|
|
|
|
|
|
|
|
|
|
108 |
@spaces.GPU
|
109 |
@torch.inference_mode()
|
110 |
-
@torch.compile
|
111 |
def edit_img(image, *selected_s):
|
112 |
-
# Start timing
|
113 |
start_time = time.time()
|
114 |
print(f"[edit_img] Starting image editing...")
|
115 |
|
@@ -120,9 +128,9 @@ def img_edit(gen, device):
|
|
120 |
preprocess_end = time.time()
|
121 |
print(f"[edit_img] Preprocessing took: {(preprocess_end - preprocess_start) * 1000:.2f} ms")
|
122 |
|
123 |
-
# Model inference timing
|
124 |
inference_start = time.time()
|
125 |
-
edited_image_tensor =
|
126 |
inference_end = time.time()
|
127 |
print(f"[edit_img] Model inference took: {(inference_end - inference_start) * 1000:.2f} ms")
|
128 |
|
|
|
37 |
]
|
38 |
|
39 |
|
40 |
+
@torch.compiler.allow_in_graph
|
41 |
def load_image(img, size):
|
42 |
img = Image.open(img).convert('RGB')
|
43 |
w, h = img.size
|
44 |
img = img.resize((size, size))
|
45 |
img = np.asarray(img)
|
46 |
+
# Make a writable copy to avoid torch.compile issues
|
47 |
+
img = np.copy(img)
|
48 |
img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
|
49 |
|
50 |
return img / 255.0, w, h
|
51 |
|
52 |
|
53 |
+
@torch.compiler.allow_in_graph
|
54 |
def img_preprocessing(img_path, size):
|
55 |
img, w, h = load_image(img_path, size) # [0, 1]
|
56 |
img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
|
|
|
109 |
|
110 |
def img_edit(gen, device):
|
111 |
|
112 |
+
@torch.compile
|
113 |
+
def compiled_inference(image_tensor, selected_s):
|
114 |
+
"""Compiled version of just the model inference"""
|
115 |
+
return gen.edit_img(image_tensor, labels_v, selected_s)
|
116 |
+
|
117 |
@spaces.GPU
|
118 |
@torch.inference_mode()
|
|
|
119 |
def edit_img(image, *selected_s):
|
120 |
+
# Start timing (outside compiled function)
|
121 |
start_time = time.time()
|
122 |
print(f"[edit_img] Starting image editing...")
|
123 |
|
|
|
128 |
preprocess_end = time.time()
|
129 |
print(f"[edit_img] Preprocessing took: {(preprocess_end - preprocess_start) * 1000:.2f} ms")
|
130 |
|
131 |
+
# Model inference timing (compile only the core computation)
|
132 |
inference_start = time.time()
|
133 |
+
edited_image_tensor = compiled_inference(image_tensor, selected_s)
|
134 |
inference_end = time.time()
|
135 |
print(f"[edit_img] Model inference took: {(inference_end - inference_start) * 1000:.2f} ms")
|
136 |
|