from flask import Flask, request, render_template_string
from PIL import Image
import numpy as np
import io
import base64
import torch
import torchvision.transforms as transforms
import os
import sys
app = Flask(__name__)
HTML_TEMPLATE = """
🌟 ZeroIG: Zero-Shot Low-Light Enhancement
CVPR 2024 - Upload a low-light image for professional enhancement!
{% if status %}
{{ status }}
{% endif %}
{% if error %}
{{ error }}
{% endif %}
{% if original_image and result_image %}
Results:
Original (Low-light)
{% endif %}
About ZeroIG
Zero-shot illumination-guided joint denoising and adaptive enhancement for low-light images.
Features: No training data required • Joint denoising & enhancement • Prevents over-enhancement
📄 Research Paper |
💻 Source Code
"""
class ZeroIGProcessor:
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.load_model()
print(f"ZeroIG initialized on {self.device}")
def load_model(self):
try:
# Import your uploaded ZeroIG files
from model import Finetunemodel, Network
# Try to load trained weights
model_path = "./weights/model.pt"
if os.path.exists(model_path):
print(f"Found model weights at {model_path}")
model = Finetunemodel(model_path)
print("✅ Loaded ZeroIG Finetunemodel with trained weights")
else:
print("No trained weights found, using Network with random initialization")
model = Network()
print("⚠️ Using ZeroIG Network with random weights")
model.to(self.device)
model.eval()
print(f"Model moved to {self.device}")
return model
except ImportError as e:
print(f"❌ Could not import ZeroIG modules: {e}")
print("Make sure you have uploaded: model.py, loss.py, utils.py")
return None
except Exception as e:
print(f"❌ Could not load ZeroIG model: {e}")
return None
def enhance_image(self, image):
"""Enhance image using your ZeroIG model"""
try:
if self.model is None:
return self.simple_enhance(image), "❌ ZeroIG model not available - using simple enhancement"
# Resize if too large to prevent memory issues
original_size = image.size
max_size = 800 # Adjust based on your needs
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = tuple(int(dim * ratio) for dim in image.size)
image = image.resize(new_size, Image.Resampling.LANCZOS)
print(f"Resized image from {original_size} to {image.size}")
# Convert to tensor
transform = transforms.ToTensor()
input_tensor = transform(image).unsqueeze(0).to(self.device)
print(f"Input tensor shape: {input_tensor.shape}")
# Run your ZeroIG model
with torch.no_grad():
if hasattr(self.model, 'enhance') and hasattr(self.model, 'denoise_1'):
# Finetunemodel - returns (enhanced, denoised)
enhanced, denoised = self.model(input_tensor)
result_tensor = denoised # Use denoised output
status = "✅ Enhanced with ZeroIG Finetunemodel"
print("Used Finetunemodel")
else:
# Network model - returns multiple outputs
outputs = self.model(input_tensor)
result_tensor = outputs[13] # H3 is the final denoised result
status = "✅ Enhanced with ZeroIG Network model"
print("Used Network model")
# Convert back to PIL
result_tensor = result_tensor.squeeze(0).cpu().clamp(0, 1)
enhanced_image = transforms.ToPILImage()(result_tensor)
print(f"Output image size: {enhanced_image.size}")
# Resize back to original size if needed
if enhanced_image.size != original_size and original_size != image.size:
enhanced_image = enhanced_image.resize(original_size, Image.Resampling.LANCZOS)
print(f"Resized back to original size: {enhanced_image.size}")
return enhanced_image, status
except Exception as e:
print(f"ZeroIG enhancement error: {e}")
import traceback
traceback.print_exc()
return self.simple_enhance(image), f"⚠️ ZeroIG failed, using simple enhancement: {str(e)}"
def simple_enhance(self, image):
"""Fallback simple enhancement"""
arr = np.array(image).astype(np.float32)
enhanced = np.clip(arr * 1.8, 0, 255).astype(np.uint8)
return Image.fromarray(enhanced)
# Initialize ZeroIG processor
print("🚀 Loading ZeroIG processor...")
zeroig = ZeroIGProcessor()
def image_to_base64(image):
"""Convert PIL image to base64 string"""
img_buffer = io.BytesIO()
image.save(img_buffer, format='PNG')
img_str = base64.b64encode(img_buffer.getvalue()).decode()
return img_str
@app.route('/', methods=['GET', 'POST'])
def index():
original_image = None
result_image = None
status = None
error = None
if request.method == 'POST':
try:
file = request.files['image']
if file:
print(f"Processing uploaded image: {file.filename}")
# Open and process image
image = Image.open(file.stream).convert('RGB')
print(f"Image size: {image.size}")
# Store original for comparison
original_image = image_to_base64(image)
# Enhance with your ZeroIG model
enhanced_image, enhancement_status = zeroig.enhance_image(image)
# Convert result to base64
result_image = image_to_base64(enhanced_image)
status = enhancement_status
except Exception as e:
error = f"Error processing image: {str(e)}"
print(f"Error: {e}")
import traceback
traceback.print_exc()
return render_template_string(HTML_TEMPLATE,
original_image=original_image,
result_image=result_image,
status=status,
error=error)
if __name__ == '__main__':
print("🚀 Starting ZeroIG Flask app...")
app.run(host='0.0.0.0', port=7860)