import gradio as gr import torch from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration import tempfile import os import cv2 import numpy as np # Load model and processor model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" # Initialize processor and model with error handling try: processor = LlavaNextVideoProcessor.from_pretrained(model_id) model = LlavaNextVideoForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True ) except Exception as e: print(f"Error loading model: {e}") processor = None model = None def process_video_file(video_path): """Convert video file to the format expected by the model""" try: # Read video using OpenCV cap = cv2.VideoCapture(video_path) frames = [] while True: ret, frame = cap.read() if not ret: break # Convert BGR to RGB frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) cap.release() # Convert to numpy array and normalize video_frames = np.array(frames) return video_frames except Exception as e: print(f"Error processing video: {e}") return None def analyze_dance_video(video_file, pose_scores="0.85", music_info="Unknown"): """ Analyze dance video with pose scores and music information """ if model is None or processor is None: return "Error: Model not loaded properly. Please check the logs." if video_file is None: return "Please upload a video file." try: # Process the video file video_frames = process_video_file(video_file) if video_frames is None: return "Error: Could not process video file." # Prepare the prompt prompt = f"""USER: You are an expert dance instructor. Analyze this dance performance video. Additional Data: - MediaPipe Pose Scores: {pose_scores} - Music Information: {music_info} - When analyzing, combine what you see in the video with the pose scores and the music details to provide precise, realistic feedback. Please provide detailed feedback on: 1. Timing and synchronization with music 2. Pose accuracy and technique 3. Movement flow and transitions 4. Areas for improvement 5. Overall performance rating (1-10) 6. How well the dancer's moves are synchronized with the music's tempo and beat. 7. The accuracy and technique of the dancer's poses, considering the pose scores. 8. The fluidity and smoothness of transitions between moves. 9. Specific areas where the dancer can improve. Give constructive feedback in a friendly, encouraging tone. ASSISTANT:""" # Process video with the model inputs = processor( text=prompt, videos=[video_frames], # Note: videos expects a list return_tensors="pt" ) # Move inputs to the same device as model inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} # Generate response with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=500, do_sample=True, temperature=0.7, pad_token_id=processor.tokenizer.eos_token_id, eos_token_id=processor.tokenizer.eos_token_id ) # Decode response response = processor.decode(output[0], skip_special_tokens=True) # Extract just the generated part (after ASSISTANT:) if "ASSISTANT:" in response: response = response.split("ASSISTANT:")[-1].strip() return response except Exception as e: return f"Error analyzing video: {str(e)}" # Create Gradio interface with gr.Blocks(title="AI Dance Instructor") as demo: gr.Markdown("# 🕺 AI Dance Instructor") gr.Markdown("Upload your dance video along with pose scores for detailed feedback!") with gr.Row(): with gr.Column(): video_input = gr.Video( label="Upload Dance Video" ) pose_scores = gr.Textbox( label="MediaPipe Pose Scores", placeholder="Enter pose scores data from MediaPipe (e.g., 0.85)...", value="0.85", lines=3 ) music_info = gr.Textbox( label="Music Information", placeholder="BPM, genre, rhythm details...", value="120 BPM, Pop music", lines=3 ) analyze_btn = gr.Button("Analyze Dance", variant="primary") with gr.Column(): feedback_output = gr.Textbox( label="Dance Feedback", lines=15, interactive=False ) # Set up the analysis function analyze_btn.click( fn=analyze_dance_video, inputs=[video_input, pose_scores, music_info], outputs=[feedback_output] ) # Add API endpoint info gr.Markdown("### API Usage") gr.Markdown(""" **For Next.js Integration, use the API endpoint:** ```javascript const formData = new FormData(); formData.append('data', JSON.stringify([videoFile, poseScores, musicInfo])); const response = await fetch('https://your-space-name.hf.space/api/predict', { method: 'POST', body: formData }); ``` """) if __name__ == "__main__": demo.launch()