# Run this before you deploy it on replicate, because if you don't
# whenever you run the model, it will download the weights from the
# internet, which will take a long time.

import torch
from diffusers import AutoencoderKL, DiffusionPipeline, ControlNetModel
from diffusers.pipelines.stable_diffusion.safety_checker import (
    StableDiffusionSafetyChecker,
)

# better_vae = AutoencoderKL.from_pretrained(
#     "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
# )

# pipe = DiffusionPipeline.from_pretrained(
#     "stabilityai/stable-diffusion-xl-base-1.0",
#     vae=better_vae,
#     torch_dtype=torch.float16,
#     use_safetensors=True,
#     variant="fp16",
# )

# pipe.save_pretrained("./sdxl-cache", safe_serialization=True)

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0",
    torch_dtype=torch.float16, use_safetensors=True
)

controlnet.save_pretrained("./cn-canny-edge-cache", safe_serialization=True)

# pipe = DiffusionPipeline.from_pretrained(
#     "stabilityai/stable-diffusion-xl-refiner-1.0",
#     torch_dtype=torch.float16,
#     use_safetensors=True,
#     variant="fp16",
# )

# # TODO - we don't need to save all of this and in fact should save just the unet, tokenizer, and config.
# pipe.save_pretrained("./refiner-cache", safe_serialization=True)


# safety = StableDiffusionSafetyChecker.from_pretrained(
#     "CompVis/stable-diffusion-safety-checker",
#     torch_dtype=torch.float16,
# )

# safety.save_pretrained("./safety-cache")