Spaces:
Running
Running
File size: 5,863 Bytes
c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 7a9faec c21bd78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import gradio as gr
import torch
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
import tempfile
import os
import cv2
import numpy as np
# Load model and processor
model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
# Initialize processor and model with error handling
try:
processor = LlavaNextVideoProcessor.from_pretrained(model_id)
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
except Exception as e:
print(f"Error loading model: {e}")
processor = None
model = None
def process_video_file(video_path):
"""Convert video file to the format expected by the model"""
try:
# Read video using OpenCV
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
# Convert to numpy array and normalize
video_frames = np.array(frames)
return video_frames
except Exception as e:
print(f"Error processing video: {e}")
return None
def analyze_dance_video(video_file, pose_scores="0.85", music_info="Unknown"):
"""
Analyze dance video with pose scores and music information
"""
if model is None or processor is None:
return "Error: Model not loaded properly. Please check the logs."
if video_file is None:
return "Please upload a video file."
try:
# Process the video file
video_frames = process_video_file(video_file)
if video_frames is None:
return "Error: Could not process video file."
# Prepare the prompt
prompt = f"""USER: You are an expert dance instructor. Analyze this dance performance video.
Additional Data:
- MediaPipe Pose Scores: {pose_scores}
- Music Information: {music_info}
- When analyzing, combine what you see in the video with the pose scores and the music details to provide precise, realistic feedback.
Please provide detailed feedback on:
1. Timing and synchronization with music
2. Pose accuracy and technique
3. Movement flow and transitions
4. Areas for improvement
5. Overall performance rating (1-10)
6. How well the dancer's moves are synchronized with the music's tempo and beat.
7. The accuracy and technique of the dancer's poses, considering the pose scores.
8. The fluidity and smoothness of transitions between moves.
9. Specific areas where the dancer can improve.
Give constructive feedback in a friendly, encouraging tone.
ASSISTANT:"""
# Process video with the model
inputs = processor(
text=prompt,
videos=[video_frames], # Note: videos expects a list
return_tensors="pt"
)
# Move inputs to the same device as model
inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
# Generate response
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=500,
do_sample=True,
temperature=0.7,
pad_token_id=processor.tokenizer.eos_token_id,
eos_token_id=processor.tokenizer.eos_token_id
)
# Decode response
response = processor.decode(output[0], skip_special_tokens=True)
# Extract just the generated part (after ASSISTANT:)
if "ASSISTANT:" in response:
response = response.split("ASSISTANT:")[-1].strip()
return response
except Exception as e:
return f"Error analyzing video: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="AI Dance Instructor") as demo:
gr.Markdown("# 🕺 AI Dance Instructor")
gr.Markdown("Upload your dance video along with pose scores for detailed feedback!")
with gr.Row():
with gr.Column():
video_input = gr.Video(
label="Upload Dance Video"
)
pose_scores = gr.Textbox(
label="MediaPipe Pose Scores",
placeholder="Enter pose scores data from MediaPipe (e.g., 0.85)...",
value="0.85",
lines=3
)
music_info = gr.Textbox(
label="Music Information",
placeholder="BPM, genre, rhythm details...",
value="120 BPM, Pop music",
lines=3
)
analyze_btn = gr.Button("Analyze Dance", variant="primary")
with gr.Column():
feedback_output = gr.Textbox(
label="Dance Feedback",
lines=15,
interactive=False
)
# Set up the analysis function
analyze_btn.click(
fn=analyze_dance_video,
inputs=[video_input, pose_scores, music_info],
outputs=[feedback_output]
)
# Add API endpoint info
gr.Markdown("### API Usage")
gr.Markdown("""
**For Next.js Integration, use the API endpoint:**
```javascript
const formData = new FormData();
formData.append('data', JSON.stringify([videoFile, poseScores, musicInfo]));
const response = await fetch('https://your-space-name.hf.space/api/predict', {
method: 'POST',
body: formData
});
```
""")
if __name__ == "__main__":
demo.launch() |