"""
Helper scripts for generating synthetic images using diffusion model.

Functions:
    - get_top_misclassified
    - get_class_list
    - generateClassPairs
    - outputDirectory
    - pipe_img
    - createPrompts
    - interpolatePrompts
        - slerp
        - get_middle_elements
        - remove_middle
    - genClassImg
    - getMetadata
    - groupbyInterpolation
    - ungroupInterpolation
    - groupAllbyInterpolation
    - getPairIndices
    - generateImagesFromDataset
    - generateTrace
"""

import json
import os

import numpy as np
import pandas as pd
import torch
from DeepCache import DeepCacheSDHelper
from diffusers import (
    LMSDiscreteScheduler,
    StableDiffusionImg2ImgPipeline,
)
from torch import nn
from torchmetrics.functional.image import structural_similarity_index_measure as ssim
from torchvision import transforms


def get_top_misclassified(val_classifier_json):
    """
    Retrieves the top misclassified classes from a validation classifier JSON file.

    Args:
        val_classifier_json (str): The path to the validation classifier JSON file.

    Returns:
        dict: A dictionary containing the top misclassified classes, where the keys are the class names
              and the values are the number of misclassifications.
    """
    with open(val_classifier_json) as f:
        val_output = json.load(f)
    val_metrics_df = pd.DataFrame.from_dict(
        val_output["val_metrics_details"], orient="index"
    )
    class_dict = dict()
    for k, v in val_metrics_df["top_n_classes"].items():
        class_dict[k] = v
    return class_dict


def get_class_list(val_classifier_json):
    """
    Retrieves the list of classes from the given validation classifier JSON file.

    Args:
        val_classifier_json (str): The path to the validation classifier JSON file.

    Returns:
        list: A sorted list of class names extracted from the JSON file.
    """
    with open(val_classifier_json, "r") as f:
        data = json.load(f)
    return sorted(list(data["val_metrics_details"].keys()))


def generateClassPairs(val_classifier_json):
    """
    Generate pairs of misclassified classes from the given validation classifier JSON.

    Args:
        val_classifier_json (str): The path to the validation classifier JSON file.

    Returns:
        list: A sorted list of pairs of misclassified classes.
    """
    pairs = set()
    misclassified_classes = get_top_misclassified(val_classifier_json)
    for key, value in misclassified_classes.items():
        for v in value:
            pairs.add(tuple(sorted([key, v])))
    return sorted(list(pairs))


def outputDirectory(class_pairs, synth_path, metadata_path):
    """
    Creates the output directory structure for the synthesized data.

    Args:
        class_pairs (list): A list of class pairs.
        synth_path (str): The path to the directory where the synthesized data will be stored.
        metadata_path (str): The path to the directory where the metadata will be stored.

    Returns:
        None
    """
    for id in class_pairs:
        class_folder = f"{synth_path}/{id}"
        if not (os.path.exists(class_folder)):
            os.makedirs(class_folder)
    if not (os.path.exists(metadata_path)):
        os.makedirs(metadata_path)
    print("Info: Output directory ready.")


def pipe_img(
    model_path,
    device="cuda",
    apply_optimization=True,
    use_torchcompile=False,
    ci_cb=(5, 1),
    use_safetensors=None,
    cpu_offload=False,
    scheduler=None,
):
    """
    Creates and returns an image-to-image pipeline for stable diffusion.

    Args:
        model_path (str): The path to the pretrained model.
        device (str, optional): The device to use for computation. Defaults to "cuda".
        apply_optimization (bool, optional): Whether to apply optimization techniques. Defaults to True.
        use_torchcompile (bool, optional): Whether to use torchcompile for model compilation. Defaults to False.
        ci_cb (tuple, optional): A tuple containing the cache interval and cache branch ID. Defaults to (5, 1).
        use_safetensors (bool, optional): Whether to use safetensors. Defaults to None.
        cpu_offload (bool, optional): Whether to enable CPU offloading. Defaults to False.
        scheduler (LMSDiscreteScheduler, optional): The scheduler for the pipeline. Defaults to None.

    Returns:
        StableDiffusionImg2ImgPipeline: The image-to-image pipeline for stable diffusion.
    """
    ###############################
    # Reference:
    # Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024).
    ###############################
    if scheduler is None:
        scheduler = LMSDiscreteScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            num_train_timesteps=1000,
            steps_offset=1,
        )
    pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
        model_path,
        scheduler=scheduler,
        torch_dtype=torch.float32,
        use_safetensors=use_safetensors,
    ).to(device)
    if cpu_offload:
        pipe.enable_model_cpu_offload()
    if apply_optimization:
        # tomesd.apply_patch(pipe, ratio=0.5)
        helper = DeepCacheSDHelper(pipe=pipe)
        cache_interval, cache_branch_id = ci_cb
        helper.set_params(
            cache_interval=cache_interval, cache_branch_id=cache_branch_id
        )  # lower is faster but lower quality
        helper.enable()
        # if torch.cuda.is_available():
        #     pipe.to("cuda")
        #     pipe.enable_xformers_memory_efficient_attention()
        if use_torchcompile:
            pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
    return pipe


