Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration | |
import tempfile | |
import os | |
import cv2 | |
import numpy as np | |
# Load model and processor | |
model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" | |
# Initialize processor and model with error handling | |
try: | |
processor = LlavaNextVideoProcessor.from_pretrained(model_id) | |
model = LlavaNextVideoForConditionalGeneration.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
low_cpu_mem_usage=True | |
) | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
processor = None | |
model = None | |
def process_video_file(video_path): | |
"""Convert video file to the format expected by the model""" | |
try: | |
# Read video using OpenCV | |
cap = cv2.VideoCapture(video_path) | |
frames = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Convert BGR to RGB | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(frame) | |
cap.release() | |
# Convert to numpy array and normalize | |
video_frames = np.array(frames) | |
return video_frames | |
except Exception as e: | |
print(f"Error processing video: {e}") | |
return None | |
def analyze_dance_video(video_file, pose_scores="0.85", music_info="Unknown"): | |
""" | |
Analyze dance video with pose scores and music information | |
""" | |
if model is None or processor is None: | |
return "Error: Model not loaded properly. Please check the logs." | |
if video_file is None: | |
return "Please upload a video file." | |
try: | |
# Process the video file | |
video_frames = process_video_file(video_file) | |
if video_frames is None: | |
return "Error: Could not process video file." | |
# Prepare the prompt | |
prompt = f"""USER: You are an expert dance instructor. Analyze this dance performance video. | |
Additional Data: | |
- MediaPipe Pose Scores: {pose_scores} | |
- Music Information: {music_info} | |
- When analyzing, combine what you see in the video with the pose scores and the music details to provide precise, realistic feedback. | |
Please provide detailed feedback on: | |
1. Timing and synchronization with music | |
2. Pose accuracy and technique | |
3. Movement flow and transitions | |
4. Areas for improvement | |
5. Overall performance rating (1-10) | |
6. How well the dancer's moves are synchronized with the music's tempo and beat. | |
7. The accuracy and technique of the dancer's poses, considering the pose scores. | |
8. The fluidity and smoothness of transitions between moves. | |
9. Specific areas where the dancer can improve. | |
Give constructive feedback in a friendly, encouraging tone. | |
ASSISTANT:""" | |
# Process video with the model | |
inputs = processor( | |
text=prompt, | |
videos=[video_frames], # Note: videos expects a list | |
return_tensors="pt" | |
) | |
# Move inputs to the same device as model | |
inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
# Generate response | |
with torch.no_grad(): | |
output = model.generate( | |
**inputs, | |
max_new_tokens=500, | |
do_sample=True, | |
temperature=0.7, | |
pad_token_id=processor.tokenizer.eos_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id | |
) | |
# Decode response | |
response = processor.decode(output[0], skip_special_tokens=True) | |
# Extract just the generated part (after ASSISTANT:) | |
if "ASSISTANT:" in response: | |
response = response.split("ASSISTANT:")[-1].strip() | |
return response | |
except Exception as e: | |
return f"Error analyzing video: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks(title="AI Dance Instructor") as demo: | |
gr.Markdown("# ๐บ AI Dance Instructor") | |
gr.Markdown("Upload your dance video along with pose scores for detailed feedback!") | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video( | |
label="Upload Dance Video" | |
) | |
pose_scores = gr.Textbox( | |
label="MediaPipe Pose Scores", | |
placeholder="Enter pose scores data from MediaPipe (e.g., 0.85)...", | |
value="0.85", | |
lines=3 | |
) | |
music_info = gr.Textbox( | |
label="Music Information", | |
placeholder="BPM, genre, rhythm details...", | |
value="120 BPM, Pop music", | |
lines=3 | |
) | |
analyze_btn = gr.Button("Analyze Dance", variant="primary") | |
with gr.Column(): | |
feedback_output = gr.Textbox( | |
label="Dance Feedback", | |
lines=15, | |
interactive=False | |
) | |
# Set up the analysis function | |
analyze_btn.click( | |
fn=analyze_dance_video, | |
inputs=[video_input, pose_scores, music_info], | |
outputs=[feedback_output] | |
) | |
# Add API endpoint info | |
gr.Markdown("### API Usage") | |
gr.Markdown(""" | |
**For Next.js Integration, use the API endpoint:** | |
```javascript | |
const formData = new FormData(); | |
formData.append('data', JSON.stringify([videoFile, poseScores, musicInfo])); | |
const response = await fetch('https://your-space-name.hf.space/api/predict', { | |
method: 'POST', | |
body: formData | |
}); | |
``` | |
""") | |
if __name__ == "__main__": | |
demo.launch() |