2nzi commited on
Commit
f137487
·
verified ·
1 Parent(s): a779f32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -133
app.py CHANGED
@@ -1,135 +1,151 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- import cv2
3
- import torch
4
- import pandas as pd
5
- from PIL import Image
6
- from transformers import AutoImageProcessor, AutoModelForImageClassification
7
- from tqdm import tqdm
8
- import json
9
- import shutil
10
- from fastapi.middleware.cors import CORSMiddleware
11
-
12
- app = FastAPI()
13
-
14
- # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify)
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
18
- allow_credentials=True,
19
- allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.)
20
- allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.)
21
- )
22
-
23
- # Charger le processor et le modèle fine-tuné depuis le chemin local
24
- local_model_path = r'./vit-finetuned-ucf101'
25
- processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
26
- model = AutoModelForImageClassification.from_pretrained(local_model_path)
27
- model.eval()
28
-
29
- # Fonction pour classifier une image
30
- def classifier_image(image):
31
- image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
32
- inputs = processor(images=image_pil, return_tensors="pt")
33
- with torch.no_grad():
34
- outputs = model(**inputs)
35
- logits = outputs.logits
36
- predicted_class_idx = logits.argmax(-1).item()
37
- predicted_class = model.config.id2label[predicted_class_idx]
38
- return predicted_class
39
-
40
- # Fonction pour traiter la vidéo et identifier les séquences de "Surfing"
41
- def identifier_sequences_surfing(video_path, intervalle=0.5):
42
- cap = cv2.VideoCapture(video_path)
43
- frame_rate = cap.get(cv2.CAP_PROP_FPS)
44
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
45
- frame_interval = int(frame_rate * intervalle)
46
-
47
- resultats = []
48
- sequences_surfing = []
49
- frame_index = 0
50
- in_surf_sequence = False
51
- start_timestamp = None
52
-
53
- with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar:
54
- success, frame = cap.read()
55
- while success:
56
- if frame_index % frame_interval == 0:
57
- timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level
58
- classe = classifier_image(frame)
59
- resultats.append({"Timestamp": timestamp, "Classe": classe})
60
-
61
- if classe == "Surfing" and not in_surf_sequence:
62
- in_surf_sequence = True
63
- start_timestamp = timestamp
64
-
65
- elif classe != "Surfing" and in_surf_sequence:
66
- # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle
67
- success_next, frame_next = cap.read()
68
- next_timestamp = round((frame_index + frame_interval) / frame_rate, 2)
69
- classe_next = None
70
-
71
- if success_next:
72
- classe_next = classifier_image(frame_next)
73
- resultats.append({"Timestamp": next_timestamp, "Classe": classe_next})
74
-
75
- # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle
76
- if classe_next == "Surfing":
77
- success = success_next
78
- frame = frame_next
79
- frame_index += frame_interval
80
- pbar.update(frame_interval)
81
- continue
82
- else:
83
- # Sinon, terminer la séquence "Surfing"
84
- in_surf_sequence = False
85
- end_timestamp = timestamp
86
- sequences_surfing.append((start_timestamp, end_timestamp))
87
-
88
  success, frame = cap.read()
