Synthetic-data-gen / sourcecode.py
Chaitanya-02's picture
Update sourcecode.py
3ff558d verified
raw
history blame
19.8 kB
import json
import os
import numpy as np
import pandas as pd
import torch
from DeepCache import DeepCacheSDHelper
from diffusers import (
LMSDiscreteScheduler,
StableDiffusionImg2ImgPipeline,
)
from torch import nn
from torchmetrics.functional.image import structural_similarity_index_measure as ssim
from torchvision import transforms
def get_top_misclassified(val_classifier_json):
with open(val_classifier_json) as f:
val_output = json.load(f)
val_metrics_df = pd.DataFrame.from_dict(
val_output["val_metrics_details"], orient="index"
)
class_dict = dict()
for k, v in val_metrics_df["top_n_classes"].items():
class_dict[k] = v
return class_dict
def get_class_list(val_classifier_json):
with open(val_classifier_json, "r") as f:
data = json.load(f)
return sorted(list(data["val_metrics_details"].keys()))
def generateClassPairs(val_classifier_json):
pairs = set()
misclassified_classes = get_top_misclassified(val_classifier_json)
for key, value in misclassified_classes.items():
for v in value:
pairs.add(tuple(sorted([key, v])))
return sorted(list(pairs))
def outputDirectory(class_pairs, synth_path, metadata_path):
for id in class_pairs:
class_folder = f"{synth_path}/{id}"
if not (os.path.exists(class_folder)):
os.makedirs(class_folder)
if not (os.path.exists(metadata_path)):
os.makedirs(metadata_path)
print("Info: Output directory ready.")
def pipe_img(
model_path,
device="cuda",
apply_optimization=True,
use_torchcompile=False,
ci_cb=(5, 1),
use_safetensors=None,
cpu_offload=False,
scheduler=None,
):
if scheduler is None:
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
steps_offset=1,
)
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_path,
scheduler=scheduler,
torch_dtype=torch.float32,
use_safetensors=use_safetensors,
).to(device)
if cpu_offload:
pipe.enable_model_cpu_offload()
if apply_optimization:
helper = DeepCacheSDHelper(pipe=pipe)
cache_interval, cache_branch_id = ci_cb
helper.set_params(
cache_interval=cache_interval, cache_branch_id=cache_branch_id
)
helper.enable()
if use_torchcompile:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
return pipe
def createPrompts(
class_name_pairs,
prompt_structure=None,
use_default_negative_prompt=False,
negative_prompt=None,
):
if prompt_structure is None:
prompt_structure = "a photo of a <class_name>"
elif "<class_name>" not in prompt_structure:
raise ValueError(
"The prompt structure must contain the <class_name> placeholder."
)
if use_default_negative_prompt:
default_negative_prompt = (
"blurry image, disfigured, deformed, distorted, cartoon, drawings"
)
negative_prompt = default_negative_prompt
class1 = class_name_pairs[0]
class2 = class_name_pairs[1]
prompt1 = prompt_structure.replace("<class_name>", class1)
prompt2 = prompt_structure.replace("<class_name>", class2)
prompts = [prompt1, prompt2]
if negative_prompt is None:
print("Info: Negative prompt not provided, returning as None.")
return prompts, None
else:
negative_prompts = [negative_prompt] * len(prompts)
return prompts, negative_prompts
def interpolatePrompts(
prompts,
pipeline,
num_interpolation_steps,
sample_mid_interpolation,
remove_n_middle=0,
device="cuda",
):
def slerp(v0, v1, num, t0=0, t1=1):
v0 = v0.detach().cpu().numpy()
v1 = v1.detach().cpu().numpy()
def interpolation(t, v0, v1, DOT_THRESHOLD=0.9995):
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
return v2
t = np.linspace(t0, t1, num)
v3 = torch.tensor(np.array([interpolation(t[i], v0, v1) for i in range(num)]))
return v3
def get_middle_elements(lst, n):
if n % 2 == 0: # Even number of elements
middle_index = len(lst) // 2 - 1
start = middle_index - n // 2 + 1
end = middle_index + n // 2 + 1
return lst[start:end], range(start, end)
else: # Odd number of elements
middle_index = len(lst) // 2
start = middle_index - n // 2
end = middle_index + n // 2 + 1
return lst[start:end], range(start, end)
def remove_middle(data, n):
if n < 0 or n > len(data):
raise ValueError(
"Invalid value for n. It should be non-negative and less than half the list length"
)
middle = len(data) // 2
if n == 1:
return data[:middle] + data[middle + 1 :]
elif n % 2 == 0:
return data[: middle - n // 2] + data[middle + n // 2 :]
else:
return data[: middle - n // 2] + data[middle + n // 2 + 1 :]
batch_size = len(prompts)
prompts_tokens = pipeline.tokenizer(
prompts,
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
prompts_embeds = pipeline.text_encoder(prompts_tokens.input_ids.to(device))[0]
interpolated_prompt_embeds = []
for i in range(batch_size - 1):
interpolated_prompt_embeds.append(
slerp(prompts_embeds[i], prompts_embeds[i + 1], num_interpolation_steps)
)
full_interpolated_prompt_embeds = interpolated_prompt_embeds[:]
interpolated_prompt_embeds[0], sample_range = get_middle_elements(
interpolated_prompt_embeds[0], sample_mid_interpolation
)
if remove_n_middle > 0:
interpolated_prompt_embeds[0] = remove_middle(
interpolated_prompt_embeds[0], remove_n_middle
)
prompt_metadata = dict()
similarity = nn.CosineSimilarity(dim=-1, eps=1e-6)
for i in range(num_interpolation_steps):
class1_sim = (
similarity(
full_interpolated_prompt_embeds[0][0],
full_interpolated_prompt_embeds[0][i],
)
.mean()
.item()
)
class2_sim = (
similarity(
full_interpolated_prompt_embeds[0][num_interpolation_steps - 1],
full_interpolated_prompt_embeds[0][i],
)
.mean()
.item()
)
relative_distance = class1_sim / (class1_sim + class2_sim)
prompt_metadata[i] = {
"selected": i in sample_range,
"similarity": {
"class1": class1_sim,
"class2": class2_sim,
"class1_relative_distance": relative_distance,
"class2_relative_distance": 1 - relative_distance,
},
"nearest_class": int(relative_distance < 0.5),
}
interpolated_prompt_embeds = torch.cat(interpolated_prompt_embeds, dim=0).to(device)
return interpolated_prompt_embeds, prompt_metadata
def genClassImg(
pipeline,
pos_embed,
neg_embed,
input_image,
generator,
latents,
num_imgs=1,
height=512,
width=512,
num_inference_steps=25,
guidance_scale=7.5,
):
if neg_embed is not None:
npe = neg_embed[None, ...]
else:
npe = None
return pipeline(
height=height,
width=width,
num_images_per_prompt=num_imgs,
prompt_embeds=pos_embed[None, ...],
negative_prompt_embeds=npe,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
latents=latents,
image=input_image,
).images[0]
def getMetadata(
class_pairs,
path,
seed,
guidance_scale,
num_inference_steps,
num_interpolation_steps,
sample_mid_interpolation,
height,
width,
prompts,
negative_prompts,
pipeline,
prompt_metadata,
negative_prompt_metadata,
ssim_metadata=None,
save_json=True,
save_path=".",
):
metadata = dict()
metadata["class_pairs"] = class_pairs
metadata["path"] = path
metadata["seed"] = seed
metadata["params"] = {
"CFG": guidance_scale,
"inferenceSteps": num_inference_steps,
"interpolationSteps": num_interpolation_steps,
"sampleMidInterpolation": sample_mid_interpolation,
"height": height,
"width": width,
}
for i in range(len(prompts)):
metadata[f"prompt_text_{i}"] = prompts[i]
if negative_prompts is not None:
metadata[f"negative_prompt_text_{i}"] = negative_prompts[i]
metadata["pipe_config"] = dict(pipeline.config)
metadata["prompt_embed_similarity"] = prompt_metadata
metadata["negative_prompt_embed_similarity"] = negative_prompt_metadata
if ssim_metadata is not None:
print("Info: SSIM scores are available.")
metadata["ssim_scores"] = ssim_metadata
if save_json:
with open(
os.path.join(save_path, f"{'_'.join(i for i in class_pairs)}_{seed}.json"),
"w",
) as f:
json.dump(metadata, f, indent=4)
return metadata
def groupbyInterpolation(dir_to_classfolder):
files = [
(f.split(sep="_")[1].split(sep=".")[0], os.path.join(dir_to_classfolder, f))
for f in os.listdir(dir_to_classfolder)
]
for interpolation_step, file_path in files:
new_dir = os.path.join(dir_to_classfolder, interpolation_step)
if not os.path.exists(new_dir):
os.makedirs(new_dir)
os.rename(file_path, os.path.join(new_dir, os.path.basename(file_path)))
def ungroupInterpolation(dir_to_classfolder):
for interpolation_step in os.listdir(dir_to_classfolder):
if os.path.isdir(os.path.join(dir_to_classfolder, interpolation_step)):
for f in os.listdir(os.path.join(dir_to_classfolder, interpolation_step)):
os.rename(
os.path.join(dir_to_classfolder, interpolation_step, f),
os.path.join(dir_to_classfolder, f),
)
os.rmdir(os.path.join(dir_to_classfolder, interpolation_step))
def groupAllbyInterpolation(
data_path,
group=True,
fn_group=groupbyInterpolation,
fn_ungroup=ungroupInterpolation,
):
data_classes = sorted(os.listdir(data_path))
if group:
fn = fn_group
else:
fn = fn_ungroup
for c in data_classes:
c_path = os.path.join(data_path, c)
if os.path.isdir(c_path):
fn(c_path)
print(f"Processed {c}")
def getPairIndices(subset_len, total_pair_count=1, seed=None):
rng = np.random.default_rng(seed)
group_size = (subset_len + total_pair_count - 1) // total_pair_count
numbers = list(range(subset_len))
numbers_selection = list(range(subset_len))
rng.shuffle(numbers)
for i in range(group_size - subset_len % group_size):
numbers.append(numbers_selection[i])
numbers = np.array(numbers)
groups = numbers[: group_size * total_pair_count].reshape(-1, group_size)
return groups.tolist()
def generateImagesFromDataset(
img_subsets,
class_iterables,
pipeline,
interpolated_prompt_embeds,
interpolated_negative_prompts_embeds,
num_inference_steps,
guidance_scale,
height=512,
width=512,
seed=None,
save_path=".",
class_pairs=("0", "1"),
save_image=True,
image_type="jpg",
interpolate_range="full",
device="cuda",
return_images=False,
):
if interpolate_range == "nearest":
nearest_half = True
furthest_half = False
elif interpolate_range == "furthest":
nearest_half = False
furthest_half = True
else:
nearest_half = False
furthest_half = False
if seed is None:
seed = torch.Generator().seed()
generator = torch.manual_seed(seed)
rng = np.random.default_rng(seed)
# Generating initial U-Net latent vectors from a random normal distribution.
latents = torch.randn(
(1, pipeline.unet.config.in_channels, height // 8, width // 8),
generator=generator,
).to(device)
embed_len = len(interpolated_prompt_embeds)
embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds)
embed_pairs_list = list(embed_pairs)
if return_images:
class_images = dict()
class_ssim = dict()
if nearest_half or furthest_half:
if nearest_half:
steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len))
mutiplier = 2
elif furthest_half:
# uses opposite class of images of the text interpolation
steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2))
mutiplier = 2
else:
steps_range = (range(embed_len), range(embed_len))
mutiplier = 1
for class_iter, class_id in enumerate(class_pairs):
if return_images:
class_images[class_id] = list()
class_ssim[class_id] = {
i: {"ssim_sum": 0, "ssim_count": 0, "ssim_avg": 0} for i in range(embed_len)
}
subset_len = len(img_subsets[class_id])
group_map = (
list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1)
)
rng.shuffle(
group_map
)
iter_indices = class_iterables[class_id].pop()
for image_id in iter_indices:
img, trg = img_subsets[class_id][image_id]
input_image = img.unsqueeze(0)
interpolate_step = group_map[image_id]
prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolate_step]
generated_image = genClassImg(
pipeline,
prompt_embeds,
negative_prompt_embeds,
input_image,
generator,
latents,
num_imgs=1,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
pred_image = transforms.ToTensor()(generated_image).unsqueeze(0)
ssim_score = ssim(pred_image, input_image).item()
class_ssim[class_id][interpolate_step]["ssim_sum"] += ssim_score
class_ssim[class_id][interpolate_step]["ssim_count"] += 1
if return_images:
class_images[class_id].append(generated_image)
if save_image:
if image_type == "jpg":
generated_image.save(
f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}",
format="JPEG",
quality=95,
)
elif image_type == "png":
generated_image.save(
f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}",
format="PNG",
)
else:
generated_image.save(
f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}"
)
for i_step in range(embed_len):
if class_ssim[class_id][i_step]["ssim_count"] > 0:
class_ssim[class_id][i_step]["ssim_avg"] = (
class_ssim[class_id][i_step]["ssim_sum"]
/ class_ssim[class_id][i_step]["ssim_count"]
)
if return_images:
return class_images, class_ssim
else:
return class_ssim
def generateTrace(
prompts,
img_subsets,
class_iterables,
interpolated_prompt_embeds,
interpolated_negative_prompts_embeds,
subset_indices,
seed=None,
save_path=".",
class_pairs=("0", "1"),
image_type="jpg",
interpolate_range="full",
save_prompt_embeds=False,
):
trace_dict = {
"class_pairs": list(),
"class_id": list(),
"image_id": list(),
"interpolation_step": list(),
"embed_len": list(),
"pos_prompt_text": list(),
"neg_prompt_text": list(),
"input_file_path": list(),
"output_file_path": list(),
"input_prompts_embed": list(),
}
if interpolate_range == "nearest":
nearest_half = True
furthest_half = False
elif interpolate_range == "furthest":
nearest_half = False
furthest_half = True
else:
nearest_half = False
furthest_half = False
if seed is None:
seed = torch.Generator().seed()
rng = np.random.default_rng(seed)
embed_len = len(interpolated_prompt_embeds)
embed_pairs = zip(
interpolated_prompt_embeds.cpu().numpy(),
interpolated_negative_prompts_embeds.cpu().numpy(),
)
embed_pairs_list = list(embed_pairs)
if nearest_half or furthest_half:
if nearest_half:
steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len))
mutiplier = 2
elif furthest_half:
# uses opposite class of images of the text interpolation
steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2))
mutiplier = 2
else:
steps_range = (range(embed_len), range(embed_len))
mutiplier = 1
for class_iter, class_id in enumerate(class_pairs):
subset_len = len(img_subsets[class_id])
group_map = (
list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1)
)
rng.shuffle(
group_map
)
iter_indices = class_iterables[class_id].pop()
for image_id in iter_indices:
class_ds = img_subsets[class_id]
interpolate_step = group_map[image_id]
sample_count = subset_indices[class_id][0] + image_id
input_file = os.path.normpath(class_ds.dataset.samples[sample_count][0])
pos_prompt = prompts[0]
neg_prompt = prompts[1]
output_file = f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}"
if save_prompt_embeds:
input_prompts_embed = embed_pairs_list[interpolate_step]
else:
input_prompts_embed = None
trace_dict["class_pairs"].append(class_pairs)
trace_dict["class_id"].append(class_id)
trace_dict["image_id"].append(image_id)
trace_dict["interpolation_step"].append(interpolate_step)
trace_dict["embed_len"].append(embed_len)
trace_dict["pos_prompt_text"].append(pos_prompt)
trace_dict["neg_prompt_text"].append(neg_prompt)
trace_dict["input_file_path"].append(input_file)
trace_dict["output_file_path"].append(output_file)
trace_dict["input_prompts_embed"].append(input_prompts_embed)
return trace_dict