Spaces:
Sleeping
Sleeping
File size: 6,681 Bytes
1bc0a1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import os
import json
import logging
from api_endpoint import VideoGenerationAPI, download_models
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Global API instance
api_instance = None
def get_api_instance():
"""Get or create the global API instance"""
global api_instance
if api_instance is None:
logger.info("Initializing Video Generation API...")
try:
# Download models first
download_models()
api_instance = VideoGenerationAPI()
logger.info("API instance created successfully")
except Exception as e:
logger.error(f"Failed to initialize API: {e}")
api_instance = None
return api_instance
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
api = get_api_instance()
return jsonify({
"status": "healthy" if api is not None else "unhealthy",
"device": api.device if api else "unknown",
"models_loaded": api.pipe is not None if api else False
})
@app.route('/generate', methods=['POST'])
def generate_video():
"""Generate video from text prompt"""
api = get_api_instance()
if api is None:
return jsonify({
"error": "API not initialized",
"video_path": None
}), 500
try:
# Get request data
data = request.get_json()
if not data or 'prompt' not in data:
return jsonify({
"error": "Missing 'prompt' in request body",
"video_path": None
}), 400
# Extract parameters with defaults
prompt = data['prompt']
negative_prompt = data.get('negative_prompt', '')
seed = data.get('seed', -1)
cfg_scale = data.get('cfg_scale', 7.0)
clip_length = data.get('clip_length', 64)
motion_scale = data.get('motion_scale', 0.5)
fps = data.get('fps', 15.0)
enhance_prompt_flag = data.get('enhance_prompt_flag', True)
num_inference_steps = data.get('num_inference_steps', 50)
# Validate parameters
if not isinstance(prompt, str) or len(prompt.strip()) == 0:
return jsonify({
"error": "Prompt must be a non-empty string",
"video_path": None
}), 400
# Generate video
result = api.generate_video(
prompt=prompt,
negative_prompt=negative_prompt,
seed=seed,
cfg_scale=cfg_scale,
clip_length=clip_length,
motion_scale=motion_scale,
fps=fps,
enhance_prompt_flag=enhance_prompt_flag,
num_inference_steps=num_inference_steps
)
if "error" in result:
return jsonify(result), 500
return jsonify(result)
except Exception as e:
logger.error(f"Error in generate_video endpoint: {e}")
return jsonify({
"error": str(e),
"video_path": None
}), 500
@app.route('/video/<filename>', methods=['GET'])
def serve_video(filename):
"""Serve generated video files"""
try:
video_path = os.path.join("outputs", filename)
if os.path.exists(video_path):
return send_file(video_path, mimetype='video/mp4')
else:
return jsonify({"error": "Video file not found"}), 404
except Exception as e:
logger.error(f"Error serving video: {e}")
return jsonify({"error": str(e)}), 500
@app.route('/enhance_prompt', methods=['POST'])
def enhance_prompt():
"""Enhance a text prompt using LLM"""
api = get_api_instance()
if api is None:
return jsonify({
"error": "API not initialized",
"enhanced_prompt": None
}), 500
try:
data = request.get_json()
if not data or 'prompt' not in data:
return jsonify({
"error": "Missing 'prompt' in request body",
"enhanced_prompt": None
}), 400
prompt = data['prompt']
if not isinstance(prompt, str) or len(prompt.strip()) == 0:
return jsonify({
"error": "Prompt must be a non-empty string",
"enhanced_prompt": None
}), 400
enhanced_prompt = api.enhance_prompt(prompt)
return jsonify({
"original_prompt": prompt,
"enhanced_prompt": enhanced_prompt
})
except Exception as e:
logger.error(f"Error in enhance_prompt endpoint: {e}")
return jsonify({
"error": str(e),
"enhanced_prompt": None
}), 500
@app.route('/', methods=['GET'])
def index():
"""API documentation"""
return jsonify({
"name": "Self-Forcing Video Generation API",
"version": "1.0.0",
"description": "Generate high-quality videos from text descriptions using the Self-Forcing model",
"endpoints": {
"GET /": "API documentation",
"GET /health": "Health check",
"POST /generate": "Generate video from text prompt",
"POST /enhance_prompt": "Enhance text prompt using LLM",
"GET /video/<filename>": "Serve generated video files"
},
"example_request": {
"url": "/generate",
"method": "POST",
"body": {
"prompt": "A cat playing with a ball in a sunny garden",
"negative_prompt": "blurry, low quality, distorted",
"seed": -1,
"cfg_scale": 7.0,
"clip_length": 64,
"motion_scale": 0.5,
"fps": 15.0,
"enhance_prompt_flag": True,
"num_inference_steps": 50
}
}
})
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description="Self-Forcing Video Generation Flask API")
parser.add_argument('--port', type=int, default=5000, help="Port to run the API on")
parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the API to")
parser.add_argument('--debug', action='store_true', help="Run in debug mode")
args = parser.parse_args()
logger.info(f"Starting Flask API on {args.host}:{args.port}")
app.run(host=args.host, port=args.port, debug=args.debug)
|