Aditi755's picture
Update app.py
7a9faec verified
import gradio as gr
import torch
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
import tempfile
import os
import cv2
import numpy as np
# Load model and processor
model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
# Initialize processor and model with error handling
try:
processor = LlavaNextVideoProcessor.from_pretrained(model_id)
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
except Exception as e:
print(f"Error loading model: {e}")
processor = None
model = None
def process_video_file(video_path):
"""Convert video file to the format expected by the model"""
try:
# Read video using OpenCV
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
# Convert to numpy array and normalize
video_frames = np.array(frames)
return video_frames
except Exception as e:
print(f"Error processing video: {e}")
return None
def analyze_dance_video(video_file, pose_scores="0.85", music_info="Unknown"):
"""
Analyze dance video with pose scores and music information
"""
if model is None or processor is None:
return "Error: Model not loaded properly. Please check the logs."
if video_file is None:
return "Please upload a video file."
try:
# Process the video file
video_frames = process_video_file(video_file)
if video_frames is None:
return "Error: Could not process video file."
# Prepare the prompt
prompt = f"""USER: You are an expert dance instructor. Analyze this dance performance video.
Additional Data:
- MediaPipe Pose Scores: {pose_scores}
- Music Information: {music_info}
- When analyzing, combine what you see in the video with the pose scores and the music details to provide precise, realistic feedback.
Please provide detailed feedback on:
1. Timing and synchronization with music
2. Pose accuracy and technique
3. Movement flow and transitions
4. Areas for improvement
5. Overall performance rating (1-10)
6. How well the dancer's moves are synchronized with the music's tempo and beat.
7. The accuracy and technique of the dancer's poses, considering the pose scores.
8. The fluidity and smoothness of transitions between moves.
9. Specific areas where the dancer can improve.
Give constructive feedback in a friendly, encouraging tone.
ASSISTANT:"""
# Process video with the model
inputs = processor(
text=prompt,
videos=[video_frames], # Note: videos expects a list
return_tensors="pt"
)
# Move inputs to the same device as model
inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
# Generate response
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=500,
do_sample=True,
temperature=0.7,
pad_token_id=processor.tokenizer.eos_token_id,
eos_token_id=processor.tokenizer.eos_token_id
)
# Decode response
response = processor.decode(output[0], skip_special_tokens=True)
# Extract just the generated part (after ASSISTANT:)
if "ASSISTANT:" in response:
response = response.split("ASSISTANT:")[-1].strip()
return response
except Exception as e:
return f"Error analyzing video: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="AI Dance Instructor") as demo:
gr.Markdown("# ๐Ÿ•บ AI Dance Instructor")
gr.Markdown("Upload your dance video along with pose scores for detailed feedback!")
with gr.Row():
with gr.Column():
video_input = gr.Video(
label="Upload Dance Video"
)
pose_scores = gr.Textbox(
label="MediaPipe Pose Scores",
placeholder="Enter pose scores data from MediaPipe (e.g., 0.85)...",
value="0.85",
lines=3
)
music_info = gr.Textbox(
label="Music Information",
placeholder="BPM, genre, rhythm details...",
value="120 BPM, Pop music",
lines=3
)
analyze_btn = gr.Button("Analyze Dance", variant="primary")
with gr.Column():
feedback_output = gr.Textbox(
label="Dance Feedback",
lines=15,
interactive=False
)
# Set up the analysis function
analyze_btn.click(
fn=analyze_dance_video,
inputs=[video_input, pose_scores, music_info],
outputs=[feedback_output]
)
# Add API endpoint info
gr.Markdown("### API Usage")
gr.Markdown("""
**For Next.js Integration, use the API endpoint:**
```javascript
const formData = new FormData();
formData.append('data', JSON.stringify([videoFile, poseScores, musicInfo]));
const response = await fetch('https://your-space-name.hf.space/api/predict', {
method: 'POST',
body: formData
});
```
""")
if __name__ == "__main__":
demo.launch()