from flask import Flask, request, jsonify | |
from handler import EndpointHandler | |
import torch | |
app = Flask(__name__) | |
# Initialize the handler | |
handler = EndpointHandler() | |
def predict(): | |
if 'file' not in request.files: | |
return jsonify({'error': 'No file provided'}), 400 | |
file = request.files['file'] | |
if file.filename == '': | |
return jsonify({'error': 'No file selected'}), 400 | |
# Read the file bytes | |
image_bytes = file.read() | |
# Get point prompts if provided | |
point_coords = request.form.get('point_coords') | |
point_labels = request.form.get('point_labels') | |
# Process with handler | |
try: | |
if point_coords and point_labels: | |
# Convert string inputs to lists | |
point_coords = eval(point_coords) # e.g. "[[500, 375]]" | |
point_labels = eval(point_labels) # e.g. "[1]" | |
result = handler({ | |
'image': image_bytes, | |
'point_coords': point_coords, | |
'point_labels': point_labels | |
}) | |
else: | |
result = handler(image_bytes) | |
return jsonify(result) | |
except Exception as e: | |
return jsonify({'error': str(e)}), 500 | |
if __name__ == '__main__': | |
app.run(debug=True, port=5000) |