Rightlight / app.py
mike23415's picture
Update app.py
7f980fd verified
raw
history blame
2.73 kB
from flask import Flask, request, send_file
from flask_cors import CORS
import numpy as np
from PIL import Image
import io
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
app = Flask(__name__)
CORS(app)
# Initialize Real-ESRGAN upsampler
def initialize_enhancer():
try:
# Configuration for RealESRGAN x4v3
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6, # Critical parameter for x4v3 model
num_grow_ch=32,
scale=4
)
# Force CPU usage for Hugging Face compatibility
device = torch.device('cpu')
return RealESRGANer(
scale=4,
model_path='weights/realesr-general-x4v3.pth',
model=model,
tile=0, # Set to 0 for small images, increase for large images
tile_pad=10,
pre_pad=0,
half=False, # CPU doesn't support half precision
device=device
)
except Exception as e:
print(f"Initialization error: {str(e)}")
return None
# Global upsampler instance
upsampler = initialize_enhancer()
@app.route('/enhance', methods=['POST'])
def enhance_image():
if not upsampler:
return {'error': 'Model failed to initialize'}, 500
if 'file' not in request.files:
return {'error': 'No file uploaded'}, 400
try:
# Read and validate image
file = request.files['file']
if file.filename == '':
return {'error': 'Empty file submitted'}, 400
img = Image.open(file.stream).convert('RGB')
img_array = np.array(img)
# Enhance image
output, _ = upsampler.enhance(
img_array,
outscale=4, # 4x super-resolution
alpha_upsampler='realesrgan'
)
# Convert to JPEG bytes
img_byte_arr = io.BytesIO()
Image.fromarray(output).save(img_byte_arr, format='JPEG', quality=95)
img_byte_arr.seek(0)
return send_file(img_byte_arr, mimetype='image/jpeg')
except Exception as e:
return {'error': f'Processing error: {str(e)}'}, 500
@app.route('/health', methods=['GET'])
def health_check():
status = 'ready' if upsampler else 'unavailable'
return {'status': status}, 200
@app.route('/')
def home():
return {
'message': 'Image Enhancement API',
'endpoints': {
'POST /enhance': 'Process images (4x upscale)',
'GET /health': 'Service status check'
}
}, 200
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)