Spaces:
Sleeping
Sleeping
import json | |
import os | |
import numpy as np | |
import pandas as pd | |
import torch | |
from DeepCache import DeepCacheSDHelper | |
from diffusers import ( | |
LMSDiscreteScheduler, | |
StableDiffusionImg2ImgPipeline, | |
) | |
from torch import nn | |
from torchmetrics.functional.image import structural_similarity_index_measure as ssim | |
from torchvision import transforms | |
def get_top_misclassified(val_classifier_json): | |
with open(val_classifier_json) as f: | |
val_output = json.load(f) | |
val_metrics_df = pd.DataFrame.from_dict( | |
val_output["val_metrics_details"], orient="index" | |
) | |
class_dict = dict() | |
for k, v in val_metrics_df["top_n_classes"].items(): | |
class_dict[k] = v | |
return class_dict | |
def get_class_list(val_classifier_json): | |
with open(val_classifier_json, "r") as f: | |
data = json.load(f) | |
return sorted(list(data["val_metrics_details"].keys())) | |
def generateClassPairs(val_classifier_json): | |
pairs = set() | |
misclassified_classes = get_top_misclassified(val_classifier_json) | |
for key, value in misclassified_classes.items(): | |
for v in value: | |
pairs.add(tuple(sorted([key, v]))) | |
return sorted(list(pairs)) | |
def outputDirectory(class_pairs, synth_path, metadata_path): | |
for id in class_pairs: | |
class_folder = f"{synth_path}/{id}" | |
if not (os.path.exists(class_folder)): | |
os.makedirs(class_folder) | |
if not (os.path.exists(metadata_path)): | |
os.makedirs(metadata_path) | |
print("Info: Output directory ready.") | |
def pipe_img( | |
model_path, | |
device="cuda", | |
apply_optimization=True, | |
use_torchcompile=False, | |
ci_cb=(5, 1), | |
use_safetensors=None, | |
cpu_offload=False, | |
scheduler=None, | |
): | |
if scheduler is None: | |
scheduler = LMSDiscreteScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
num_train_timesteps=1000, | |
steps_offset=1, | |
) | |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
model_path, | |
scheduler=scheduler, | |
torch_dtype=torch.float32, | |
use_safetensors=use_safetensors, | |
).to(device) | |
if cpu_offload: | |
pipe.enable_model_cpu_offload() | |
if apply_optimization: | |
helper = DeepCacheSDHelper(pipe=pipe) | |
cache_interval, cache_branch_id = ci_cb | |
helper.set_params( | |
cache_interval=cache_interval, cache_branch_id=cache_branch_id | |
) | |
helper.enable() | |
if use_torchcompile: | |
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) | |
return pipe | |
def createPrompts( | |
class_name_pairs, | |
prompt_structure=None, | |
use_default_negative_prompt=False, | |
negative_prompt=None, | |
): | |
if prompt_structure is None: | |
prompt_structure = "a photo of a <class_name>" | |
elif "<class_name>" not in prompt_structure: | |
raise ValueError( | |
"The prompt structure must contain the <class_name> placeholder." | |
) | |
if use_default_negative_prompt: | |
default_negative_prompt = ( | |
"blurry image, disfigured, deformed, distorted, cartoon, drawings" | |
) | |
negative_prompt = default_negative_prompt | |
class1 = class_name_pairs[0] | |
class2 = class_name_pairs[1] | |
prompt1 = prompt_structure.replace("<class_name>", class1) | |
prompt2 = prompt_structure.replace("<class_name>", class2) | |
prompts = [prompt1, prompt2] | |
if negative_prompt is None: | |
print("Info: Negative prompt not provided, returning as None.") | |
return prompts, None | |
else: | |
negative_prompts = [negative_prompt] * len(prompts) | |
return prompts, negative_prompts | |
def interpolatePrompts( | |
prompts, | |
pipeline, | |
num_interpolation_steps, | |
sample_mid_interpolation, | |
remove_n_middle=0, | |
device="cuda", | |
): | |
def slerp(v0, v1, num, t0=0, t1=1): | |
v0 = v0.detach().cpu().numpy() | |
v1 = v1.detach().cpu().numpy() | |
def interpolation(t, v0, v1, DOT_THRESHOLD=0.9995): | |
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) | |
if np.abs(dot) > DOT_THRESHOLD: | |
v2 = (1 - t) * v0 + t * v1 | |
else: | |
theta_0 = np.arccos(dot) | |
sin_theta_0 = np.sin(theta_0) | |
theta_t = theta_0 * t | |
sin_theta_t = np.sin(theta_t) | |
s0 = np.sin(theta_0 - theta_t) / sin_theta_0 | |
s1 = sin_theta_t / sin_theta_0 | |
v2 = s0 * v0 + s1 * v1 | |
return v2 | |
t = np.linspace(t0, t1, num) | |
v3 = torch.tensor(np.array([interpolation(t[i], v0, v1) for i in range(num)])) | |
return v3 | |
def get_middle_elements(lst, n): | |
if n % 2 == 0: # Even number of elements | |
middle_index = len(lst) // 2 - 1 | |
start = middle_index - n // 2 + 1 | |
end = middle_index + n // 2 + 1 | |
return lst[start:end], range(start, end) | |
else: # Odd number of elements | |
middle_index = len(lst) // 2 | |
start = middle_index - n // 2 | |
end = middle_index + n // 2 + 1 | |
return lst[start:end], range(start, end) | |
def remove_middle(data, n): | |
if n < 0 or n > len(data): | |
raise ValueError( | |
"Invalid value for n. It should be non-negative and less than half the list length" | |
) | |
middle = len(data) // 2 | |
if n == 1: | |
return data[:middle] + data[middle + 1 :] | |
elif n % 2 == 0: | |
return data[: middle - n // 2] + data[middle + n // 2 :] | |
else: | |
return data[: middle - n // 2] + data[middle + n // 2 + 1 :] | |
batch_size = len(prompts) | |
prompts_tokens = pipeline.tokenizer( | |
prompts, | |
padding="max_length", | |
max_length=pipeline.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
prompts_embeds = pipeline.text_encoder(prompts_tokens.input_ids.to(device))[0] | |
interpolated_prompt_embeds = [] | |
for i in range(batch_size - 1): | |
interpolated_prompt_embeds.append( | |
slerp(prompts_embeds[i], prompts_embeds[i + 1], num_interpolation_steps) | |
) | |
full_interpolated_prompt_embeds = interpolated_prompt_embeds[:] | |
interpolated_prompt_embeds[0], sample_range = get_middle_elements( | |
interpolated_prompt_embeds[0], sample_mid_interpolation | |
) | |
if remove_n_middle > 0: | |
interpolated_prompt_embeds[0] = remove_middle( | |
interpolated_prompt_embeds[0], remove_n_middle | |
) | |
prompt_metadata = dict() | |
similarity = nn.CosineSimilarity(dim=-1, eps=1e-6) | |
for i in range(num_interpolation_steps): | |
class1_sim = ( | |
similarity( | |
full_interpolated_prompt_embeds[0][0], | |
full_interpolated_prompt_embeds[0][i], | |
) | |
.mean() | |
.item() | |
) | |
class2_sim = ( | |
similarity( | |
full_interpolated_prompt_embeds[0][num_interpolation_steps - 1], | |
full_interpolated_prompt_embeds[0][i], | |
) | |
.mean() | |
.item() | |
) | |
relative_distance = class1_sim / (class1_sim + class2_sim) | |
prompt_metadata[i] = { | |
"selected": i in sample_range, | |
"similarity": { | |
"class1": class1_sim, | |
"class2": class2_sim, | |
"class1_relative_distance": relative_distance, | |
"class2_relative_distance": 1 - relative_distance, | |
}, | |
"nearest_class": int(relative_distance < 0.5), | |
} | |
interpolated_prompt_embeds = torch.cat(interpolated_prompt_embeds, dim=0).to(device) | |
return interpolated_prompt_embeds, prompt_metadata | |
def genClassImg( | |
pipeline, | |
pos_embed, | |
neg_embed, | |
input_image, | |
generator, | |
latents, | |
num_imgs=1, | |
height=512, | |
width=512, | |
num_inference_steps=25, | |
guidance_scale=7.5, | |
): | |
if neg_embed is not None: | |
npe = neg_embed[None, ...] | |
else: | |
npe = None | |
return pipeline( | |
height=height, | |
width=width, | |
num_images_per_prompt=num_imgs, | |
prompt_embeds=pos_embed[None, ...], | |
negative_prompt_embeds=npe, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
latents=latents, | |
image=input_image, | |
).images[0] | |
def getMetadata( | |
class_pairs, | |
path, | |
seed, | |
guidance_scale, | |
num_inference_steps, | |
num_interpolation_steps, | |
sample_mid_interpolation, | |
height, | |
width, | |
prompts, | |
negative_prompts, | |
pipeline, | |
prompt_metadata, | |
negative_prompt_metadata, | |
ssim_metadata=None, | |
save_json=True, | |
save_path=".", | |
): | |
metadata = dict() | |
metadata["class_pairs"] = class_pairs | |
metadata["path"] = path | |
metadata["seed"] = seed | |
metadata["params"] = { | |
"CFG": guidance_scale, | |
"inferenceSteps": num_inference_steps, | |
"interpolationSteps": num_interpolation_steps, | |
"sampleMidInterpolation": sample_mid_interpolation, | |
"height": height, | |
"width": width, | |
} | |
for i in range(len(prompts)): | |
metadata[f"prompt_text_{i}"] = prompts[i] | |
if negative_prompts is not None: | |
metadata[f"negative_prompt_text_{i}"] = negative_prompts[i] | |
metadata["pipe_config"] = dict(pipeline.config) | |
metadata["prompt_embed_similarity"] = prompt_metadata | |
metadata["negative_prompt_embed_similarity"] = negative_prompt_metadata | |
if ssim_metadata is not None: | |
print("Info: SSIM scores are available.") | |
metadata["ssim_scores"] = ssim_metadata | |
if save_json: | |
with open( | |
os.path.join(save_path, f"{'_'.join(i for i in class_pairs)}_{seed}.json"), | |
"w", | |
) as f: | |
json.dump(metadata, f, indent=4) | |
return metadata | |
def groupbyInterpolation(dir_to_classfolder): | |
files = [ | |
(f.split(sep="_")[1].split(sep=".")[0], os.path.join(dir_to_classfolder, f)) | |
for f in os.listdir(dir_to_classfolder) | |
] | |
for interpolation_step, file_path in files: | |
new_dir = os.path.join(dir_to_classfolder, interpolation_step) | |
if not os.path.exists(new_dir): | |
os.makedirs(new_dir) | |
os.rename(file_path, os.path.join(new_dir, os.path.basename(file_path))) | |
def ungroupInterpolation(dir_to_classfolder): | |
for interpolation_step in os.listdir(dir_to_classfolder): | |
if os.path.isdir(os.path.join(dir_to_classfolder, interpolation_step)): | |
for f in os.listdir(os.path.join(dir_to_classfolder, interpolation_step)): | |
os.rename( | |
os.path.join(dir_to_classfolder, interpolation_step, f), | |
os.path.join(dir_to_classfolder, f), | |
) | |
os.rmdir(os.path.join(dir_to_classfolder, interpolation_step)) | |
def groupAllbyInterpolation( | |
data_path, | |
group=True, | |
fn_group=groupbyInterpolation, | |
fn_ungroup=ungroupInterpolation, | |
): | |
data_classes = sorted(os.listdir(data_path)) | |
if group: | |
fn = fn_group | |
else: | |
fn = fn_ungroup | |
for c in data_classes: | |
c_path = os.path.join(data_path, c) | |
if os.path.isdir(c_path): | |
fn(c_path) | |
print(f"Processed {c}") | |
def getPairIndices(subset_len, total_pair_count=1, seed=None): | |
rng = np.random.default_rng(seed) | |
group_size = (subset_len + total_pair_count - 1) // total_pair_count | |
numbers = list(range(subset_len)) | |
numbers_selection = list(range(subset_len)) | |
rng.shuffle(numbers) | |
for i in range(group_size - subset_len % group_size): | |
numbers.append(numbers_selection[i]) | |
numbers = np.array(numbers) | |
groups = numbers[: group_size * total_pair_count].reshape(-1, group_size) | |
return groups.tolist() | |
def generateImagesFromDataset( | |
img_subsets, | |
class_iterables, | |
pipeline, | |
interpolated_prompt_embeds, | |
interpolated_negative_prompts_embeds, | |
num_inference_steps, | |
guidance_scale, | |
height=512, | |
width=512, | |
seed=None, | |
save_path=".", | |
class_pairs=("0", "1"), | |
save_image=True, | |
image_type="jpg", | |
interpolate_range="full", | |
device="cuda", | |
return_images=False, | |
): | |
if interpolate_range == "nearest": | |
nearest_half = True | |
furthest_half = False | |
elif interpolate_range == "furthest": | |
nearest_half = False | |
furthest_half = True | |
else: | |
nearest_half = False | |
furthest_half = False | |
if seed is None: | |
seed = torch.Generator().seed() | |
generator = torch.manual_seed(seed) | |
rng = np.random.default_rng(seed) | |
# Generating initial U-Net latent vectors from a random normal distribution. | |
latents = torch.randn( | |
(1, pipeline.unet.config.in_channels, height // 8, width // 8), | |
generator=generator, | |
).to(device) | |
embed_len = len(interpolated_prompt_embeds) | |
embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds) | |
embed_pairs_list = list(embed_pairs) | |
if return_images: | |
class_images = dict() | |
class_ssim = dict() | |
if nearest_half or furthest_half: | |
if nearest_half: | |
steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len)) | |
mutiplier = 2 | |
elif furthest_half: | |
# uses opposite class of images of the text interpolation | |
steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2)) | |
mutiplier = 2 | |
else: | |
steps_range = (range(embed_len), range(embed_len)) | |
mutiplier = 1 | |
for class_iter, class_id in enumerate(class_pairs): | |
if return_images: | |
class_images[class_id] = list() | |
class_ssim[class_id] = { | |
i: {"ssim_sum": 0, "ssim_count": 0, "ssim_avg": 0} for i in range(embed_len) | |
} | |
subset_len = len(img_subsets[class_id]) | |
group_map = ( | |
list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1) | |
) | |
rng.shuffle( | |
group_map | |
) | |
iter_indices = class_iterables[class_id].pop() | |
for image_id in iter_indices: | |
img, trg = img_subsets[class_id][image_id] | |
input_image = img.unsqueeze(0) | |
interpolate_step = group_map[image_id] | |
prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolate_step] | |
generated_image = genClassImg( | |
pipeline, | |
prompt_embeds, | |
negative_prompt_embeds, | |
input_image, | |
generator, | |
latents, | |
num_imgs=1, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
) | |
pred_image = transforms.ToTensor()(generated_image).unsqueeze(0) | |
ssim_score = ssim(pred_image, input_image).item() | |
class_ssim[class_id][interpolate_step]["ssim_sum"] += ssim_score | |
class_ssim[class_id][interpolate_step]["ssim_count"] += 1 | |
if return_images: | |
class_images[class_id].append(generated_image) | |
if save_image: | |
if image_type == "jpg": | |
generated_image.save( | |
f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}", | |
format="JPEG", | |
quality=95, | |
) | |
elif image_type == "png": | |
generated_image.save( | |
f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}", | |
format="PNG", | |
) | |
else: | |
generated_image.save( | |
f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}" | |
) | |
for i_step in range(embed_len): | |
if class_ssim[class_id][i_step]["ssim_count"] > 0: | |
class_ssim[class_id][i_step]["ssim_avg"] = ( | |
class_ssim[class_id][i_step]["ssim_sum"] | |
/ class_ssim[class_id][i_step]["ssim_count"] | |
) | |
if return_images: | |
return class_images, class_ssim | |
else: | |
return class_ssim | |
def generateTrace( | |
prompts, | |
img_subsets, | |
class_iterables, | |
interpolated_prompt_embeds, | |
interpolated_negative_prompts_embeds, | |
subset_indices, | |
seed=None, | |
save_path=".", | |
class_pairs=("0", "1"), | |
image_type="jpg", | |
interpolate_range="full", | |
save_prompt_embeds=False, | |
): | |
trace_dict = { | |
"class_pairs": list(), | |
"class_id": list(), | |
"image_id": list(), | |
"interpolation_step": list(), | |
"embed_len": list(), | |
"pos_prompt_text": list(), | |
"neg_prompt_text": list(), | |
"input_file_path": list(), | |
"output_file_path": list(), | |
"input_prompts_embed": list(), | |
} | |
if interpolate_range == "nearest": | |
nearest_half = True | |
furthest_half = False | |
elif interpolate_range == "furthest": | |
nearest_half = False | |
furthest_half = True | |
else: | |
nearest_half = False | |
furthest_half = False | |
if seed is None: | |
seed = torch.Generator().seed() | |
rng = np.random.default_rng(seed) | |
embed_len = len(interpolated_prompt_embeds) | |
embed_pairs = zip( | |
interpolated_prompt_embeds.cpu().numpy(), | |
interpolated_negative_prompts_embeds.cpu().numpy(), | |
) | |
embed_pairs_list = list(embed_pairs) | |
if nearest_half or furthest_half: | |
if nearest_half: | |
steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len)) | |
mutiplier = 2 | |
elif furthest_half: | |
# uses opposite class of images of the text interpolation | |
steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2)) | |
mutiplier = 2 | |
else: | |
steps_range = (range(embed_len), range(embed_len)) | |
mutiplier = 1 | |
for class_iter, class_id in enumerate(class_pairs): | |
subset_len = len(img_subsets[class_id]) | |
group_map = ( | |
list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1) | |
) | |
rng.shuffle( | |
group_map | |
) | |
iter_indices = class_iterables[class_id].pop() | |
for image_id in iter_indices: | |
class_ds = img_subsets[class_id] | |
interpolate_step = group_map[image_id] | |
sample_count = subset_indices[class_id][0] + image_id | |
input_file = os.path.normpath(class_ds.dataset.samples[sample_count][0]) | |
pos_prompt = prompts[0] | |
neg_prompt = prompts[1] | |
output_file = f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}" | |
if save_prompt_embeds: | |
input_prompts_embed = embed_pairs_list[interpolate_step] | |
else: | |
input_prompts_embed = None | |
trace_dict["class_pairs"].append(class_pairs) | |
trace_dict["class_id"].append(class_id) | |
trace_dict["image_id"].append(image_id) | |
trace_dict["interpolation_step"].append(interpolate_step) | |
trace_dict["embed_len"].append(embed_len) | |
trace_dict["pos_prompt_text"].append(pos_prompt) | |
trace_dict["neg_prompt_text"].append(neg_prompt) | |
trace_dict["input_file_path"].append(input_file) | |
trace_dict["output_file_path"].append(output_file) | |
trace_dict["input_prompts_embed"].append(input_prompts_embed) | |
return trace_dict |