File size: 5,863 Bytes
c21bd78
 
 
 
 
7a9faec
 
c21bd78
 
 
 
7a9faec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c21bd78
 
 
7a9faec
 
 
 
 
 
c21bd78
7a9faec
 
 
 
 
c21bd78
7a9faec
c21bd78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a9faec
c21bd78
 
7a9faec
c21bd78
7a9faec
c21bd78
 
7a9faec
c21bd78
7a9faec
 
 
 
c21bd78
 
 
 
 
 
 
 
7a9faec
 
c21bd78
 
 
 
 
7a9faec
 
 
c21bd78
 
 
 
 
 
 
 
 
 
 
 
 
 
7a9faec
c21bd78
 
 
7a9faec
 
 
c21bd78
 
 
 
7a9faec
c21bd78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a9faec
c21bd78
 
7a9faec
c21bd78
7a9faec
 
 
c21bd78
 
7a9faec
c21bd78
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import gradio as gr
import torch
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
import tempfile
import os
import cv2
import numpy as np

# Load model and processor
model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"

# Initialize processor and model with error handling
try:
    processor = LlavaNextVideoProcessor.from_pretrained(model_id)
    model = LlavaNextVideoForConditionalGeneration.from_pretrained(
        model_id, 
        torch_dtype=torch.float16, 
        device_map="auto",
        low_cpu_mem_usage=True
    )
except Exception as e:
    print(f"Error loading model: {e}")
    processor = None
    model = None

def process_video_file(video_path):
    """Convert video file to the format expected by the model"""
    try:
        # Read video using OpenCV
        cap = cv2.VideoCapture(video_path)
        frames = []
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            # Convert BGR to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
        
        cap.release()
        
        # Convert to numpy array and normalize
        video_frames = np.array(frames)
        return video_frames
        
    except Exception as e:
        print(f"Error processing video: {e}")
        return None

def analyze_dance_video(video_file, pose_scores="0.85", music_info="Unknown"):
    """
    Analyze dance video with pose scores and music information
    """
    if model is None or processor is None:
        return "Error: Model not loaded properly. Please check the logs."
    
    if video_file is None:
        return "Please upload a video file."
    
    try:
        # Process the video file
        video_frames = process_video_file(video_file)
        if video_frames is None:
            return "Error: Could not process video file."
        
        # Prepare the prompt
        prompt = f"""USER: You are an expert dance instructor. Analyze this dance performance video.
        
        Additional Data:
        - MediaPipe Pose Scores: {pose_scores}
        - Music Information: {music_info}
        - When analyzing, combine what you see in the video with the pose scores and the music details to provide precise, realistic feedback.

        Please provide detailed feedback on:
        1. Timing and synchronization with music
        2. Pose accuracy and technique
        3. Movement flow and transitions
        4. Areas for improvement
        5. Overall performance rating (1-10)
        6. How well the dancer's moves are synchronized with the music's tempo and beat.
        7. The accuracy and technique of the dancer's poses, considering the pose scores.
        8. The fluidity and smoothness of transitions between moves.
        9. Specific areas where the dancer can improve.

        
        Give constructive feedback in a friendly, encouraging tone.
        ASSISTANT:"""
        
        # Process video with the model
        inputs = processor(
            text=prompt,
            videos=[video_frames],  # Note: videos expects a list
            return_tensors="pt"
        )
        
        # Move inputs to the same device as model
        inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
        
        # Generate response
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=500,
                do_sample=True,
                temperature=0.7,
                pad_token_id=processor.tokenizer.eos_token_id,
                eos_token_id=processor.tokenizer.eos_token_id
            )
        
        # Decode response
        response = processor.decode(output[0], skip_special_tokens=True)
        
        # Extract just the generated part (after ASSISTANT:)
        if "ASSISTANT:" in response:
            response = response.split("ASSISTANT:")[-1].strip()
        
        return response
        
    except Exception as e:
        return f"Error analyzing video: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="AI Dance Instructor") as demo:
    gr.Markdown("# 🕺 AI Dance Instructor")
    gr.Markdown("Upload your dance video along with pose scores for detailed feedback!")
    
    with gr.Row():
        with gr.Column():
            video_input = gr.Video(
                label="Upload Dance Video"
            )
            pose_scores = gr.Textbox(
                label="MediaPipe Pose Scores",
                placeholder="Enter pose scores data from MediaPipe (e.g., 0.85)...",
                value="0.85",
                lines=3
            )
            music_info = gr.Textbox(
                label="Music Information",
                placeholder="BPM, genre, rhythm details...",
                value="120 BPM, Pop music",
                lines=3
            )
            analyze_btn = gr.Button("Analyze Dance", variant="primary")
        
        with gr.Column():
            feedback_output = gr.Textbox(
                label="Dance Feedback",
                lines=15,
                interactive=False
            )
    
    # Set up the analysis function
    analyze_btn.click(
        fn=analyze_dance_video,
        inputs=[video_input, pose_scores, music_info],
        outputs=[feedback_output]
    )
    
    # Add API endpoint info
    gr.Markdown("### API Usage")
    gr.Markdown("""
    **For Next.js Integration, use the API endpoint:**
    ```javascript
    const formData = new FormData();
    formData.append('data', JSON.stringify([videoFile, poseScores, musicInfo]));
    
    const response = await fetch('https://your-space-name.hf.space/api/predict', {
      method: 'POST',
      body: formData
    });
    ```
    """)

if __name__ == "__main__":
    demo.launch()