jbilcke-hf commited on
Commit
4e60091
·
verified ·
1 Parent(s): 2ff7eb5

Update gradio_tabs/img_edit.py

Browse files
Files changed (1) hide show
  1. gradio_tabs/img_edit.py +37 -21
gradio_tabs/img_edit.py CHANGED
@@ -55,21 +55,31 @@ def img_preprocessing(img_path, size):
55
  return imgs_norm, w, h
56
 
57
 
58
- def resize(img, size):
59
- transform = torchvision.transforms.Compose([
60
- torchvision.transforms.Resize((size,size), antialias=True),
61
- ])
62
 
63
- return transform(img)
 
 
 
 
 
 
 
 
 
64
 
65
 
66
- def resize_back(img, w, h):
67
- transform = torchvision.transforms.Compose([
68
- torchvision.transforms.Resize((h, w), antialias=True),
69
- ])
70
 
71
- return transform(img)
72
 
 
 
 
 
73
 
74
  def img_denorm(img):
75
  img = img.clamp(-1, 1).cpu()
@@ -78,17 +88,23 @@ def img_denorm(img):
78
  return img
79
 
80
 
81
- def img_postprocessing(image, w, h):
82
-
83
- image = resize_back(image, w, h)
84
- image = image.permute(0, 2, 3, 1)
85
- edited_image = img_denorm(image)
86
- img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
87
-
88
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
89
- imageio.imwrite(temp_file.name, img_output, quality=8)
90
- return temp_file.name
91
-
 
 
 
 
 
 
92
 
93
  def img_edit(gen, device):
94
 
 
55
  return imgs_norm, w, h
56
 
57
 
58
+ # Pre-compile resize transforms for better performance
59
+ resize_transform_cache = {}
 
 
60
 
61
+ def get_resize_transform(size):
62
+ """Get cached resize transform - creates once, reuses many times"""
63
+ if size not in resize_transform_cache:
64
+ # Only create the transform if it doesn't exist in cache
65
+ resize_transform_cache[size] = torchvision.transforms.Resize(
66
+ size,
67
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
68
+ antialias=True
69
+ )
70
+ return resize_transform_cache[size]
71
 
72
 
73
+ def resize(img, size):
74
+ """Use cached resize transform"""
75
+ transform = get_resize_transform((size, size))
76
+ return transform(img)
77
 
 
78
 
79
+ def resize_back(img, w, h):
80
+ """Use cached resize transform for back operation"""
81
+ transform = get_resize_transform((h, w))
82
+ return transform(img)
83
 
84
  def img_denorm(img):
85
  img = img.clamp(-1, 1).cpu()
 
88
  return img
89
 
90
 
91
+ def img_postprocessing(img, w, h):
92
+ # Resize on GPU (using cached transform)
93
+ image = resize_back(image, w, h)
94
+
95
+ # Denormalize ON GPU (avoid early CPU transfer)
96
+ image = image.clamp(-1, 1) # Still on GPU
97
+ image = (image - image.min()) / (image.max() - image.min()) # Still on GPU
98
+
99
+ # Single optimized CPU transfer
100
+ image = image.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
101
+ img_output = (image.cpu().numpy() * 255).astype(np.uint8) # Single CPU transfer
102
+
103
+ # Use PIL directly (faster than imageio)
104
+ pil_image = Image.fromarray(img_output)
105
+
106
+ # return the PIL image directly
107
+ return pil_image
108
 
109
  def img_edit(gen, device):
110