from diffusers import (
    DDIMScheduler,
    DDPMScheduler,
    DEISMultistepScheduler,
    DPMSolverMultistepScheduler,
    DPMSolverSinglestepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    HeunDiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    KDPM2DiscreteScheduler,
    PNDMScheduler,
    UniPCMultistepScheduler,
)

SCHEDULER_MAPPING = {
    "DDIM": DDIMScheduler,
    "DDPMScheduler": DDPMScheduler,
    "DEISMultistep": DEISMultistepScheduler,
    "DPMSolverMultistep": DPMSolverMultistepScheduler,
    "DPMSolverSinglestep": DPMSolverSinglestepScheduler,
    "EulerAncestralDiscrete": EulerAncestralDiscreteScheduler,
    "EulerDiscrete": EulerDiscreteScheduler,
    "HeunDiscrete": HeunDiscreteScheduler,
    "KDPM2AncestralDiscrete": KDPM2AncestralDiscreteScheduler,
    "KDPM2Discrete": KDPM2DiscreteScheduler,
    "PNDMScheduler": PNDMScheduler,
    "UniPCMultistep": UniPCMultistepScheduler,
}


def get_scheduler(pipe, scheduler):
    if scheduler in SCHEDULER_MAPPING:
        SchedulerClass = SCHEDULER_MAPPING[scheduler]
        pipe.scheduler = SchedulerClass.from_config(pipe.scheduler.config)
    else:
        raise ValueError(f"Invalid scheduler name {scheduler}")

    return pipe