Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
from torch import nn | |
from networks.encoder import Encoder | |
from networks.decoder import Decoder | |
import numpy as np | |
from tqdm import tqdm | |
from einops import rearrange, repeat | |
import time | |
from contextlib import contextmanager | |
def timing_context(label, enabled=True): | |
"""Context manager for timing that doesn't break torch.compile""" | |
if not enabled: | |
yield | |
return | |
start = time.time() | |
yield | |
end = time.time() | |
print(f"[Generator.edit_img] {label} took: {(end - start) * 1000:.2f} ms") | |
class Generator(nn.Module): | |
def __init__(self, size, style_dim=512, motion_dim=40, scale=1): | |
super(Generator, self).__init__() | |
style_dim = style_dim * scale | |
# encoder | |
self.enc = Encoder(style_dim, motion_dim, scale) | |
self.dec = Decoder(style_dim, motion_dim, scale) | |
# Pre-allocate commonly used tensors to avoid repeated allocations | |
self._device = None | |
self._cached_tensors = {} | |
def device(self): | |
if self._device is None: | |
self._device = next(self.parameters()).device | |
return self._device | |
def get_alpha(self, x): | |
return self.enc.enc_motion(x) | |
def edit_img(self, img_source, d_l, v_l): | |
return self._edit_img_core(img_source, d_l, v_l) | |
def edit_img_with_timing(self, img_source, d_l, v_l): | |
"""Version with timing for debugging - not compiled""" | |
start_time = time.time() | |
print(f"[Generator.edit_img] Starting image editing...") | |
with timing_context("enc_2r encoding"): | |
z_s2r, feat_rgb = self.enc.enc_2r(img_source) | |
with timing_context("enc_r2t encoding"): | |
alpha_r2s = self.enc.enc_r2t(z_s2r) | |
with timing_context("Alpha modification"): | |
# Create tensor directly on the same device as alpha_r2s | |
v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0) | |
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor | |
with timing_context("Decoding"): | |
img_recon = self.dec(z_s2r, [alpha_r2s], feat_rgb) | |
# Total time | |
end_time = time.time() | |
total_time_ms = (end_time - start_time) * 1000 | |
print(f"[Generator.edit_img] Total execution time: {total_time_ms:.2f} ms") | |
print(f"[Generator.edit_img] ----------------------------------------") | |
return img_recon | |
def _edit_img_core(self, img_source, d_l, v_l): | |
"""Core edit_img logic without timing - can be compiled""" | |
z_s2r, feat_rgb = self.enc.enc_2r(img_source) | |
alpha_r2s = self.enc.enc_r2t(z_s2r) | |
# Create tensor directly on the same device as alpha_r2s | |
v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0) | |
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor | |
img_recon = self.dec(z_s2r, [alpha_r2s], feat_rgb) | |
return img_recon | |
def animate(self, img_source, vid_target, d_l, v_l): | |
alpha_start = self.get_alpha(vid_target[:, 0, :, :, :]) | |
vid_target_recon = [] | |
z_s2r, feat_rgb = self.enc.enc_2r(img_source) | |
alpha_r2s = self.enc.enc_r2t(z_s2r) | |
# Optimized alpha modification | |
v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0) | |
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor | |
for i in tqdm(range(vid_target.size(1))): | |
img_target = vid_target[:, i, :, :, :] | |
alpha = self.enc.enc_transfer_vid(alpha_r2s, img_target, alpha_start) | |
img_recon = self.dec(z_s2r, alpha, feat_rgb) | |
vid_target_recon.append(img_recon.unsqueeze(2)) | |
vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW | |
return vid_target_recon | |
def animate_batch(self, img_source, vid_target, d_l, v_l, chunk_size): | |
b,t,c,h,w = vid_target.size() | |
alpha_start = self.get_alpha(vid_target[:, 0, :, :, :]) # 1x40 | |
vid_target_recon = [] | |
z_s2r, feat_rgb = self.enc.enc_2r(img_source) | |
alpha_r2s = self.enc.enc_r2t(z_s2r) | |
# Optimized alpha modification | |
v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0) | |
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor | |
bs = chunk_size | |
chunks = t//bs | |
alpha_start_r = repeat(alpha_start, 'b c -> (repeat b) c', repeat=bs) | |
alpha_r2s_r = repeat(alpha_r2s, 'b c -> (repeat b) c', repeat=bs) | |
feat_rgb_r = [repeat(feat, 'b c h w -> (repeat b) c h w', repeat=bs) for feat in feat_rgb] | |
z_s2r_r = repeat(z_s2r, 'b c -> (repeat b) c', repeat=bs) | |
for i in range(chunks+1): | |
if i == chunks: | |
img_target = vid_target[:, i*bs:, :, :, :] | |
bs = t-i*bs | |
alpha_start_r = alpha_start_r[:bs] | |
alpha_r2s_r = alpha_r2s_r[:bs] | |
feat_rgb_r = [feat[:bs] for feat in feat_rgb_r] | |
z_s2r_r = z_s2r_r[:bs] | |
else: | |
img_target = vid_target[:, i*bs:(i+1)*bs, :, :, :] | |
alpha = self.enc.enc_transfer_vid(alpha_r2s_r, img_target.squeeze(0), alpha_start_r) | |
img_recon = self.dec(z_s2r_r, alpha, feat_rgb_r) # bs x 3 x h x w | |
vid_target_recon.append(img_recon) | |
vid_target_recon = torch.cat(vid_target_recon, dim=0).unsqueeze(0) # 1xTCHW | |
vid_target_recon = rearrange(vid_target_recon, 'b t c h w -> b c t h w') | |
return vid_target_recon # BCTHW | |
def edit_vid(self, vid_target, d_l, v_l): | |
img_source = vid_target[:, 0, :, :, :] | |
alpha_start = self.get_alpha(vid_target[:, 0, :, :, :]) | |
vid_target_recon = [] | |
z_s2r, feat_rgb = self.enc.enc_2r(img_source) | |
alpha_r2s = self.enc.enc_r2t(z_s2r) | |
# Optimized alpha modification | |
v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0) | |
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor | |
for i in tqdm(range(vid_target.size(1))): | |
img_target = vid_target[:, i, :, :, :] | |
alpha = self.enc.enc_transfer_vid(alpha_r2s, img_target, alpha_start) | |
img_recon = self.dec(z_s2r, alpha, feat_rgb) | |
vid_target_recon.append(img_recon.unsqueeze(2)) | |
vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW | |
return vid_target_recon | |
def edit_vid_batch(self, vid_target, d_l, v_l, chunk_size): | |
b,t,c,h,w = vid_target.size() | |
img_source = vid_target[:, 0, :, :, :] | |
alpha_start = self.get_alpha(img_source) # 1x40 | |
vid_target_recon = [] | |
z_s2r, feat_rgb = self.enc.enc_2r(img_source) | |
alpha_r2s = self.enc.enc_r2t(z_s2r) | |
# Optimized alpha modification | |
v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0) | |
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor | |
bs = chunk_size | |
chunks = t//bs | |
alpha_start_r = repeat(alpha_start, 'b c -> (repeat b) c', repeat=bs) | |
alpha_r2s_r = repeat(alpha_r2s, 'b c -> (repeat b) c', repeat=bs) | |
feat_rgb_r = [repeat(feat, 'b c h w -> (repeat b) c h w', repeat=bs) for feat in feat_rgb] | |
z_s2r_r = repeat(z_s2r, 'b c -> (repeat b) c', repeat=bs) | |
for i in range(chunks+1): | |
if i == chunks: | |
img_target = vid_target[:, i*bs:, :, :, :] | |
bs = t-i*bs | |
alpha_start_r = alpha_start_r[:bs] | |
alpha_r2s_r = alpha_r2s_r[:bs] | |
feat_rgb_r = [feat[:bs] for feat in feat_rgb_r] | |
z_s2r_r = z_s2r_r[:bs] | |
else: | |
img_target = vid_target[:, i*bs:(i+1)*bs, :, :, :] | |
alpha = self.enc.enc_transfer_vid(alpha_r2s_r, img_target.squeeze(0), alpha_start_r) | |
img_recon = self.dec(z_s2r_r, alpha, feat_rgb_r) # bs x 3 x h x w | |
vid_target_recon.append(img_recon) | |
vid_target_recon = torch.cat(vid_target_recon, dim=0).unsqueeze(0) # 1xTCHW | |
vid_target_recon = rearrange(vid_target_recon, 'b t c h w -> b c t h w') | |
return vid_target_recon # BCTHW | |
def interpolate_img(self, img_source, d_l, v_l): | |
vid_target_recon = [] | |
step = 16 | |
v_start = np.array([0.] * len(v_l)) | |
v_end = np.array(v_l) | |
stride = (v_end - v_start) / step | |
z_s2r, feat_rgb = self.enc.enc_2r(img_source) | |
v_tmp = v_start | |
for i in range(step): | |
v_tmp = v_tmp + stride | |
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) | |
img_recon = self.dec(z_s2r, alpha, feat_rgb) | |
vid_target_recon.append(img_recon.unsqueeze(2)) | |
for i in range(step): | |
v_tmp = v_tmp - stride | |
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) | |
img_recon = self.dec(z_s2r, alpha, feat_rgb) | |
vid_target_recon.append(img_recon.unsqueeze(2)) | |
if (v_l[6]!=0) or (v_l[7]!=0) or (v_l[8]!=0) or (v_l[9]!=0): | |
for i in range(step): | |
v_tmp = v_tmp + stride | |
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) | |
img_recon = self.dec(z_s2r, alpha, feat_rgb) | |
vid_target_recon.append(img_recon.unsqueeze(2)) | |
for i in range(step): | |
v_tmp = v_tmp - stride | |
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) | |
img_recon = self.dec(z_s2r, alpha, feat_rgb) | |
vid_target_recon.append(img_recon.unsqueeze(2)) | |
else: | |
for i in range(step): | |
v_tmp = v_tmp - stride | |
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) | |
img_recon = self.dec(z_s2r, alpha, feat_rgb) | |
vid_target_recon.append(img_recon.unsqueeze(2)) | |
for i in range(step): | |
v_tmp = v_tmp + stride | |
alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp) | |
img_recon = self.dec(z_s2r, alpha, feat_rgb) | |
vid_target_recon.append(img_recon.unsqueeze(2)) | |
vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW | |
return vid_target_recon |