from flask import Flask, request, render_template_string from PIL import Image import numpy as np import io import base64 import torch import torchvision.transforms as transforms import os import sys app = Flask(__name__) HTML_TEMPLATE = """ ZeroIG Enhancement

🌟 ZeroIG: Zero-Shot Low-Light Enhancement

CVPR 2024 - Upload a low-light image for professional enhancement!



{% if status %}
{{ status }}
{% endif %} {% if error %}
{{ error }}
{% endif %} {% if original_image and result_image %}

Results:

Original (Low-light)

Original
{% endif %}

About ZeroIG

Zero-shot illumination-guided joint denoising and adaptive enhancement for low-light images.

Features: No training data required • Joint denoising & enhancement • Prevents over-enhancement

📄 Research Paper | 💻 Source Code

""" class ZeroIGProcessor: def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = self.load_model() print(f"ZeroIG initialized on {self.device}") def load_model(self): try: # Import your uploaded ZeroIG files from model import Finetunemodel, Network # Try to load trained weights model_path = "./weights/model.pt" if os.path.exists(model_path): print(f"Found model weights at {model_path}") model = Finetunemodel(model_path) print("✅ Loaded ZeroIG Finetunemodel with trained weights") else: print("No trained weights found, using Network with random initialization") model = Network() print("⚠️ Using ZeroIG Network with random weights") model.to(self.device) model.eval() print(f"Model moved to {self.device}") return model except ImportError as e: print(f"❌ Could not import ZeroIG modules: {e}") print("Make sure you have uploaded: model.py, loss.py, utils.py") return None except Exception as e: print(f"❌ Could not load ZeroIG model: {e}") return None def enhance_image(self, image): """Enhance image using your ZeroIG model""" try: if self.model is None: return self.simple_enhance(image), "❌ ZeroIG model not available - using simple enhancement" # Resize if too large to prevent memory issues original_size = image.size max_size = 800 # Adjust based on your needs if max(image.size) > max_size: ratio = max_size / max(image.size) new_size = tuple(int(dim * ratio) for dim in image.size) image = image.resize(new_size, Image.Resampling.LANCZOS) print(f"Resized image from {original_size} to {image.size}") # Convert to tensor transform = transforms.ToTensor() input_tensor = transform(image).unsqueeze(0).to(self.device) print(f"Input tensor shape: {input_tensor.shape}") # Run your ZeroIG model with torch.no_grad(): if hasattr(self.model, 'enhance') and hasattr(self.model, 'denoise_1'): # Finetunemodel - returns (enhanced, denoised) enhanced, denoised = self.model(input_tensor) result_tensor = denoised # Use denoised output status = "✅ Enhanced with ZeroIG Finetunemodel" print("Used Finetunemodel") else: # Network model - returns multiple outputs outputs = self.model(input_tensor) result_tensor = outputs[13] # H3 is the final denoised result status = "✅ Enhanced with ZeroIG Network model" print("Used Network model") # Convert back to PIL result_tensor = result_tensor.squeeze(0).cpu().clamp(0, 1) enhanced_image = transforms.ToPILImage()(result_tensor) print(f"Output image size: {enhanced_image.size}") # Resize back to original size if needed if enhanced_image.size != original_size and original_size != image.size: enhanced_image = enhanced_image.resize(original_size, Image.Resampling.LANCZOS) print(f"Resized back to original size: {enhanced_image.size}") return enhanced_image, status except Exception as e: print(f"ZeroIG enhancement error: {e}") import traceback traceback.print_exc() return self.simple_enhance(image), f"⚠️ ZeroIG failed, using simple enhancement: {str(e)}" def simple_enhance(self, image): """Fallback simple enhancement""" arr = np.array(image).astype(np.float32) enhanced = np.clip(arr * 1.8, 0, 255).astype(np.uint8) return Image.fromarray(enhanced) # Initialize ZeroIG processor print("🚀 Loading ZeroIG processor...") zeroig = ZeroIGProcessor() def image_to_base64(image): """Convert PIL image to base64 string""" img_buffer = io.BytesIO() image.save(img_buffer, format='PNG') img_str = base64.b64encode(img_buffer.getvalue()).decode() return img_str @app.route('/', methods=['GET', 'POST']) def index(): original_image = None result_image = None status = None error = None if request.method == 'POST': try: file = request.files['image'] if file: print(f"Processing uploaded image: {file.filename}") # Open and process image image = Image.open(file.stream).convert('RGB') print(f"Image size: {image.size}") # Store original for comparison original_image = image_to_base64(image) # Enhance with your ZeroIG model enhanced_image, enhancement_status = zeroig.enhance_image(image) # Convert result to base64 result_image = image_to_base64(enhanced_image) status = enhancement_status except Exception as e: error = f"Error processing image: {str(e)}" print(f"Error: {e}") import traceback traceback.print_exc() return render_template_string(HTML_TEMPLATE, original_image=original_image, result_image=result_image, status=status, error=error) if __name__ == '__main__': print("🚀 Starting ZeroIG Flask app...") app.run(host='0.0.0.0', port=7860)