jbilcke-hf HF Staff commited on
Commit
d72fa8b
·
2 Parent(s): 19fb8a2 4e60091

Merge branch 'main' of hf.co:spaces/jbilcke-hf/LIA-X-testing

Browse files
Files changed (2) hide show
  1. gradio_tabs/img_edit.py +37 -21
  2. networks/generator.py +31 -14
gradio_tabs/img_edit.py CHANGED
@@ -55,21 +55,31 @@ def img_preprocessing(img_path, size):
55
  return imgs_norm, w, h
56
 
57
 
58
- def resize(img, size):
59
- transform = torchvision.transforms.Compose([
60
- torchvision.transforms.Resize((size,size), antialias=True),
61
- ])
62
 
63
- return transform(img)
 
 
 
 
 
 
 
 
 
64
 
65
 
66
- def resize_back(img, w, h):
67
- transform = torchvision.transforms.Compose([
68
- torchvision.transforms.Resize((h, w), antialias=True),
69
- ])
70
 
71
- return transform(img)
72
 
 
 
 
 
73
 
74
  def img_denorm(img):
75
  img = img.clamp(-1, 1).cpu()
@@ -78,17 +88,23 @@ def img_denorm(img):
78
  return img
79
 
80
 
81
- def img_postprocessing(image, w, h):
82
-
83
- image = resize_back(image, w, h)
84
- image = image.permute(0, 2, 3, 1)
85
- edited_image = img_denorm(image)
86
- img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
87
-
88
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
89
- imageio.imwrite(temp_file.name, img_output, quality=8)
90
- return temp_file.name
91
-
 
 
 
 
 
 
92
 
93
  def img_edit(gen, device):
94
 
 
55
  return imgs_norm, w, h
56
 
57
 
58
+ # Pre-compile resize transforms for better performance
59
+ resize_transform_cache = {}
 
 
60
 
61
+ def get_resize_transform(size):
62
+ """Get cached resize transform - creates once, reuses many times"""
63
+ if size not in resize_transform_cache:
64
+ # Only create the transform if it doesn't exist in cache
65
+ resize_transform_cache[size] = torchvision.transforms.Resize(
66
+ size,
67
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
68
+ antialias=True
69
+ )
70
+ return resize_transform_cache[size]
71
 
72
 
73
+ def resize(img, size):
74
+ """Use cached resize transform"""
75
+ transform = get_resize_transform((size, size))
76
+ return transform(img)
77
 
 
78
 
79
+ def resize_back(img, w, h):
80
+ """Use cached resize transform for back operation"""
81
+ transform = get_resize_transform((h, w))
82
+ return transform(img)
83
 
84
  def img_denorm(img):
85
  img = img.clamp(-1, 1).cpu()
 
88
  return img
89
 
90
 
91
+ def img_postprocessing(img, w, h):
92
+ # Resize on GPU (using cached transform)
93
+ image = resize_back(image, w, h)
94
+
95
+ # Denormalize ON GPU (avoid early CPU transfer)
96
+ image = image.clamp(-1, 1) # Still on GPU
97
+ image = (image - image.min()) / (image.max() - image.min()) # Still on GPU
98
+
99
+ # Single optimized CPU transfer
100
+ image = image.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
101
+ img_output = (image.cpu().numpy() * 255).astype(np.uint8) # Single CPU transfer
102
+
103
+ # Use PIL directly (faster than imageio)
104
+ pil_image = Image.fromarray(img_output)
105
+
106
+ # return the PIL image directly
107
+ return pil_image
108
 
109
  def img_edit(gen, device):
110
 
networks/generator.py CHANGED
@@ -17,6 +17,16 @@ class Generator(nn.Module):
17
  # encoder
18
  self.enc = Encoder(style_dim, motion_dim, scale)
19
  self.dec = Decoder(style_dim, motion_dim, scale)
 
 
 
 
 
 
 
 
 
 
20
 
21
  def get_alpha(self, x):
22
  return self.enc.enc_motion(x)
@@ -38,9 +48,11 @@ class Generator(nn.Module):
38
  enc_r2t_end = time.time()
39
  print(f"[Generator.edit_img] enc_r2t encoding took: {(enc_r2t_end - enc_r2t_start) * 1000:.2f} ms")
40
 
41
- # Alpha modification timing
42
  alpha_mod_start = time.time()
43
- alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
 
 
44
  alpha_mod_end = time.time()
45
  print(f"[Generator.edit_img] Alpha modification took: {(alpha_mod_end - alpha_mod_start) * 1000:.2f} ms")
46
 
@@ -59,13 +71,15 @@ class Generator(nn.Module):
59
  return img_recon
60
 
61
  def animate(self, img_source, vid_target, d_l, v_l):
62
-
63
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
64
 
65
  vid_target_recon = []
66
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
67
  alpha_r2s = self.enc.enc_r2t(z_s2r)
68
- alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
 
 
 
69
 
70
  for i in tqdm(range(vid_target.size(1))):
71
  img_target = vid_target[:, i, :, :, :]
@@ -77,14 +91,16 @@ class Generator(nn.Module):
77
  return vid_target_recon
78
 
79
  def animate_batch(self, img_source, vid_target, d_l, v_l, chunk_size):
80
-
81
  b,t,c,h,w = vid_target.size()
82
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :]) # 1x40
83
 
84
  vid_target_recon = []
