arupchakraborty2004's picture
Update app.py
c5c7526 verified
import gradio as gr
from PIL import Image
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
# Load pre-trained models for CPU
print("Loading models...")
# Stable Diffusion for colorization and inpainting
colorization_pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
colorization_pipeline = colorization_pipeline.to("cpu")
# Img2Img pipeline for denoising
denoising_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
denoising_pipeline = denoising_pipeline.to("cpu")
# Upscaling pipeline using stabilityai's upscaler
upscaling_pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float32)
upscaling_pipeline = upscaling_pipeline.to("cpu")
# Define image processing functions
def colorize_image(grayscale_image):
grayscale_image = grayscale_image.convert("RGB").resize((512, 512))
prompt = "A colorful and realistic version of this grayscale photo."
result = colorization_pipeline(prompt, image=grayscale_image)
return result.images[0]
def denoise_image(noisy_image, strength: float = 0.7):
noisy_image = noisy_image.convert("RGB").resize((512, 512))
prompt = "A high-quality version of this image."
result = denoising_pipeline(prompt, image=noisy_image, strength=strength)
return result.images[0]
def inpaint_image(base_image, mask_image):
if mask_image is None:
raise ValueError("Mask image is required for inpainting.")
base_image = base_image.convert("RGB").resize((512, 512))
mask_image = mask_image.convert("RGB").resize((512, 512))
prompt = "A completed and visually realistic image with the missing parts filled in."
result = colorization_pipeline(prompt, image=base_image, mask_image=mask_image)
return result.images[0]
def upscale_image(low_res_image):
low_res_image = low_res_image.convert("RGB").resize((256, 256)) # Upscaler expects smaller inputs
prompt = "An upscaled version of this image."
result = upscaling_pipeline(prompt, image=low_res_image)
return result.images[0]
# Main function to process based on the selected task
def process_image(task, image, mask=None):
if task == "Colorize":
return colorize_image(image)
elif task == "Denoise":
return denoise_image(image)
elif task == "Inpaint":
return inpaint_image(image, mask)
elif task == "Upscale":
return upscale_image(image)
else:
raise ValueError("Invalid task selected.")
# Gradio Interface
interface = gr.Interface(
fn=process_image,
inputs=[
gr.Dropdown(
["Colorize", "Denoise", "Inpaint", "Upscale"],
label="Select Task",
value="Colorize"
),
gr.Image(type="pil", label="Upload Image"),
gr.Image(type="pil", label="Upload Mask (for Inpainting)", value=None) # Allow empty input for mask
],
outputs=gr.Image(label="Processed Image"),
title="AI-Powered Image Restoration",
description=(
"Select an image restoration task from the dropdown menu:\n\n"
"- **Colorize**: Convert grayscale images to realistic color.\n"
"- **Denoise**: Enhance image quality by removing noise.\n"
"- **Inpaint**: Fill missing parts of an image using a mask.\n"
"- **Upscale**: Increase image resolution while preserving details."
)
)
if __name__ == "__main__":
interface.launch(share=True)