89
- frame_index += 1
90
- pbar.update(1)
91
-
92
- # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo
93
- if in_surf_sequence:
94
- sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2)))
95
-
96
- cap.release()
97
- dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"])
98
- return dataframe_sequences
99
-
100
- # Fonction pour convertir les séquences en format JSON
101
- def convertir_sequences_en_json(dataframe):
102
- events = []
103
- blocks = []
104
- for idx, row in dataframe.iterrows():
105
- block = {
106
- "id": f"Surfing{idx + 1}",
107
- "start": round(row["Début"], 2),
108
- "end": round(row["Fin"], 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  }
110
- blocks.append(block)
111
- event = {
112
- "event": "Surfing",
113
- "blocks": blocks
114
- }
115
- events.append(event)
116
- return events
117
-
118
- @app.post("/analyze_video/")
119
- async def analyze_video(file: UploadFile = File(...)):
120
- with open("uploaded_video.mp4", "wb") as buffer:
121
- shutil.copyfileobj(file.file, buffer)
122
-
123
- dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1)
124
- json_result = convertir_sequences_en_json(dataframe_sequences)
125
- return json_result
126
-
127
-
128
- @app.get("/", response_class=HTMLResponse, tags=["Introduction Endpoints"])
129
- async def index():
130
- return (
131
- "Hello world! This `/` is the most simple and default endpoint. "
132
- "If you want to learn more, check out documentation of the API at "
133
- "<a href='/docs'>/docs</a> or "
134
- "<a href='https://2nzi-video-sequence-labeling.hf.space/docs' target='_blank'>external docs</a>."
135
- )
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ import cv2
3
+ import torch
4
+ import pandas as pd
5
+ from PIL import Image
6
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
7
+ from tqdm import tqdm
8
+ import json
9
+ import shutil
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import HTMLResponse
12
+
13
+ app = FastAPI()
14
+
15
+ # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify)
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ # allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
19
+ allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
20
+ allow_credentials=True,
21
+ allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.)
22
+ allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.)
23
+ )
24
+
25
+ # Charger le processor et le modèle fine-tuné depuis le chemin local
26
+ local_model_path = r'.\vit-finetuned-ucf101'
27
+ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
28
+ model = AutoModelForImageClassification.from_pretrained(local_model_path)
29
+ # model = AutoModelForImageClassification.from_pretrained("2nzi/vit-finetuned-ucf101")
30
+ model.eval()
31
+
32
+ # Fonction pour classifier une image
33
+ def classifier_image(image):
34
+ image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
35
+ inputs = processor(images=image_pil, return_tensors="pt")
36
+ with torch.no_grad():
37
+ outputs = model(**inputs)
38
+ logits = outputs.logits
39
+ predicted_class_idx = logits.argmax(-1).item()
40
+ predicted_class = model.config.id2label[predicted_class_idx]
41
+ return predicted_class
42
+
43
+ # Fonction pour traiter la vidéo et identifier les séquences de "Surfing"
44
+ def identifier_sequences_surfing(video_path, intervalle=0.5):
45
+ cap = cv2.VideoCapture(video_path)
46
+ frame_rate = cap.get(cv2.CAP_PROP_FPS)
47
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48
+ frame_interval = int(frame_rate * intervalle)
49
+
50
+ resultats = []
51
+ sequences_surfing = []
52
+ frame_index = 0
53
+ in_surf_sequence = False
54
+ start_timestamp = None
55
+
56
+ with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  success, frame = cap.read()
58
+ while success:
59
+ if frame_index % frame_interval == 0:
60
+ timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level
61
+ classe = classifier_image(frame)
62
+ resultats.append({"Timestamp": timestamp, "Classe": classe})
63
+
64
+ if classe == "Surfing" and not in_surf_sequence:
65
+ in_surf_sequence = True
66
+ start_timestamp = timestamp
67
+
68
+ elif classe != "Surfing" and in_surf_sequence:
69
+ # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle
70
+ success_next, frame_next = cap.read()
71
+ next_timestamp = round((frame_index + frame_interval) / frame_rate, 2)
72
+ classe_next = None
73
+
74
+ if success_next:
75
+ classe_next = classifier_image(frame_next)
76
+ resultats.append({"Timestamp": next_timestamp, "Classe": classe_next})
77
+
78
+ # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle
79
+ if classe_next == "Surfing":
80
+ success = success_next
81
+ frame = frame_next
82
+ frame_index += frame_interval
83
+ pbar.update(frame_interval)
84
+ continue
85
+ else:
86
+ # Sinon, terminer la séquence "Surfing"
87
+ in_surf_sequence = False
88
+ end_timestamp = timestamp
89
+ sequences_surfing.append((start_timestamp, end_timestamp))
90
+
91
+ success, frame = cap.read()
92
+ frame_index += 1
93
+ pbar.update(1)
94
+
95
+ # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo
96
+ if in_surf_sequence:
97
+ sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2)))
98
+
99
+ cap.release()
100
+ dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"])
101
+ return dataframe_sequences
102
+
103
+ # Fonction pour convertir les séquences en format JSON
104
+ def convertir_sequences_en_json(dataframe):
105
+ events = []
106
+ blocks = []
107
+ for idx, row in dataframe.iterrows():
108
+ block = {
109
+ "id": f"Surfing{idx + 1}",
110
+ "start": round(row["Début"], 2),
111
+ "end": round(row["Fin"], 2)
112
+ }
113
+ blocks.append(block)
114
+ event = {
115
+ "event": "Surfing",
116
+ "blocks": blocks
117
  }
118
+ events.append(event)
119
+ return events
120
+
121
+ @app.post("/analyze_video/")
122
+ async def analyze_video(file: UploadFile = File(...)):
123
+ with open("uploaded_video.mp4", "wb") as buffer:
124
+ shutil.copyfileobj(file.file, buffer)
125
+
126
+ dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1)
127
+ json_result = convertir_sequences_en_json(dataframe_sequences)
128
+ return json_result
129
+
130
+ @app.get("/", response_class=HTMLResponse)
131
+ async def index():
132
+ return (
133
+ """
134
+ <html>
135
+ <body>
136
+ <h1>Hello world!</h1>
137
+ <p>This `/` is the most simple and default endpoint.</p>
138
+ <p>If you want to learn more, check out the documentation of the API at
139
+ <a href='/docs'>/docs</a> or
140
+ <a href='https://2nzi-video-sequence-labeling.hf.space/docs' target='_blank'>external docs</a>.
141
+ </p>
142
+ </body>
143
+ </html>
144
+ """
145
+ )
146
+
147
+
148
+ # Lancer l'application avec uvicorn (command line)
149
+ # uvicorn main:app --reload
150
+ # http://localhost:8000/docs#/
151
+ # (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1