Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from typing import Optional, Any | |
| from diffusers import ( | |
| DDIMScheduler, | |
| DPMSolverMultistepScheduler, | |
| DPMSolverSinglestepScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| EulerDiscreteScheduler, | |
| AutoencoderKL, | |
| StableDiffusionXLPipeline, | |
| ) | |
| import logging | |
| def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str] = None, vae: Optional[AutoencoderKL] = None) -> Any: | |
| """Load the Stable Diffusion pipeline.""" | |
| try: | |
| pipeline = ( | |
| StableDiffusionXLPipeline.from_single_file | |
| if model_name.endswith(".safetensors") | |
| else StableDiffusionXLPipeline.from_pretrained | |
| ) | |
| pipe = pipeline( | |
| model_name, | |
| vae=vae, | |
| torch_dtype=torch.float16, | |
| custom_pipeline="lpw_stable_diffusion_xl", | |
| use_safetensors=True, | |
| add_watermarker=False | |
| ) | |
| pipe.to(device) | |
| return pipe | |
| except Exception as e: | |
| logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True) | |
| raise | |