2nzi commited on
Commit
b0dcef4
·
verified ·
1 Parent(s): f137487

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -148
app.py CHANGED
@@ -1,151 +1,151 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- import cv2
3
- import torch
4
- import pandas as pd
5
- from PIL import Image
6
- from transformers import AutoImageProcessor, AutoModelForImageClassification
7
- from tqdm import tqdm
8
- import json
9
- import shutil
10
- from fastapi.middleware.cors import CORSMiddleware
11
- from fastapi.responses import HTMLResponse
12
-
13
- app = FastAPI()
14
-
15
- # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify)
16
- app.add_middleware(
17
- CORSMiddleware,
18
- # allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
19
- allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
20
- allow_credentials=True,
21
- allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.)
22
- allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.)
23
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # Charger le processor et le modèle fine-tuné depuis le chemin local
26
- local_model_path = r'.\vit-finetuned-ucf101'
27
- processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
28
- model = AutoModelForImageClassification.from_pretrained(local_model_path)
29
- # model = AutoModelForImageClassification.from_pretrained("2nzi/vit-finetuned-ucf101")
30
- model.eval()
31
-
32
- # Fonction pour classifier une image
33
- def classifier_image(image):
34
- image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
35
- inputs = processor(images=image_pil, return_tensors="pt")
36
- with torch.no_grad():
37
- outputs = model(**inputs)
38
- logits = outputs.logits
39
- predicted_class_idx = logits.argmax(-1).item()
40
- predicted_class = model.config.id2label[predicted_class_idx]
41
- return predicted_class
42
-
43
- # Fonction pour traiter la vidéo et identifier les séquences de "Surfing"
44
- def identifier_sequences_surfing(video_path, intervalle=0.5):
45
- cap = cv2.VideoCapture(video_path)
46
- frame_rate = cap.get(cv2.CAP_PROP_FPS)
47
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48
- frame_interval = int(frame_rate * intervalle)
49
-
50
- resultats = []
51
- sequences_surfing = []
52
- frame_index = 0
53
- in_surf_sequence = False
54
- start_timestamp = None
55
-
56
- with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar:
57
  success, frame = cap.read()
58
- while success:
59
- if frame_index % frame_interval == 0:
60
- timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level
61
- classe = classifier_image(frame)
62
- resultats.append({"Timestamp": timestamp, "Classe": classe})
63
-
64
- if classe == "Surfing" and not in_surf_sequence:
65
- in_surf_sequence = True
66
- start_timestamp = timestamp
67
-
68
- elif classe != "Surfing" and in_surf_sequence:
69
- # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle
70
- success_next, frame_next = cap.read()
71
- next_timestamp = round((frame_index + frame_interval) / frame_rate, 2)
72
- classe_next = None
73
-
74
- if success_next:
75
- classe_next = classifier_image(frame_next)
76
- resultats.append({"Timestamp": next_timestamp, "Classe": classe_next})
77
-
78
- # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle
79
- if classe_next == "Surfing":
80
- success = success_next
81
- frame = frame_next
82
- frame_index += frame_interval
83
- pbar.update(frame_interval)
84
- continue
85
- else:
86
- # Sinon, terminer la séquence "Surfing"
87
- in_surf_sequence = False
88
- end_timestamp = timestamp
89
- sequences_surfing.append((start_timestamp, end_timestamp))
90
-
91
- success, frame = cap.read()
92
- frame_index += 1
93
- pbar.update(1)
94
-
95
- # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo
96
- if in_surf_sequence:
97
- sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2)))
98
-
99
- cap.release()
100
- dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"])
101
- return dataframe_sequences
102
-
103
- # Fonction pour convertir les séquences en format JSON
104
- def convertir_sequences_en_json(dataframe):
105
- events = []
106
- blocks = []
107
- for idx, row in dataframe.iterrows():
108
- block = {
109
- "id": f"Surfing{idx + 1}",
110
- "start": round(row["Début"], 2),
111
- "end": round(row["Fin"], 2)
112
- }
113
- blocks.append(block)
114
- event = {
115
- "event": "Surfing",
116
- "blocks": blocks
117
  }