def createPrompts(
    class_name_pairs,
    prompt_structure=None,
    use_default_negative_prompt=False,
    negative_prompt=None,
):
    """
    Create prompts for image generation.

    Args:
        class_name_pairs (list): A list of two class names.
        prompt_structure (str, optional): The structure of the prompt. Defaults to "a photo of a <class_name>".
        use_default_negative_prompt (bool, optional): Whether to use the default negative prompt. Defaults to False.
        negative_prompt (str, optional): The negative prompt to steer the generation away from certain features.

    Returns:
        tuple: A tuple containing two lists - prompts and negative_prompts.
            prompts (list): Text prompts that describe the desired output image.
            negative_prompts (list): Negative prompts that can be used to steer the generation away from certain features.
    """
    if prompt_structure is None:
        prompt_structure = "a photo of a <class_name>"
    elif "<class_name>" not in prompt_structure:
        raise ValueError(
            "The prompt structure must contain the <class_name> placeholder."
        )
    if use_default_negative_prompt:
        default_negative_prompt = (
            "blurry image, disfigured, deformed, distorted, cartoon, drawings"
        )
        negative_prompt = default_negative_prompt

    class1 = class_name_pairs[0]
    class2 = class_name_pairs[1]
    prompt1 = prompt_structure.replace("<class_name>", class1)
    prompt2 = prompt_structure.replace("<class_name>", class2)
    prompts = [prompt1, prompt2]
    if negative_prompt is None:
        print("Info: Negative prompt not provided, returning as None.")
        return prompts, None
    else:
        # Negative prompts that can be used to steer the generation away from certain features.
        negative_prompts = [negative_prompt] * len(prompts)
        return prompts, negative_prompts


