Spaces:
Sleeping
Sleeping
from flask import Flask, request, send_file | |
from flask_cors import CORS | |
import numpy as np | |
from PIL import Image | |
import io | |
import torch | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
from realesrgan import RealESRGANer | |
app = Flask(__name__) | |
CORS(app) | |
# Initialize Real-ESRGAN upsampler | |
def initialize_enhancer(): | |
try: | |
# Configuration for RealESRGAN x4v3 | |
model = RRDBNet( | |
num_in_ch=3, | |
num_out_ch=3, | |
num_feat=64, | |
num_block=6, # Critical parameter for x4v3 model | |
num_grow_ch=32, | |
scale=4 | |
) | |
# Force CPU usage for Hugging Face compatibility | |
device = torch.device('cpu') | |
return RealESRGANer( | |
scale=4, | |
model_path='weights/realesr-general-x4v3.pth', | |
model=model, | |
tile=0, # Set to 0 for small images, increase for large images | |
tile_pad=10, | |
pre_pad=0, | |
half=False, # CPU doesn't support half precision | |
device=device | |
) | |
except Exception as e: | |
print(f"Initialization error: {str(e)}") | |
return None | |
# Global upsampler instance | |
upsampler = initialize_enhancer() | |
def enhance_image(): | |
if not upsampler: | |
return {'error': 'Model failed to initialize'}, 500 | |
if 'file' not in request.files: | |
return {'error': 'No file uploaded'}, 400 | |
try: | |
# Read and validate image | |
file = request.files['file'] | |
if file.filename == '': | |
return {'error': 'Empty file submitted'}, 400 | |
img = Image.open(file.stream).convert('RGB') | |
img_array = np.array(img) | |
# Enhance image | |
output, _ = upsampler.enhance( | |
img_array, | |
outscale=4, # 4x super-resolution | |
alpha_upsampler='realesrgan' | |
) | |
# Convert to JPEG bytes | |
img_byte_arr = io.BytesIO() | |
Image.fromarray(output).save(img_byte_arr, format='JPEG', quality=95) | |
img_byte_arr.seek(0) | |
return send_file(img_byte_arr, mimetype='image/jpeg') | |
except Exception as e: | |
return {'error': f'Processing error: {str(e)}'}, 500 | |
def health_check(): | |
status = 'ready' if upsampler else 'unavailable' | |
return {'status': status}, 200 | |
def home(): | |
return { | |
'message': 'Image Enhancement API', | |
'endpoints': { | |
'POST /enhance': 'Process images (4x upscale)', | |
'GET /health': 'Service status check' | |
} | |
}, 200 | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=5000) |