Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import math | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
class Conceptrol: | |
def __init__(self, config): | |
if "name" not in config: | |
raise KeyError("name has to be provided as 'conceptrol' or 'ominicontrol'") | |
name = config["name"] | |
if name not in ["conceptrol", "ominicontrol"]: | |
raise ValueError( | |
f"Name must be one of ['conceptrol', 'ominicontrol'], got {name}" | |
) | |
try: | |
log_attn_map = config["log_attn_map"] | |
except KeyError: | |
log_attn_map = False | |
# static | |
self.NUM_BLOCKS = 19 # this is fixed for FLUX | |
self.M = 512 # num of text tokens, fixed for FLUX | |
self.N = 1024 # num of latent / image condtion tokens, fixed for FLUX | |
self.EP = -10e6 | |
self.CONCEPT_BLOCK_IDX = 18 | |
# fixed during one generation | |
self.name = name | |
# variable during one generation | |
self.textual_concept_mask = None | |
self.forward_count = 0 | |
# log out for visualization | |
if log_attn_map: | |
self.attn_maps = {"latent_to_concept": [], "latent_to_image": []} | |
def __call__( | |
self, | |
query: torch.FloatTensor, | |
key: torch.FloatTensor, | |
attention_mask: torch.Tensor, | |
c_factor: float = 1.0, | |
) -> torch.Tensor: | |
if not hasattr(self, "textual_concept_idx"): | |
raise AttributeError( | |
"textual_concept_idx must be registered before calling Conceptrol" | |
) | |
# Skip computation for ominicontrol | |
if self.name == "ominicontrol": | |
scale_factor = 1 / math.sqrt(query.size(-1)) | |
attention_weight = ( | |
query @ key.transpose(-2, -1) * scale_factor + attention_mask | |
) | |
attention_probs = torch.softmax( | |
attention_weight, dim=-1 | |
) # [B, H, M+2N, M+2N] | |
return attention_probs | |
if not self.textual_concept_idx[0] < self.textual_concept_idx[1]: | |
raise ValueError( | |
f"register_idx[0] must be less than register_idx[1], " | |
f"got {self.textual_concept_idx[0]} >= {self.textual_concept_idx[1]}" | |
) | |
### Reset attention mask predefined in ominicontrol | |
attention_mask = torch.zeros_like(attention_mask) | |
bias = torch.log(c_factor[0]) | |
# attention of image condition to latent | |
attention_mask[-self.N :, self.M : -self.N] = bias | |
# attention of latent to image condition | |
attention_mask[self.M : -self.N, -self.N :] = bias | |
# attention of textual concept to image condition | |
attention_mask[ | |
self.textual_concept_idx[0] : self.textual_concept_idx[1], -self.N : | |
] = bias | |
# attention of other words to image condition (set as negative inf) | |
attention_mask[: self.textual_concept_idx[0], -self.N :] = self.EP | |
attention_mask[self.textual_concept_idx[1] : self.M, -self.N :] = self.EP | |
# If there is no textual_concept_mask, it means currently in layers previous to the first concept-specific block | |
if self.textual_concept_mask is None: | |
self.textual_concept_mask = ( | |
torch.zeros_like(attention_mask).unsqueeze(0).unsqueeze(0) | |
) | |
### Compute attention | |
scale_factor = 1 / math.sqrt(query.size(-1)) | |
attention_weight = ( | |
query @ key.transpose(-2, -1) * scale_factor | |
+ attention_mask | |
+ self.textual_concept_mask | |
) | |
# [B, H, M+2N, M+2N] | |
attention_probs = torch.softmax(attention_weight, dim=-1) | |
### Extract textual concept mask if it's concept-specific block | |
is_concept_block = ( | |
self.forward_count % self.NUM_BLOCKS == self.CONCEPT_BLOCK_IDX | |
) | |
if is_concept_block: | |
# Shape: [B, H, N, S], where S is the token numbers of the subject | |
textual_concept_mask_local = attention_probs[ | |
:, | |
:, | |
self.M : -self.N, | |
self.textual_concept_idx[0] : self.textual_concept_idx[1], | |
] | |
# Consider the ratio within context of text | |
textual_concept_mask_local = textual_concept_mask_local / torch.sum( | |
attention_probs[:, :, self.M : -self.N, : self.M], dim=-1, keepdim=True | |
) | |
# Average over words and head, Shape: [B, 1, N, 1] | |
textual_concept_mask_local = torch.mean( | |
textual_concept_mask_local, dim=(-1, 1), keepdim=True | |
) | |
# Normalize to average as 1 | |
textual_concept_mask_local = textual_concept_mask_local / torch.mean( | |
textual_concept_mask_local, dim=-2, keepdim=True | |
) | |
self.textual_concept_mask = ( | |
torch.zeros_like(attention_mask).unsqueeze(0).unsqueeze(0) | |
) | |
# log(A) in the paper | |
self.textual_concept_mask[:, :, self.M : -self.N, -self.N :] = torch.log( | |
textual_concept_mask_local | |
) | |
self.forward_count += 1 | |
return attention_probs | |
def register(self, textual_concept_idx): | |
self.textual_concept_idx = textual_concept_idx | |
def visualize_attn_map(self, config_name: str, subject: str): | |
global global_concept_mask | |
global forward_count | |
save_dir = f"attn_maps/{config_name}/{subject}" | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir) | |
for attn_map_name, attn_maps in self.attn_maps.items(): | |
if "token_to_token" in attn_map_name: | |
continue | |
plt.figure() | |
rows, cols = 8, 19 | |
fig, axes = plt.subplots( | |
rows, cols, figsize=(64 * cols / 100, 64 * rows / 100) | |
) | |
fig.subplots_adjust( | |
wspace=0.1, hspace=0.1 | |
) # Adjust spacing between subplots | |
# Plot each array in the list on the grid | |
for i, ax in enumerate(axes.flatten()): | |
if i < len(attn_maps): # Only plot existing arrays | |
attn_map = attn_maps[i] / np.amax(attn_maps[i]) | |
ax.imshow(attn_map, cmap="viridis") | |
ax.axis("off") # Turn off axes for clarity | |
else: | |
ax.axis("off") # Turn off unused subplots | |
fig.set_size_inches(64 * cols / 100, 64 * rows / 100) | |
save_path = os.path.join(save_dir, f"{attn_map_name}.jpg") | |
plt.savefig(save_path) | |
plt.close() | |
for attn_map_name, attn_maps in self.attn_maps.items(): | |
if "token_to_token" not in attn_map_name: | |
continue | |
plt.figure() | |
rows, cols = 8, 19 | |
fig, axes = plt.subplots( | |
rows, cols, figsize=(2560 * cols / 100, 2560 * rows / 100) | |
) | |
fig.subplots_adjust( | |
wspace=0.1, hspace=0.1 | |
) # Adjust spacing between subplots | |
# Plot each array in the list on the grid | |
for i, ax in enumerate(axes.flatten()): | |
if i < len(attn_maps): # Only plot existing arrays | |
attn_map = attn_maps[i] / np.amax(attn_maps[i]) | |
ax.imshow(attn_map, cmap="viridis") | |
ax.axis("off") # Turn off axes for clarity | |
else: | |
ax.axis("off") # Turn off unused subplots | |
fig.set_size_inches(64 * cols / 100, 64 * rows / 100) | |
save_path = os.path.join(save_dir, f"{attn_map_name}.jpg") | |
plt.savefig(save_path) | |
plt.close() | |
for attn_map_name in self.attn_maps.keys(): | |
self.attn_maps[attn_map_name] = [] | |
global_concept_mask = None | |
forward_count = 0 | |