import glob
from os import path
from paths import get_file_name, FastStableDiffusionPaths
from pathlib import Path


# A basic class to keep track of the currently loaded LoRAs and
# their weights; the diffusers function \c get_active_adapters()
# returns a list of adapter names but not their weights so we need
# a way to keep track of the current LoRA weights to set whenever
# a new LoRA is loaded
class _lora_info:
    def __init__(
        self,
        path: str,
        weight: float,
    ):
        self.path = path
        self.adapter_name = get_file_name(path)
        self.weight = weight

    def __del__(self):
        self.path = None
        self.adapter_name = None


_loaded_loras = []
_current_pipeline = None


# This function loads a LoRA from the LoRA path setting, so it's
# possible to load multiple LoRAs by calling this function more than
# once with a different LoRA path setting; note that if you plan to
# load multiple LoRAs and dynamically change their weights, you
# might want to set the LoRA fuse option to False
def load_lora_weight(
    pipeline,
    lcm_diffusion_setting,
):
    if not lcm_diffusion_setting.lora.path:
        raise Exception("Empty lora model path")

    if not path.exists(lcm_diffusion_setting.lora.path):
        raise Exception("Lora model path is invalid")

    # If the pipeline has been rebuilt since the last call, remove all
    # references to previously loaded LoRAs and store the new pipeline
    global _loaded_loras
    global _current_pipeline
    if pipeline != _current_pipeline:
        for lora in _loaded_loras:
            del lora
        del _loaded_loras
        _loaded_loras = []
        _current_pipeline = pipeline

    current_lora = _lora_info(
        lcm_diffusion_setting.lora.path,
        lcm_diffusion_setting.lora.weight,
    )
    _loaded_loras.append(current_lora)

    if lcm_diffusion_setting.lora.enabled:
        print(f"LoRA adapter name : {current_lora.adapter_name}")
        pipeline.load_lora_weights(
            FastStableDiffusionPaths.get_lora_models_path(),
            weight_name=Path(lcm_diffusion_setting.lora.path).name,
            local_files_only=True,
            adapter_name=current_lora.adapter_name,
        )
        update_lora_weights(
            pipeline,
            lcm_diffusion_setting,
        )

        if lcm_diffusion_setting.lora.fuse:
            pipeline.fuse_lora()


def get_lora_models(root_dir: str):
    lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True)
    lora_models_map = {}
    for file_path in lora_models:
        lora_name = get_file_name(file_path)
        if lora_name is not None:
            lora_models_map[lora_name] = file_path
    return lora_models_map


# This function returns a list of (adapter_name, weight) tuples for the
# currently loaded LoRAs
def get_active_lora_weights():
    active_loras = []
    for lora_info in _loaded_loras:
        active_loras.append(
            (
                lora_info.adapter_name,
                lora_info.weight,
            )
        )
    return active_loras


# This function receives a pipeline, an lcm_diffusion_setting object and
# an optional list of updated (adapter_name, weight) tuples
def update_lora_weights(
    pipeline,
    lcm_diffusion_setting,
    lora_weights=None,
):
    global _loaded_loras
    global _current_pipeline
    if pipeline != _current_pipeline:
        print("Wrong pipeline when trying to update LoRA weights")
        return
    if lora_weights:
        for idx, lora in enumerate(lora_weights):
            if _loaded_loras[idx].adapter_name != lora[0]:
                print("Wrong adapter name in LoRA enumeration!")
                continue
            _loaded_loras[idx].weight = lora[1]

    adapter_names = []
    adapter_weights = []
    if lcm_diffusion_setting.use_lcm_lora:
        adapter_names.append("lcm")
        adapter_weights.append(1.0)
    for lora in _loaded_loras:
        adapter_names.append(lora.adapter_name)
        adapter_weights.append(lora.weight)
    pipeline.set_adapters(
        adapter_names,
        adapter_weights=adapter_weights,
    )
    adapter_weights = zip(adapter_names, adapter_weights)
    print(f"Adapters: {list(adapter_weights)}")