118
- events.append(event)
119
- return events
120
-
121
- @app.post("/analyze_video/")
122
- async def analyze_video(file: UploadFile = File(...)):
123
- with open("uploaded_video.mp4", "wb") as buffer:
124
- shutil.copyfileobj(file.file, buffer)
125
-
126
- dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1)
127
- json_result = convertir_sequences_en_json(dataframe_sequences)
128
- return json_result
129
-
130
- @app.get("/", response_class=HTMLResponse)
131
- async def index():
132
- return (
133
- """
134
- <html>
135
- <body>
136
- <h1>Hello world!</h1>
137
- <p>This `/` is the most simple and default endpoint.</p>
138
- <p>If you want to learn more, check out the documentation of the API at
139
- <a href='/docs'>/docs</a> or
140
- <a href='https://2nzi-video-sequence-labeling.hf.space/docs' target='_blank'>external docs</a>.
141
- </p>
142
- </body>
143
- </html>
144
- """
145
- )
146
-
147
-
148
- # Lancer l'application avec uvicorn (command line)
149
- # uvicorn main:app --reload
150
- # http://localhost:8000/docs#/
151
- # (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ import cv2
3
+ import torch
4
+ import pandas as pd
5
+ from PIL import Image
6
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
7
+ from tqdm import tqdm
8
+ import json
9
+ import shutil
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import HTMLResponse
12
+
13
+ app = FastAPI()
14
+
15
+ # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify)
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ # allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
19
+ allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
20
+ allow_credentials=True,
21
+ allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.)
22
+ allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.)
23
+ )
24
+
25
+ # Charger le processor et le modèle fine-tuné depuis le chemin local
26
+ local_model_path = r'.\vit-finetuned-ucf101'
27
+ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
28
+ model = AutoModelForImageClassification.from_pretrained(local_model_path)
29
+ # model = AutoModelForImageClassification.from_pretrained("2nzi/vit-finetuned-ucf101")
30
+ model.eval()
31
+
32
+ # Fonction pour classifier une image
33
+ def classifier_image(image):
34
+ image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
35
+ inputs = processor(images=image_pil, return_tensors="pt")
36
+ with torch.no_grad():
37
+ outputs = model(**inputs)
38
+ logits = outputs.logits
39
+ predicted_class_idx = logits.argmax(-1).item()
40
+ predicted_class = model.config.id2label[predicted_class_idx]
41
+ return predicted_class
42
+
43
+ # Fonction pour traiter la vidéo et identifier les séquences de "Surfing"
44
+ def identifier_sequences_surfing(video_path, intervalle=0.5):
45
+ cap = cv2.VideoCapture(video_path)
46
+ frame_rate = cap.get(cv2.CAP_PROP_FPS)
47
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48
+ frame_interval = int(frame_rate * intervalle)
49
+
50
+ resultats = []
51
+ sequences_surfing = []
52
+ frame_index = 0
53
+ in_surf_sequence = False
54
+ start_timestamp = None
55
+
56
+ with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar:
57
+ success, frame = cap.read()
58
+ while success:
59
+ if frame_index % frame_interval == 0:
60
+ timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level
61
+ classe = classifier_image(frame)
62
+ resultats.append({"Timestamp": timestamp, "Classe": classe})
63
+
64
+ if classe == "Surfing" and not in_surf_sequence:
65
+ in_surf_sequence = True
66
+ start_timestamp = timestamp
67
+
68
+ elif classe != "Surfing" and in_surf_sequence:
69
+ # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle
70
+ success_next, frame_next = cap.read()
71
+ next_timestamp = round((frame_index + frame_interval) / frame_rate, 2)
72
+ classe_next = None
73
+
74
+ if success_next:
75
+ classe_next = classifier_image(frame_next)
76
+ resultats.append({"Timestamp": next_timestamp, "Classe": classe_next})
77
+
78
+ # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle
79
+ if classe_next == "Surfing":
80
+ success = success_next
81
+ frame = frame_next
82
+ frame_index += frame_interval
83
+ pbar.update(frame_interval)
84
+ continue
85
+ else:
86
+ # Sinon, terminer la séquence "Surfing"
87
+ in_surf_sequence = False
88
+ end_timestamp = timestamp
89
+ sequences_surfing.append((start_timestamp, end_timestamp))
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  success, frame = cap.read()
92
+ frame_index += 1
93
+ pbar.update(1)
94
+
95
+ # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo
96
+ if in_surf_sequence:
97
+ sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2)))
98
+
99
+ cap.release()
100
+ dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"])
101
+ return dataframe_sequences
102
+
103
+ # Fonction pour convertir les séquences en format JSON
104
+ def convertir_sequences_en_json(dataframe):
105
+ events = []
106
+ blocks = []
107
+ for idx, row in dataframe.iterrows():
108
+ block = {
109
+ "id": f"Surfing{idx + 1}",
110
+ "start": round(row["Début"], 2),
111
+ "end": round(row["Fin"], 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  }
113
+ blocks.append(block)
114
+ event = {
115
+ "event": "Surfing",
116
+ "blocks": blocks
117
+ }
118
+ events.append(event)
119
+ return events
120
+
121
+ @app.post("/analyze_video/")
122
+ async def analyze_video(file: UploadFile = File(...)):
123
+ with open("uploaded_video.mp4", "wb") as buffer:
124
+ shutil.copyfileobj(file.file, buffer)
125
+
126
+ dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1)
127
+ json_result = convertir_sequences_en_json(dataframe_sequences)
128
+ return json_result
129
+
130
+ @app.get("/", response_class=HTMLResponse)
131
+ async def index():
132
+ return (
133
+ """
134
+ <html>
135
+ <body>
136
+ <h1>Hello world!</h1>
137
+ <p>This `/` is the most simple and default endpoint.</p>
138
+ <p>If you want to learn more, check out the documentation of the API at
139
+ <a href='/docs'>/docs</a> or
140
+ <a href='https://2nzi-video-sequence-labeling.hf.space/docs' target='_blank'>external docs</a>.
141
+ </p>
142
+ </body>
143
+ </html>
144
+ """
145
+ )
146
+
147
+
148
+ # Lancer l'application avec uvicorn (command line)
149
+ # uvicorn main:app --reload
150
+ # http://localhost:8000/docs#/
151
+ # (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1