Aditi755 commited on
Commit
7a9faec
·
verified ·
1 Parent(s): 9b3e431

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -26
app.py CHANGED
@@ -3,24 +3,69 @@ import torch
3
  from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
4
  import tempfile
5
  import os
 
 
6
 
7
  # Load model and processor
8
  model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
9
- processor = LlavaNextVideoProcessor.from_pretrained(model_id)
10
- model = LlavaNextVideoForConditionalGeneration.from_pretrained(
11
- model_id,
12
- torch_dtype=torch.float16,
13
- device_map="auto"
14
- )
15
 
16
- def analyze_dance_video(video_file, pose_scores = 0.85, music_info="Unknown"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
  Analyze dance video with pose scores and music information
19
  """
 
 
 
 
 
 
20
  try:
 
 
 
 
 
21
  # Prepare the prompt
22
- prompt = f"""
23
- You are an expert dance instructor. Analyze this dance performance video.
24
 
25
  Additional Data:
26
  - MediaPipe Pose Scores: {pose_scores}
@@ -37,16 +82,20 @@ def analyze_dance_video(video_file, pose_scores = 0.85, music_info="Unknown"):
37
  7. The accuracy and technique of the dancer's poses, considering the pose scores.
38
  8. The fluidity and smoothness of transitions between moves.
39
  9. Specific areas where the dancer can improve.
 
40
 
41
  Give constructive feedback in a friendly, encouraging tone.
42
- """
43
 
44
- # Process video
45
  inputs = processor(
46
  text=prompt,
47
- videos=video_file,
48
  return_tensors="pt"
49
- ).to(model.device, torch.float16)
 
 
 
50
 
51
  # Generate response
52
  with torch.no_grad():
@@ -55,14 +104,16 @@ def analyze_dance_video(video_file, pose_scores = 0.85, music_info="Unknown"):
55
  max_new_tokens=500,
56
  do_sample=True,
57
  temperature=0.7,
58
- pad_token_id=processor.tokenizer.eos_token_id
 
59
  )
60
 
61
  # Decode response
62
  response = processor.decode(output[0], skip_special_tokens=True)
63
 
64
- # Extract just the generated part (after the prompt)
65
- response = response.split("Give constructive feedback in a friendly, encouraging tone.")[-1].strip()
 
66
 
67
  return response
68
 
@@ -77,17 +128,18 @@ with gr.Blocks(title="AI Dance Instructor") as demo:
77
  with gr.Row():
78
  with gr.Column():
79
  video_input = gr.Video(
80
- label="Upload Dance Video",
81
- format="mp4"
82
  )
83
  pose_scores = gr.Textbox(
84
  label="MediaPipe Pose Scores",
85
- placeholder="Enter pose scores data from MediaPipe...",
86
- lines=5
 
87
  )
88
  music_info = gr.Textbox(
89
  label="Music Information",
90
  placeholder="BPM, genre, rhythm details...",
 
91
  lines=3
92
  )
93
  analyze_btn = gr.Button("Analyze Dance", variant="primary")
@@ -106,17 +158,17 @@ with gr.Blocks(title="AI Dance Instructor") as demo:
106
  outputs=[feedback_output]
107
  )
108
 
109
- # Add API endpoint for Next.js integration
110
  gr.Markdown("### API Usage")
111
  gr.Markdown("""
112
- **For Next.js Integration:**
113
  ```javascript
 
 
 
114
  const response = await fetch('https://your-space-name.hf.space/api/predict', {
115
  method: 'POST',
116
- headers: { 'Content-Type': 'application/json' },
117
- body: JSON.stringify({
118
- data: [videoFile, poseScores, musicInfo]
119
- })
120
  });
121
  ```
122
  """)
 
3
  from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
4
  import tempfile
5
  import os
6
+ import cv2
7
+ import numpy as np
8
 
9
  # Load model and processor
10
  model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
 
 
 
 
 
 
11
 
12
+ # Initialize processor and model with error handling
13
+ try:
14
+ processor = LlavaNextVideoProcessor.from_pretrained(model_id)
15
+ model = LlavaNextVideoForConditionalGeneration.from_pretrained(
16
+ model_id,
17
+ torch_dtype=torch.float16,
18
+ device_map="auto",
19
+ low_cpu_mem_usage=True
20
+ )
21
+ except Exception as e:
22
+ print(f"Error loading model: {e}")
23
+ processor = None
24
+ model = None
25
+
26
+ def process_video_file(video_path):
27
+ """Convert video file to the format expected by the model"""
28
+ try:
29
+ # Read video using OpenCV
30
+ cap = cv2.VideoCapture(video_path)
31
+ frames = []
32
+
33
+ while True:
34
+ ret, frame = cap.read()
35
+ if not ret:
36
+ break
37
+ # Convert BGR to RGB
38
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
39
+ frames.append(frame)
40
+
41
+ cap.release()
42
+
43
+ # Convert to numpy array and normalize
44
+ video_frames = np.array(frames)
45
+ return video_frames
46
+
47
+ except Exception as e:
48
+ print(f"Error processing video: {e}")
49
+ return None
50
+
51
+ def analyze_dance_video(video_file, pose_scores="0.85", music_info="Unknown"):
52
  """
53
  Analyze dance video with pose scores and music information
54
  """
55
+ if model is None or processor is None:
56
+ return "Error: Model not loaded properly. Please check the logs."
57
+
58
+ if video_file is None:
59
+ return "Please upload a video file."
60
+
61
  try:
62
+ # Process the video file
63
+ video_frames = process_video_file(video_file)
64
+ if video_frames is None:
65
+ return "Error: Could not process video file."
66
+
67
  # Prepare the prompt
68
+ prompt = f"""USER: You are an expert dance instructor. Analyze this dance performance video.
 
69
 
70
  Additional Data:
71
  - MediaPipe Pose Scores: {pose_scores}
 
82
  7. The accuracy and technique of the dancer's poses, considering the pose scores.
83
  8. The fluidity and smoothness of transitions between moves.
84
  9. Specific areas where the dancer can improve.
85
+
86
 
87
  Give constructive feedback in a friendly, encouraging tone.
88
+ ASSISTANT:"""
89
 
90
+ # Process video with the model
91
  inputs = processor(
92
  text=prompt,
93
+ videos=[video_frames], # Note: videos expects a list
94
  return_tensors="pt"
95
+ )
96
+
97
+ # Move inputs to the same device as model
98
+ inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
99
 
100
  # Generate response
101
  with torch.no_grad():
 
104
  max_new_tokens=500,
105
  do_sample=True,
106
  temperature=0.7,
107
+ pad_token_id=processor.tokenizer.eos_token_id,
108
+ eos_token_id=processor.tokenizer.eos_token_id
109
  )
110
 
111
  # Decode response
112
  response = processor.decode(output[0], skip_special_tokens=True)
113
 
114
+ # Extract just the generated part (after ASSISTANT:)
115
+ if "ASSISTANT:" in response:
116
+ response = response.split("ASSISTANT:")[-1].strip()
117
 
118
  return response
119
 
 
128
  with gr.Row():
129
  with gr.Column():
130
  video_input = gr.Video(
131
+ label="Upload Dance Video"
 
132
  )
133
  pose_scores = gr.Textbox(
134
  label="MediaPipe Pose Scores",
135
+ placeholder="Enter pose scores data from MediaPipe (e.g., 0.85)...",
136
+ value="0.85",
137
+ lines=3
138
  )
139
  music_info = gr.Textbox(
140
  label="Music Information",
141
  placeholder="BPM, genre, rhythm details...",
142
+ value="120 BPM, Pop music",
143
  lines=3
144
  )
145
  analyze_btn = gr.Button("Analyze Dance", variant="primary")
 
158
  outputs=[feedback_output]
159
  )
160
 
161
+ # Add API endpoint info
162
  gr.Markdown("### API Usage")
163
  gr.Markdown("""
164
+ **For Next.js Integration, use the API endpoint:**
165
  ```javascript
166
+ const formData = new FormData();
167
+ formData.append('data', JSON.stringify([videoFile, poseScores, musicInfo]));
168
+
169
  const response = await fetch('https://your-space-name.hf.space/api/predict', {
170
  method: 'POST',
171
+ body: formData
 
 
 
172
  });
173
  ```
174
  """)