inference-pipeline / scripts /download-weights.py
Sebastian Semeniuc
feat: add models locally from_pretrained
8aaec62
# Run this before you deploy it on replicate, because if you don't
# whenever you run the model, it will download the weights from the
# internet, which will take a long time.
import torch
from diffusers import AutoencoderKL, DiffusionPipeline, ControlNetModel
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
# better_vae = AutoencoderKL.from_pretrained(
# "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
# )
# pipe = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-1.0",
# vae=better_vae,
# torch_dtype=torch.float16,
# use_safetensors=True,
# variant="fp16",
# )
# pipe.save_pretrained("./sdxl-cache", safe_serialization=True)
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-canny-sdxl-1.0",
torch_dtype=torch.float16, use_safetensors=True
)
controlnet.save_pretrained("./cn-canny-edge-cache", safe_serialization=True)
# pipe = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-refiner-1.0",
# torch_dtype=torch.float16,
# use_safetensors=True,
# variant="fp16",
# )
# # TODO - we don't need to save all of this and in fact should save just the unet, tokenizer, and config.
# pipe.save_pretrained("./refiner-cache", safe_serialization=True)
# safety = StableDiffusionSafetyChecker.from_pretrained(
# "CompVis/stable-diffusion-safety-checker",
# torch_dtype=torch.float16,
# )
# safety.save_pretrained("./safety-cache")