def interpolatePrompts(
    prompts,
    pipeline,
    num_interpolation_steps,
    sample_mid_interpolation,
    remove_n_middle=0,
    device="cuda",
):
    """
    Interpolates prompts by generating intermediate embeddings between pairs of prompts.

    Args:
        prompts (List[str]): A list of prompts to be interpolated.
        pipeline: The pipeline object containing the tokenizer and text encoder.
        num_interpolation_steps (int): The number of interpolation steps between each pair of prompts.
        sample_mid_interpolation (int): The number of intermediate embeddings to sample from the middle of the interpolated prompts.
        remove_n_middle (int, optional): The number of middle embeddings to remove from the interpolated prompts. Defaults to 0.
        device (str, optional): The device to run the interpolation on. Defaults to "cuda".

    Returns:
        interpolated_prompt_embeds (torch.Tensor): The interpolated prompt embeddings.
        prompt_metadata (dict): Metadata about the interpolation process, including similarity scores and nearest class information.

    e.g. if num_interpolation_steps = 10, sample_mid_interpolation = 6, remove_n_middle = 2
    Interpolated: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    Sampled:            [2, 3, 4, 5, 6, 7]
    Removed:                   x  x
    Returns:            [2, 3,       6, 7]
    """

    ###############################
    # Reference:
    # Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024).
    ###############################

    def slerp(v0, v1, num, t0=0, t1=1):
        """
        Performs spherical linear interpolation between two vectors.

        Args:
            v0 (torch.Tensor): The starting vector.
            v1 (torch.Tensor): The ending vector.
            num (int): The number of interpolation points.
            t0 (float, optional): The starting time. Defaults to 0.
            t1 (float, optional): The ending time. Defaults to 1.

        Returns:
            torch.Tensor: The interpolated vectors.

        """
        ###############################
        # Reference:
        # Karpathy, A. (2022) hacky stablediffusion code for generating videos, Gist. Available at: https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355 (Accessed: 4 June 2024).
        ###############################
        v0 = v0.detach().cpu().numpy()
        v1 = v1.detach().cpu().numpy()

        def interpolation(t, v0, v1, DOT_THRESHOLD=0.9995):
            """helper function to spherically interpolate two arrays v1 v2"""
            dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
            if np.abs(dot) > DOT_THRESHOLD:
                v2 = (1 - t) * v0 + t * v1
            else:
                theta_0 = np.arccos(dot)
                sin_theta_0 = np.sin(theta_0)
                theta_t = theta_0 * t
                sin_theta_t = np.sin(theta_t)
                s0 = np.sin(theta_0 - theta_t) / sin_theta_0
                s1 = sin_theta_t / sin_theta_0
                v2 = s0 * v0 + s1 * v1
            return v2

        t = np.linspace(t0, t1, num)

        v3 = torch.tensor(np.array([interpolation(t[i], v0, v1) for i in range(num)]))

        return v3

    def get_middle_elements(lst, n):
        """
        Returns a tuple containing a sublist of the middle elements of the given list `lst` and a range of indices of those elements.

        Args:
            lst (list): The list from which to extract the middle elements.
            n (int): The number of middle elements to extract.

        Returns:
            tuple: A tuple containing the sublist of middle elements and a range of indices.

        Raises:
            None

        Examples:
            lst = [1, 2, 3, 4, 5]
            get_middle_elements(lst, 3)
            ([2, 3, 4], range(2, 5))
        """
        if n % 2 == 0:  # Even number of elements
            middle_index = len(lst) // 2 - 1
            start = middle_index - n // 2 + 1
            end = middle_index + n // 2 + 1
            return lst[start:end], range(start, end)
        else:  # Odd number of elements
            middle_index = len(lst) // 2
            start = middle_index - n // 2
            end = middle_index + n // 2 + 1
            return lst[start:end], range(start, end)

    def remove_middle(data, n):
        """
        Remove the middle n elements from a list.

        Args:
            data (list): The input list.
            n (int): The number of elements to remove from the middle of the list.

        Returns:
            list: The modified list with the middle n elements removed.

        Raises:
            ValueError: If n is negative or greater than the length of the list.

        """
        if n < 0 or n > len(data):
            raise ValueError(
                "Invalid value for n. It should be non-negative and less than half the list length"
            )

        # Find the middle index
        middle = len(data) // 2

        # Create slices to exclude the middle n elements
        if n == 1:
            return data[:middle] + data[middle + 1 :]
        elif n % 2 == 0:
            return data[: middle - n // 2] + data[middle + n // 2 :]
        else:
            return data[: middle - n // 2] + data[middle + n // 2 + 1 :]

    batch_size = len(prompts)

    # Tokenizing and encoding prompts into embeddings.
    prompts_tokens = pipeline.tokenizer(
        prompts,
        padding="max_length",
        max_length=pipeline.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    prompts_embeds = pipeline.text_encoder(prompts_tokens.input_ids.to(device))[0]

    # Interpolating between embeddings pairs for the given number of interpolation steps.
    interpolated_prompt_embeds = []

    for i in range(batch_size - 1):
        interpolated_prompt_embeds.append(
            slerp(prompts_embeds[i], prompts_embeds[i + 1], num_interpolation_steps)
        )

    full_interpolated_prompt_embeds = interpolated_prompt_embeds[:]
    interpolated_prompt_embeds[0], sample_range = get_middle_elements(
        interpolated_prompt_embeds[0], sample_mid_interpolation
    )

    if remove_n_middle > 0:
        interpolated_prompt_embeds[0] = remove_middle(
            interpolated_prompt_embeds[0], remove_n_middle
        )

    prompt_metadata = dict()
    similarity = nn.CosineSimilarity(dim=-1, eps=1e-6)
    for i in range(num_interpolation_steps):
        class1_sim = (
            similarity(
                full_interpolated_prompt_embeds[0][0],
                full_interpolated_prompt_embeds[0][i],
            )
            .mean()
            .item()
        )
        class2_sim = (
            similarity(
                full_interpolated_prompt_embeds[0][num_interpolation_steps - 1],
                full_interpolated_prompt_embeds[0][i],
            )
            .mean()
            .item()
        )
        relative_distance = class1_sim / (class1_sim + class2_sim)

        prompt_metadata[i] = {
            "selected": i in sample_range,
            "similarity": {
                "class1": class1_sim,
                "class2": class2_sim,
                "class1_relative_distance": relative_distance,
                "class2_relative_distance": 1 - relative_distance,
            },
            "nearest_class": int(relative_distance < 0.5),
        }

    interpolated_prompt_embeds = torch.cat(interpolated_prompt_embeds, dim=0).to(device)
    return interpolated_prompt_embeds, prompt_metadata


def genClassImg(
    pipeline,
    pos_embed,
    neg_embed,
    input_image,
    generator,
    latents,
    num_imgs=1,
    height=512,
    width=512,
    num_inference_steps=25,
    guidance_scale=7.5,
):
    """
    Generate class image using the given inputs.

    Args:
        pipeline: The pipeline object used for image generation.
        pos_embed: The positive embedding for the class.
        neg_embed: The negative embedding for the class (optional).
        input_image: The input image for guidance (optional).
        generator: The generator model used for image generation.
        latents: The latent vectors used for image generation.
        num_imgs: The number of images to generate (default is 1).
        height: The height of the generated images (default is 512).
        width: The width of the generated images (default is 512).
        num_inference_steps: The number of inference steps for image generation (default is 25).
        guidance_scale: The scale factor for guidance (default is 7.5).

    Returns:
        The generated class image.
    """

    if neg_embed is not None:
        npe = neg_embed[None, ...]
    else:
        npe = None

    return pipeline(
        height=height,
        width=width,
        num_images_per_prompt=num_imgs,
        prompt_embeds=pos_embed[None, ...],
        negative_prompt_embeds=npe,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator,
        latents=latents,
        image=input_image,
    ).images[0]


def getMetadata(
    class_pairs,
    path,
    seed,
    guidance_scale,
    num_inference_steps,
    num_interpolation_steps,
    sample_mid_interpolation,
    height,
    width,
    prompts,
    negative_prompts,
    pipeline,
    prompt_metadata,
    negative_prompt_metadata,
    ssim_metadata=None,
    save_json=True,
    save_path=".",
):
    """
    Generate metadata for the given parameters.

    Args:
        class_pairs (list): List of class pairs.
        path (str): Path to the data.
        seed (int): Seed value for randomization.
        guidance_scale (float): Scale factor for guidance.
        num_inference_steps (int): Number of inference steps.
        num_interpolation_steps (int): Number of interpolation steps.
        sample_mid_interpolation (bool): Flag to sample mid-interpolation.
        height (int): Height of the image.
        width (int): Width of the image.
        prompts (list): List of prompts.
        negative_prompts (list): List of negative prompts.
        pipeline (object): Pipeline object.
        prompt_metadata (dict): Metadata for prompts.
        negative_prompt_metadata (dict): Metadata for negative prompts.
        ssim_metadata (dict, optional): SSIM scores metadata. Defaults to None.
        save_json (bool, optional): Flag to save metadata as JSON. Defaults to True.
        save_path (str, optional): Path to save the JSON file. Defaults to ".".

    Returns:
        dict: Generated metadata.
    """

    metadata = dict()

    metadata["class_pairs"] = class_pairs
    metadata["path"] = path
    metadata["seed"] = seed
    metadata["params"] = {
        "CFG": guidance_scale,
        "inferenceSteps": num_inference_steps,
        "interpolationSteps": num_interpolation_steps,
        "sampleMidInterpolation": sample_mid_interpolation,
        "height": height,
        "width": width,
    }
    for i in range(len(prompts)):
        metadata[f"prompt_text_{i}"] = prompts[i]
        if negative_prompts is not None:
            metadata[f"negative_prompt_text_{i}"] = negative_prompts[i]
    metadata["pipe_config"] = dict(pipeline.config)
    metadata["prompt_embed_similarity"] = prompt_metadata
    metadata["negative_prompt_embed_similarity"] = negative_prompt_metadata
    if ssim_metadata is not None:
        print("Info: SSIM scores are available.")
        metadata["ssim_scores"] = ssim_metadata
    if save_json:
        with open(
            os.path.join(save_path, f"{'_'.join(i for i in class_pairs)}_{seed}.json"),
            "w",
        ) as f:
            json.dump(metadata, f, indent=4)
    return metadata


def groupbyInterpolation(dir_to_classfolder):
    """
    Group files in a directory by interpolation step.

    Args:
        dir_to_classfolder (str): The path to the directory containing the files.

    Returns:
        None
    """
    files = [
        (f.split(sep="_")[1].split(sep=".")[0], os.path.join(dir_to_classfolder, f))
        for f in os.listdir(dir_to_classfolder)
    ]
    # create a subfolder for each step of the interpolation
    for interpolation_step, file_path in files:
        new_dir = os.path.join(dir_to_classfolder, interpolation_step)
        if not os.path.exists(new_dir):
            os.makedirs(new_dir)
        os.rename(file_path, os.path.join(new_dir, os.path.basename(file_path)))


def ungroupInterpolation(dir_to_classfolder):
    """
    Moves all files from subdirectories within `dir_to_classfolder` to `dir_to_classfolder` itself,
    and then removes the subdirectories.

    Args:
        dir_to_classfolder (str): The path to the directory containing the subdirectories.

    Returns:
        None
    """
    for interpolation_step in os.listdir(dir_to_classfolder):
        if os.path.isdir(os.path.join(dir_to_classfolder, interpolation_step)):
            for f in os.listdir(os.path.join(dir_to_classfolder, interpolation_step)):
                os.rename(
                    os.path.join(dir_to_classfolder, interpolation_step, f),
                    os.path.join(dir_to_classfolder, f),
                )
            os.rmdir(os.path.join(dir_to_classfolder, interpolation_step))


def groupAllbyInterpolation(
    data_path,
    group=True,
    fn_group=groupbyInterpolation,
    fn_ungroup=ungroupInterpolation,
):
    """
    Group or ungroup all data classes by interpolation.

    Args:
        data_path (str): The path to the data.
        group (bool, optional): Whether to group the data. Defaults to True.
        fn_group (function, optional): The function to use for grouping. Defaults to groupbyInterpolation.
        fn_ungroup (function, optional): The function to use for ungrouping. Defaults to ungroupInterpolation.
    """
    data_classes = sorted(os.listdir(data_path))
    if group:
        fn = fn_group
    else:
        fn = fn_ungroup
    for c in data_classes:
        c_path = os.path.join(data_path, c)
        if os.path.isdir(c_path):
            fn(c_path)
            print(f"Processed {c}")


def getPairIndices(subset_len, total_pair_count=1, seed=None):
    """
    Generate pairs of indices for a given subset length.

    Args:
        subset_len (int): The length of the subset.
        total_pair_count (int, optional): The total number of pairs to generate. Defaults to 1.
        seed (int, optional): The seed value for the random number generator. Defaults to None.

    Returns:
        list: A list of pairs of indices.

    """
    rng = np.random.default_rng(seed)
    group_size = (subset_len + total_pair_count - 1) // total_pair_count
    numbers = list(range(subset_len))
    numbers_selection = list(range(subset_len))
    rng.shuffle(numbers)
    for i in range(group_size - subset_len % group_size):
        numbers.append(numbers_selection[i])
    numbers = np.array(numbers)
    groups = numbers[: group_size * total_pair_count].reshape(-1, group_size)
    return groups.tolist()


def generateImagesFromDataset(
    img_subsets,
    class_iterables,
    pipeline,
    interpolated_prompt_embeds,
    interpolated_negative_prompts_embeds,
    num_inference_steps,
    guidance_scale,
    height=512,
    width=512,
    seed=None,
    save_path=".",
    class_pairs=("0", "1"),
    save_image=True,
    image_type="jpg",
    interpolate_range="full",
    device="cuda",
    return_images=False,
):
    """
    Generates images from a dataset using the given parameters.

    Args:
        img_subsets (dict): A dictionary containing image subsets for each class.
        class_iterables (dict): A dictionary containing iterable objects for each class.
        pipeline (object): The pipeline object used for image generation.
        interpolated_prompt_embeds (list): A list of interpolated prompt embeddings.
        interpolated_negative_prompts_embeds (list): A list of interpolated negative prompt embeddings.
        num_inference_steps (int): The number of inference steps for image generation.
        guidance_scale (float): The scale factor for guidance loss during image generation.
        height (int, optional): The height of the generated images. Defaults to 512.
        width (int, optional): The width of the generated images. Defaults to 512.
        seed (int, optional): The seed value for random number generation. Defaults to None.
        save_path (str, optional): The path to save the generated images. Defaults to ".".
        class_pairs (tuple, optional): A tuple containing pairs of class identifiers. Defaults to ("0", "1").
        save_image (bool, optional): Whether to save the generated images. Defaults to True.
        image_type (str, optional): The file format of the saved images. Defaults to "jpg".
        interpolate_range (str, optional): The range of interpolation for prompt embeddings.
            Possible values are "full", "nearest", or "furthest". Defaults to "full".
        device (str, optional): The device to use for image generation. Defaults to "cuda".
        return_images (bool, optional): Whether to return the generated images. Defaults to False.

    Returns:
        dict or tuple: If return_images is True, returns a dictionary containing the generated images for each class and a dictionary containing the SSIM scores for each class and interpolation step.
                       If return_images is False, returns a dictionary containing the SSIM scores for each class and interpolation step.
    """
    if interpolate_range == "nearest":
        nearest_half = True
        furthest_half = False
    elif interpolate_range == "furthest":
        nearest_half = False
        furthest_half = True
    else:
        nearest_half = False
        furthest_half = False

    if seed is None:
        seed = torch.Generator().seed()
    generator = torch.manual_seed(seed)
    rng = np.random.default_rng(seed)
    # Generating initial U-Net latent vectors from a random normal distribution.
    latents = torch.randn(
        (1, pipeline.unet.config.in_channels, height // 8, width // 8),
        generator=generator,
    ).to(device)

    embed_len = len(interpolated_prompt_embeds)
    embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds)
    embed_pairs_list = list(embed_pairs)
    if return_images:
        class_images = dict()
    class_ssim = dict()

    if nearest_half or furthest_half:
        if nearest_half:
            steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len))
            mutiplier = 2
        elif furthest_half:
            # uses opposite class of images of the text interpolation
            steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2))
            mutiplier = 2
    else:
        steps_range = (range(embed_len), range(embed_len))
        mutiplier = 1

    for class_iter, class_id in enumerate(class_pairs):
        if return_images:
            class_images[class_id] = list()
        class_ssim[class_id] = {
            i: {"ssim_sum": 0, "ssim_count": 0, "ssim_avg": 0} for i in range(embed_len)
        }
        subset_len = len(img_subsets[class_id])
        # to efficiently randomize the steps to interpolate for each image in the class, group_map is used
        # group_map: index is the image id, element is the group id
        # steps_range[class_iter] determines the range of steps to interpolate for the class,
        # so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps
        # then the rest is to multiply the steps to cover the whole subset + remainder
        group_map = (
            list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1)
        )
        rng.shuffle(
            group_map
        )  # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id

        iter_indices = class_iterables[class_id].pop()
        # generate images for each image in the class, randomly selecting an interpolated step
        for image_id in iter_indices:
            img, trg = img_subsets[class_id][image_id]
            input_image = img.unsqueeze(0)
            interpolate_step = group_map[image_id]
            prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolate_step]
            generated_image = genClassImg(
                pipeline,
                prompt_embeds,
                negative_prompt_embeds,
                input_image,
                generator,
                latents,
                num_imgs=1,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
            )
            pred_image = transforms.ToTensor()(generated_image).unsqueeze(0)
            ssim_score = ssim(pred_image, input_image).item()
            class_ssim[class_id][interpolate_step]["ssim_sum"] += ssim_score
            class_ssim[class_id][interpolate_step]["ssim_count"] += 1
            if return_images:
                class_images[class_id].append(generated_image)
            if save_image:
                if image_type == "jpg":
                    generated_image.save(
                        f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}",
                        format="JPEG",
                        quality=95,
                    )
                elif image_type == "png":
                    generated_image.save(
                        f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}",
                        format="PNG",
                    )
                else:
                    generated_image.save(
                        f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}"
                    )

        # calculate ssim avg for the class
        for i_step in range(embed_len):
            if class_ssim[class_id][i_step]["ssim_count"] > 0:
                class_ssim[class_id][i_step]["ssim_avg"] = (
                    class_ssim[class_id][i_step]["ssim_sum"]
                    / class_ssim[class_id][i_step]["ssim_count"]
                )

    if return_images:
        return class_images, class_ssim
    else:
        return class_ssim


