from flask import Flask, request, jsonify, send_file from flask_cors import CORS import os import json import logging from api_endpoint import VideoGenerationAPI, download_models # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) CORS(app) # Enable CORS for all routes # Global API instance api_instance = None def get_api_instance(): """Get or create the global API instance""" global api_instance if api_instance is None: logger.info("Initializing Video Generation API...") try: # Download models first download_models() api_instance = VideoGenerationAPI() logger.info("API instance created successfully") except Exception as e: logger.error(f"Failed to initialize API: {e}") api_instance = None return api_instance @app.route('/health', methods=['GET']) def health_check(): """Health check endpoint""" api = get_api_instance() return jsonify({ "status": "healthy" if api is not None else "unhealthy", "device": api.device if api else "unknown", "models_loaded": api.pipe is not None if api else False }) @app.route('/generate', methods=['POST']) def generate_video(): """Generate video from text prompt""" api = get_api_instance() if api is None: return jsonify({ "error": "API not initialized", "video_path": None }), 500 try: # Get request data data = request.get_json() if not data or 'prompt' not in data: return jsonify({ "error": "Missing 'prompt' in request body", "video_path": None }), 400 # Extract parameters with defaults prompt = data['prompt'] negative_prompt = data.get('negative_prompt', '') seed = data.get('seed', -1) cfg_scale = data.get('cfg_scale', 7.0) clip_length = data.get('clip_length', 64) motion_scale = data.get('motion_scale', 0.5) fps = data.get('fps', 15.0) enhance_prompt_flag = data.get('enhance_prompt_flag', True) num_inference_steps = data.get('num_inference_steps', 50) # Validate parameters if not isinstance(prompt, str) or len(prompt.strip()) == 0: return jsonify({ "error": "Prompt must be a non-empty string", "video_path": None }), 400 # Generate video result = api.generate_video( prompt=prompt, negative_prompt=negative_prompt, seed=seed, cfg_scale=cfg_scale, clip_length=clip_length, motion_scale=motion_scale, fps=fps, enhance_prompt_flag=enhance_prompt_flag, num_inference_steps=num_inference_steps ) if "error" in result: return jsonify(result), 500 return jsonify(result) except Exception as e: logger.error(f"Error in generate_video endpoint: {e}") return jsonify({ "error": str(e), "video_path": None }), 500 @app.route('/video/', methods=['GET']) def serve_video(filename): """Serve generated video files""" try: video_path = os.path.join("outputs", filename) if os.path.exists(video_path): return send_file(video_path, mimetype='video/mp4') else: return jsonify({"error": "Video file not found"}), 404 except Exception as e: logger.error(f"Error serving video: {e}") return jsonify({"error": str(e)}), 500 @app.route('/enhance_prompt', methods=['POST']) def enhance_prompt(): """Enhance a text prompt using LLM""" api = get_api_instance() if api is None: return jsonify({ "error": "API not initialized", "enhanced_prompt": None }), 500 try: data = request.get_json() if not data or 'prompt' not in data: return jsonify({ "error": "Missing 'prompt' in request body", "enhanced_prompt": None }), 400 prompt = data['prompt'] if not isinstance(prompt, str) or len(prompt.strip()) == 0: return jsonify({ "error": "Prompt must be a non-empty string", "enhanced_prompt": None }), 400 enhanced_prompt = api.enhance_prompt(prompt) return jsonify({ "original_prompt": prompt, "enhanced_prompt": enhanced_prompt }) except Exception as e: logger.error(f"Error in enhance_prompt endpoint: {e}") return jsonify({ "error": str(e), "enhanced_prompt": None }), 500 @app.route('/', methods=['GET']) def index(): """API documentation""" return jsonify({ "name": "Self-Forcing Video Generation API", "version": "1.0.0", "description": "Generate high-quality videos from text descriptions using the Self-Forcing model", "endpoints": { "GET /": "API documentation", "GET /health": "Health check", "POST /generate": "Generate video from text prompt", "POST /enhance_prompt": "Enhance text prompt using LLM", "GET /video/": "Serve generated video files" }, "example_request": { "url": "/generate", "method": "POST", "body": { "prompt": "A cat playing with a ball in a sunny garden", "negative_prompt": "blurry, low quality, distorted", "seed": -1, "cfg_scale": 7.0, "clip_length": 64, "motion_scale": 0.5, "fps": 15.0, "enhance_prompt_flag": True, "num_inference_steps": 50 } } }) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description="Self-Forcing Video Generation Flask API") parser.add_argument('--port', type=int, default=5000, help="Port to run the API on") parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the API to") parser.add_argument('--debug', action='store_true', help="Run in debug mode") args = parser.parse_args() logger.info(f"Starting Flask API on {args.host}:{args.port}") app.run(host=args.host, port=args.port, debug=args.debug)