85
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
86
  alpha_r2s = self.enc.enc_r2t(z_s2r)
87
- alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
 
 
 
88
 
89
  bs = chunk_size
90
  chunks = t//bs
@@ -114,14 +130,16 @@ class Generator(nn.Module):
114
  return vid_target_recon # BCTHW
115
 
116
  def edit_vid(self, vid_target, d_l, v_l):
117
-
118
  img_source = vid_target[:, 0, :, :, :]
119
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
120
 
121
  vid_target_recon = []
122
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
123
  alpha_r2s = self.enc.enc_r2t(z_s2r)
124
- alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
 
 
 
125
 
126
  for i in tqdm(range(vid_target.size(1))):
127
  img_target = vid_target[:, i, :, :, :]
@@ -133,7 +151,6 @@ class Generator(nn.Module):
133
  return vid_target_recon
134
 
135
  def edit_vid_batch(self, vid_target, d_l, v_l, chunk_size):
136
-
137
  b,t,c,h,w = vid_target.size()
138
  img_source = vid_target[:, 0, :, :, :]
139
  alpha_start = self.get_alpha(img_source) # 1x40
@@ -141,7 +158,10 @@ class Generator(nn.Module):
141
  vid_target_recon = []
142
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
143
  alpha_r2s = self.enc.enc_r2t(z_s2r)
144
- alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
 
 
 
145
 
146
  bs = chunk_size
147
  chunks = t//bs
@@ -170,9 +190,7 @@ class Generator(nn.Module):
170
 
171
  return vid_target_recon # BCTHW
172
 
173
-
174
  def interpolate_img(self, img_source, d_l, v_l):
175
-
176
  vid_target_recon = []
177
 
178
  step = 16
@@ -222,5 +240,4 @@ class Generator(nn.Module):
222
 
223
  vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW
224
 
225
- return vid_target_recon
226
-
 
17
  # encoder
18
  self.enc = Encoder(style_dim, motion_dim, scale)
19
  self.dec = Decoder(style_dim, motion_dim, scale)
20
+
21
+ # Pre-allocate commonly used tensors to avoid repeated allocations
22
+ self._device = None
23
+ self._cached_tensors = {}
24
+
25
+ @property
26
+ def device(self):
27
+ if self._device is None:
28
+ self._device = next(self.parameters()).device
29
+ return self._device
30
 
31
  def get_alpha(self, x):
32
  return self.enc.enc_motion(x)
 
48
  enc_r2t_end = time.time()
49
  print(f"[Generator.edit_img] enc_r2t encoding took: {(enc_r2t_end - enc_r2t_start) * 1000:.2f} ms")
50
 
51
+ # Alpha modification timing - OPTIMIZED
52
  alpha_mod_start = time.time()
53
+ # Create tensor directly on the same device as alpha_r2s
54
+ v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
55
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
56
  alpha_mod_end = time.time()
57
  print(f"[Generator.edit_img] Alpha modification took: {(alpha_mod_end - alpha_mod_start) * 1000:.2f} ms")
58
 
 
71
  return img_recon
72
 
73
  def animate(self, img_source, vid_target, d_l, v_l):
 
74
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
75
 
76
  vid_target_recon = []
77
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
78
  alpha_r2s = self.enc.enc_r2t(z_s2r)
79
+
80
+ # Optimized alpha modification
81
+ v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
82
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
83
 
84
  for i in tqdm(range(vid_target.size(1))):
85
  img_target = vid_target[:, i, :, :, :]
 
91
  return vid_target_recon
92
 
93
  def animate_batch(self, img_source, vid_target, d_l, v_l, chunk_size):
 
94
  b,t,c,h,w = vid_target.size()
95
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :]) # 1x40
96
 
97
  vid_target_recon = []
98
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
99
  alpha_r2s = self.enc.enc_r2t(z_s2r)
100
+
101
+ # Optimized alpha modification
102
+ v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
103
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
104
 
105
  bs = chunk_size
106
  chunks = t//bs
 
130
  return vid_target_recon # BCTHW
131
 
132
  def edit_vid(self, vid_target, d_l, v_l):
 
133
  img_source = vid_target[:, 0, :, :, :]
134
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
135
 
136
  vid_target_recon = []
137
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
138
  alpha_r2s = self.enc.enc_r2t(z_s2r)
139
+
140
+ # Optimized alpha modification
141
+ v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
142
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
143
 
144
  for i in tqdm(range(vid_target.size(1))):
145
  img_target = vid_target[:, i, :, :, :]
 
151
  return vid_target_recon
152
 
153
  def edit_vid_batch(self, vid_target, d_l, v_l, chunk_size):
 
154
  b,t,c,h,w = vid_target.size()
155
  img_source = vid_target[:, 0, :, :, :]
156
  alpha_start = self.get_alpha(img_source) # 1x40
 
158
  vid_target_recon = []
159
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
160
  alpha_r2s = self.enc.enc_r2t(z_s2r)
161
+
162
+ # Optimized alpha modification
163
+ v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
164
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
165
 
166
  bs = chunk_size
167
  chunks = t//bs
 
190
 
191
  return vid_target_recon # BCTHW
192
 
 
193
  def interpolate_img(self, img_source, d_l, v_l):
 
194
  vid_target_recon = []
195
 
196
  step = 16
 
240
 
241
  vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW
242
 
243
+ return vid_target_recon