def generateTrace(
    prompts,
    img_subsets,
    class_iterables,
    interpolated_prompt_embeds,
    interpolated_negative_prompts_embeds,
    subset_indices,
    seed=None,
    save_path=".",
    class_pairs=("0", "1"),
    image_type="jpg",
    interpolate_range="full",
    save_prompt_embeds=False,
):
    """
    Generate a trace dictionary containing information about the generated images.

    Args:
        prompts (list): List of prompt texts.
        img_subsets (dict): Dictionary containing image subsets for each class.
        class_iterables (dict): Dictionary containing iterable objects for each class.
        interpolated_prompt_embeds (torch.Tensor): Tensor containing interpolated prompt embeddings.
        interpolated_negative_prompts_embeds (torch.Tensor): Tensor containing interpolated negative prompt embeddings.
        subset_indices (dict): Dictionary containing indices of subsets for each class.
        seed (int, optional): Seed value for random number generation. Defaults to None.
        save_path (str, optional): Path to save the generated images. Defaults to ".".
        class_pairs (tuple, optional): Tuple containing class pairs. Defaults to ("0", "1").
        image_type (str, optional): Type of the generated images. Defaults to "jpg".
        interpolate_range (str, optional): Range of interpolation. Defaults to "full".
        save_prompt_embeds (bool, optional): Flag to save prompt embeddings. Defaults to False.

    Returns:
        dict: Trace dictionary containing information about the generated images.
    """
    trace_dict = {
        "class_pairs": list(),
        "class_id": list(),
        "image_id": list(),
        "interpolation_step": list(),
        "embed_len": list(),
        "pos_prompt_text": list(),
        "neg_prompt_text": list(),
        "input_file_path": list(),
        "output_file_path": list(),
        "input_prompts_embed": list(),
    }

    if interpolate_range == "nearest":
        nearest_half = True
        furthest_half = False
    elif interpolate_range == "furthest":
        nearest_half = False
        furthest_half = True
    else:
        nearest_half = False
        furthest_half = False

    if seed is None:
        seed = torch.Generator().seed()
    rng = np.random.default_rng(seed)

    embed_len = len(interpolated_prompt_embeds)
    embed_pairs = zip(
        interpolated_prompt_embeds.cpu().numpy(),
        interpolated_negative_prompts_embeds.cpu().numpy(),
    )
    embed_pairs_list = list(embed_pairs)

    if nearest_half or furthest_half:
        if nearest_half:
            steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len))
            mutiplier = 2
        elif furthest_half:
            # uses opposite class of images of the text interpolation
            steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2))
            mutiplier = 2
    else:
        steps_range = (range(embed_len), range(embed_len))
        mutiplier = 1

    for class_iter, class_id in enumerate(class_pairs):

        subset_len = len(img_subsets[class_id])
        # to efficiently randomize the steps to interpolate for each image in the class, group_map is used
        # group_map: index is the image id, element is the group id
        # steps_range[class_iter] determines the range of steps to interpolate for the class,
        # so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps
        # then the rest is to multiply the steps to cover the whole subset + remainder
        group_map = (
            list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1)
        )
        rng.shuffle(
            group_map
        )  # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id

        iter_indices = class_iterables[class_id].pop()
        # generate images for each image in the class, randomly selecting an interpolated step
        for image_id in iter_indices:
            class_ds = img_subsets[class_id]
            interpolate_step = group_map[image_id]
            sample_count = subset_indices[class_id][0] + image_id
            input_file = os.path.normpath(class_ds.dataset.samples[sample_count][0])
            pos_prompt = prompts[0]
            neg_prompt = prompts[1]
            output_file = f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}"
            if save_prompt_embeds:
                input_prompts_embed = embed_pairs_list[interpolate_step]
            else:
                input_prompts_embed = None

            trace_dict["class_pairs"].append(class_pairs)
            trace_dict["class_id"].append(class_id)
            trace_dict["image_id"].append(image_id)
            trace_dict["interpolation_step"].append(interpolate_step)
            trace_dict["embed_len"].append(embed_len)
            trace_dict["pos_prompt_text"].append(pos_prompt)
            trace_dict["neg_prompt_text"].append(neg_prompt)
            trace_dict["input_file_path"].append(input_file)
            trace_dict["output_file_path"].append(output_file)
            trace_dict["input_prompts_embed"].append(input_prompts_embed)

    return trace_dict