Spaces:
Runtime error
Runtime error
Linoy Tsaban
commited on
Commit
·
17db690
1
Parent(s):
3fcb5ce
Update inversion_utils.py
Browse files- inversion_utils.py +6 -22
inversion_utils.py
CHANGED
|
@@ -29,27 +29,11 @@ def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
|
|
| 29 |
image = image[offset:offset + w]
|
| 30 |
image = np.array(Image.fromarray(image).resize((512, 512)))
|
| 31 |
image = torch.from_numpy(image).float() / 127.5 - 1
|
| 32 |
-
image = image.permute(2, 0, 1).unsqueeze(0).to(device)
|
| 33 |
|
| 34 |
return image
|
| 35 |
|
| 36 |
|
| 37 |
-
def load_real_image(folder = "data/", img_name = None, idx = 0, img_size=512, device='cuda'):
|
| 38 |
-
from PIL import Image
|
| 39 |
-
from glob import glob
|
| 40 |
-
if img_name is not None:
|
| 41 |
-
path = os.path.join(folder, img_name)
|
| 42 |
-
else:
|
| 43 |
-
path = glob(folder + "*")[idx]
|
| 44 |
-
|
| 45 |
-
img = Image.open(path).resize((img_size,
|
| 46 |
-
img_size))
|
| 47 |
-
|
| 48 |
-
img = pil_to_tensor(img).to(device)
|
| 49 |
-
|
| 50 |
-
if img.shape[1]== 4:
|
| 51 |
-
img = img[:,:3,:,:]
|
| 52 |
-
return img
|
| 53 |
|
| 54 |
def mu_tilde(model, xt,x0, timestep):
|
| 55 |
"mu_tilde(x_t, x_0) DDPM paper eq. 7"
|
|
@@ -77,10 +61,10 @@ def sample_xts_from_x0(model, x0, num_inference_steps=50):
|
|
| 77 |
|
| 78 |
timesteps = model.scheduler.timesteps.to(model.device)
|
| 79 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
| 80 |
-
xts = torch.zeros(variance_noise_shape).to(x0.device)
|
| 81 |
for t in reversed(timesteps):
|
| 82 |
idx = t_to_idx[int(t)]
|
| 83 |
-
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
|
| 84 |
xts = torch.cat([xts, x0 ],dim = 0)
|
| 85 |
|
| 86 |
return xts
|
|
@@ -151,7 +135,7 @@ def inversion_forward_process(model, x0,
|
|
| 151 |
if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
|
| 152 |
xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
|
| 153 |
alpha_bar = model.scheduler.alphas_cumprod
|
| 154 |
-
zs = torch.zeros(size=variance_noise_shape, device=model.device)
|
| 155 |
|
| 156 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
| 157 |
xt = x0
|
|
@@ -230,7 +214,7 @@ def reverse_step(model, model_output, timestep, sample, eta = 0, variance_noise=
|
|
| 230 |
# 8. Add noice if eta > 0
|
| 231 |
if eta > 0:
|
| 232 |
if variance_noise is None:
|
| 233 |
-
variance_noise = torch.randn(model_output.shape, device=model.device)
|
| 234 |
sigma_z = eta * variance ** (0.5) * variance_noise
|
| 235 |
prev_sample = prev_sample + sigma_z
|
| 236 |
|
|
@@ -248,7 +232,7 @@ def inversion_reverse_process(model,
|
|
| 248 |
|
| 249 |
batch_size = len(prompts)
|
| 250 |
|
| 251 |
-
cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device)
|
| 252 |
|
| 253 |
text_embeddings = encode_text(model, prompts)
|
| 254 |
uncond_embedding = encode_text(model, [""] * batch_size)
|
|
|
|
| 29 |
image = image[offset:offset + w]
|
| 30 |
image = np.array(Image.fromarray(image).resize((512, 512)))
|
| 31 |
image = torch.from_numpy(image).float() / 127.5 - 1
|
| 32 |
+
image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype =torch.float16)
|
| 33 |
|
| 34 |
return image
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
def mu_tilde(model, xt,x0, timestep):
|
| 39 |
"mu_tilde(x_t, x_0) DDPM paper eq. 7"
|
|
|
|
| 61 |
|
| 62 |
timesteps = model.scheduler.timesteps.to(model.device)
|
| 63 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
| 64 |
+
xts = torch.zeros(variance_noise_shape).to(x0.device, dtype =torch.float16)
|
| 65 |
for t in reversed(timesteps):
|
| 66 |
idx = t_to_idx[int(t)]
|
| 67 |
+
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0, dtype =torch.float16) * sqrt_one_minus_alpha_bar[t]
|
| 68 |
xts = torch.cat([xts, x0 ],dim = 0)
|
| 69 |
|
| 70 |
return xts
|
|
|
|
| 135 |
if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
|
| 136 |
xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
|
| 137 |
alpha_bar = model.scheduler.alphas_cumprod
|
| 138 |
+
zs = torch.zeros(size=variance_noise_shape, device=model.device, dtype =torch.float16)
|
| 139 |
|
| 140 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
| 141 |
xt = x0
|
|
|
|
| 214 |
# 8. Add noice if eta > 0
|
| 215 |
if eta > 0:
|
| 216 |
if variance_noise is None:
|
| 217 |
+
variance_noise = torch.randn(model_output.shape, device=model.device, dtype =torch.float16)
|
| 218 |
sigma_z = eta * variance ** (0.5) * variance_noise
|
| 219 |
prev_sample = prev_sample + sigma_z
|
| 220 |
|
|
|
|
| 232 |
|
| 233 |
batch_size = len(prompts)
|
| 234 |
|
| 235 |
+
cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device, dtype=torch.float16)
|
| 236 |
|
| 237 |
text_embeddings = encode_text(model, prompts)
|
| 238 |
uncond_embedding = encode_text(model, [""] * batch_size)
|