|
import torch |
|
from PIL import Image |
|
from RealESRGAN import RealESRGAN |
|
import gradio as gr |
|
import numpy as np |
|
import tempfile |
|
import time |
|
import os |
|
from transformers import pipeline |
|
import csv |
|
import zipfile |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
def load_model(scale): |
|
model = RealESRGAN(device, scale=scale) |
|
weights_path = f'weights/RealESRGAN_x{scale}.pth' |
|
try: |
|
model.load_weights(weights_path, download=True) |
|
print(f"Weights for scale {scale} loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading weights for scale {scale}: {e}") |
|
model.load_weights(weights_path, download=False) |
|
return model |
|
|
|
|
|
model2 = load_model(2) |
|
model4 = load_model(4) |
|
model8 = load_model(8) |
|
|
|
|
|
description_generator = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") |
|
|
|
|
|
def enhance_image(image, scale): |
|
try: |
|
image_np = np.array(image.convert('RGB')) |
|
if scale == '2x': |
|
result = model2.predict(image_np) |
|
elif scale == '4x': |
|
result = model4.predict(image_np) |
|
else: |
|
result = model8.predict(image_np) |
|
|
|
return Image.fromarray(np.uint8(result)) |
|
except Exception as e: |
|
print(f"Error enhancing image: {e}") |
|
return image |
|
|
|
|
|
def generate_description(image): |
|
try: |
|
description = description_generator(image)[0]['generated_text'] |
|
return description |
|
except Exception as e: |
|
print(f"Error generating description: {e}") |
|
return "Description unavailable." |
|
|
|
|
|
def muda_dpi(input_image, dpi): |
|
dpi_tuple = (dpi, dpi) |
|
image = Image.fromarray(input_image.astype('uint8'), 'RGB') |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') |
|
image.save(temp_file, format='JPEG', dpi=dpi_tuple) |
|
temp_file.close() |
|
return Image.open(temp_file.name) |
|
|
|
|
|
def resize_image(input_image, width, height): |
|
image = Image.fromarray(input_image.astype('uint8'), 'RGB') |
|
resized_image = image.resize((width, height)) |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') |
|
resized_image.save(temp_file, format='JPEG') |
|
temp_file.close() |
|
return Image.open(temp_file.name) |
|
|
|
|
|
def process_images(image_files, enhance, scale, adjust_dpi, dpi, resize, width, height): |
|
processed_images = [] |
|
file_paths = [] |
|
descriptions = [] |
|
|
|
|
|
csv_file_path = os.path.join(tempfile.gettempdir(), "image_descriptions.csv") |
|
with open(csv_file_path, mode="w", newline="") as csv_file: |
|
writer = csv.writer(csv_file) |
|
writer.writerow(["Filename", "Title", "Keywords"]) |
|
|
|
for image_file in image_files: |
|
input_image = np.array(Image.open(image_file).convert('RGB')) |
|
original_image = Image.fromarray(input_image.astype('uint8'), 'RGB') |
|
|
|
if enhance: |
|
original_image = enhance_image(original_image, scale) |
|
|
|
if adjust_dpi: |
|
original_image = muda_dpi(np.array(original_image), dpi) |
|
|
|
if resize: |
|
original_image = resize_image(np.array(original_image), width, height) |
|
|
|
|
|
description = generate_description(original_image) |
|
title = description |
|
keywords = ", ".join(set(description.split()))[:45] |
|
|
|
|
|
base_name = os.path.basename(image_file.name) |
|
file_name, _ = os.path.splitext(base_name) |
|
file_name = ''.join(e for e in file_name if e.isalnum() or e in (' ', '_', '-')).strip().replace(' ', '_') |
|
|
|
|
|
output_path = os.path.join(tempfile.gettempdir(), f"{file_name}.jpg") |
|
original_image.save(output_path, format='JPEG') |
|
|
|
|
|
writer.writerow([file_name, title, keywords]) |
|
|
|
|
|
processed_images.append(original_image) |
|
file_paths.append(output_path) |
|
descriptions.append(description) |
|
|
|
|
|
zip_file_path = os.path.join(tempfile.gettempdir(), "processed_images.zip") |
|
with zipfile.ZipFile(zip_file_path, 'w') as zipf: |
|
for file_path in file_paths: |
|
zipf.write(file_path, arcname=os.path.basename(file_path)) |
|
zipf.write(csv_file_path, arcname="image_descriptions.csv") |
|
|
|
return processed_images, zip_file_path, descriptions |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_images, |
|
inputs=[ |
|
gr.Files(label="Upload Image Files"), |
|
gr.Checkbox(label="Enhance Images (ESRGAN)"), |
|
gr.Radio(['2x', '4x', '8x'], type="value", value='2x', label='Resolution model'), |
|
gr.Checkbox(label="Adjust DPI"), |
|
gr.Number(label="DPI", value=300), |
|
gr.Checkbox(label="Resize"), |
|
gr.Number(label="Width", value=512), |
|
gr.Number(label="Height", value=512) |
|
], |
|
outputs=[ |
|
gr.Gallery(label="Final Images"), |
|
gr.File(label="Download ZIP of Images and Descriptions"), |
|
gr.Textbox(label="Image Descriptions", lines=5) |
|
], |
|
title="Multi-Image Enhancer with Hugging Face Descriptions", |
|
description="Upload multiple images, enhance, adjust DPI, resize, generate descriptions, and download the results and a ZIP archive." |
|
) |
|
|
|
iface.launch(debug=True, share=True) |
|
|