|
from huggingface_hub import hf_hub_download |
|
|
|
hf_hub_download(repo_id="InstantX/InstantIR", filename="models/adapter.pt", local_dir=".") |
|
hf_hub_download(repo_id="InstantX/InstantIR", filename="models/aggregator.pt", local_dir=".") |
|
hf_hub_download(repo_id="InstantX/InstantIR", filename="models/previewer_lora_weights.bin", local_dir=".") |
|
|
|
import torch |
|
from PIL import Image |
|
|
|
from diffusers import DDPMScheduler |
|
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler |
|
|
|
from module.ip_adapter.utils import load_adapter_to_pipe |
|
from pipelines.sdxl_instantir import InstantIRPipeline |
|
|
|
|
|
instantir_path = f'./models' |
|
|
|
|
|
pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16) |
|
|
|
|
|
load_adapter_to_pipe( |
|
pipe, |
|
f"{instantir_path}/adapter.pt", |
|
image_encoder_or_path = 'facebook/dinov2-large', |
|
) |
|
|
|
|
|
pipe.prepare_previewers(instantir_path) |
|
pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler") |
|
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
|
pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt") |
|
pipe.aggregator.load_state_dict(pretrained_state_dict) |
|
|
|
|
|
pipe.to(device='cuda', dtype=torch.float16) |
|
pipe.aggregator.to(device='cuda', dtype=torch.float16) |
|
|
|
def infer(input_image): |
|
|
|
low_quality_image = Image.open(input_image).convert("RGB") |
|
|
|
|
|
image = pipe( |
|
image=low_quality_image, |
|
previewer_scheduler=lcm_scheduler, |
|
).images[0] |
|
|
|
return image |
|
|
|
import gradio as gr |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
lq_img = gr.Image(label="Low-quality image", type="filepath") |
|
submit_btn = gr.Button("InstantIR magic!") |
|
output_img = gr.Image(label="InstantIR restored") |
|
submit_btn.click( |
|
fn=infer, |
|
inputs=[lq_img], |
|
outputs=[output_img] |
|
) |
|
demo.launch(show_error=True) |