2nzi commited on
Commit
867f506
·
verified ·
1 Parent(s): 33a7b22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -129
app.py CHANGED
@@ -1,130 +1,130 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- import cv2
3
- import torch
4
- import pandas as pd
5
- from PIL import Image
6
- from transformers import AutoImageProcessor, AutoModelForImageClassification
7
- from tqdm import tqdm
8
- import json
9
- import shutil
10
- from fastapi.middleware.cors import CORSMiddleware
11
-
12
- app = FastAPI()
13
-
14
- # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify)
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
18
- allow_credentials=True,
19
- allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.)
20
- allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.)
21
- )
22
-
23
- # Charger le processor et le modèle fine-tuné depuis le chemin local
24
- local_model_path = r'.\vit-finetuned-ucf101'
25
- processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
26
- model = AutoModelForImageClassification.from_pretrained(local_model_path)
27
- model.eval()
28
-
29
- # Fonction pour classifier une image
30
- def classifier_image(image):
31
- image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
32
- inputs = processor(images=image_pil, return_tensors="pt")
33
- with torch.no_grad():
34
- outputs = model(**inputs)
35
- logits = outputs.logits
36
- predicted_class_idx = logits.argmax(-1).item()
37
- predicted_class = model.config.id2label[predicted_class_idx]
38
- return predicted_class
39
-
40
- # Fonction pour traiter la vidéo et identifier les séquences de "Surfing"
41
- def identifier_sequences_surfing(video_path, intervalle=0.5):
42
- cap = cv2.VideoCapture(video_path)
43
- frame_rate = cap.get(cv2.CAP_PROP_FPS)
44
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
45
- frame_interval = int(frame_rate * intervalle)
46
-
47
- resultats = []
48
- sequences_surfing = []
49
- frame_index = 0
50
- in_surf_sequence = False
51
- start_timestamp = None
52
-
53
- with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar:
54
- success, frame = cap.read()
55
- while success:
56
- if frame_index % frame_interval == 0:
57
- timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level
58
- classe = classifier_image(frame)
59
- resultats.append({"Timestamp": timestamp, "Classe": classe})
60
-
61
- if classe == "Surfing" and not in_surf_sequence:
62
- in_surf_sequence = True
63
- start_timestamp = timestamp
64
-
65
- elif classe != "Surfing" and in_surf_sequence:
66
- # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle
67
- success_next, frame_next = cap.read()
68
- next_timestamp = round((frame_index + frame_interval) / frame_rate, 2)
69
- classe_next = None
70
-
71
- if success_next:
72
- classe_next = classifier_image(frame_next)
73
- resultats.append({"Timestamp": next_timestamp, "Classe": classe_next})
74
-
75
- # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle
76
- if classe_next == "Surfing":
77
- success = success_next
78
- frame = frame_next
79
- frame_index += frame_interval
80
- pbar.update(frame_interval)
81
- continue
82
- else:
83
- # Sinon, terminer la séquence "Surfing"
84
- in_surf_sequence = False
85
- end_timestamp = timestamp
86
- sequences_surfing.append((start_timestamp, end_timestamp))
87
-
88
- success, frame = cap.read()
89
- frame_index += 1
90
- pbar.update(1)
91
-
92
- # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo
93
- if in_surf_sequence:
94
- sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2)))
95
-
96
- cap.release()
97
- dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"])
98
- return dataframe_sequences
99
-
100
- # Fonction pour convertir les séquences en format JSON
101
- def convertir_sequences_en_json(dataframe):
102
- events = []
103
- blocks = []
104
- for idx, row in dataframe.iterrows():
105
- block = {
106
- "id": f"Surfing{idx + 1}",
107
- "start": round(row["Début"], 2),
108
- "end": round(row["Fin"], 2)
109
- }
110
- blocks.append(block)
111
- event = {
112
- "event": "Surfing",
113
- "blocks": blocks
114
- }
115
- events.append(event)
116
- return events
117
-
118
- @app.post("/analyze_video/")
119
- async def analyze_video(file: UploadFile = File(...)):
120
- with open("uploaded_video.mp4", "wb") as buffer:
121
- shutil.copyfileobj(file.file, buffer)
122
-
123
- dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1)
124
- json_result = convertir_sequences_en_json(dataframe_sequences)
125
- return json_result
126
-
127
- # Lancer l'application avec uvicorn (command line)
128
- # uvicorn main:app --reload
129
- # http://localhost:8000/docs#/
130
  # (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ import cv2
3
+ import torch
4
+ import pandas as pd
5
+ from PIL import Image
6
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
7
+ from tqdm import tqdm
8
+ import json
9
+ import shutil
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+
12
+ app = FastAPI()
13
+
14
+ # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify)
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
18
+ allow_credentials=True,
19
+ allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.)
20
+ allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.)
21
+ )
22
+
23
+ # Charger le processor et le modèle fine-tuné depuis le chemin local
24
+ local_model_path = r'./vit-finetuned-ucf101'
25
+ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
26
+ model = AutoModelForImageClassification.from_pretrained(local_model_path)
27
+ model.eval()
28
+
29
+ # Fonction pour classifier une image
30
+ def classifier_image(image):
31
+ image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
32
+ inputs = processor(images=image_pil, return_tensors="pt")
33
+ with torch.no_grad():
34
+ outputs = model(**inputs)
35
+ logits = outputs.logits
36
+ predicted_class_idx = logits.argmax(-1).item()
37
+ predicted_class = model.config.id2label[predicted_class_idx]
38
+ return predicted_class
39
+
40
+ # Fonction pour traiter la vidéo et identifier les séquences de "Surfing"
41
+ def identifier_sequences_surfing(video_path, intervalle=0.5):
42
+ cap = cv2.VideoCapture(video_path)
43
+ frame_rate = cap.get(cv2.CAP_PROP_FPS)
44
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
45
+ frame_interval = int(frame_rate * intervalle)
46
+
47
+ resultats = []
48
+ sequences_surfing = []
49
+ frame_index = 0
50
+ in_surf_sequence = False
51
+ start_timestamp = None
52
+
53
+ with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar:
54
+ success, frame = cap.read()
55
+ while success:
56
+ if frame_index % frame_interval == 0:
57
+ timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level
58
+ classe = classifier_image(frame)
59
+ resultats.append({"Timestamp": timestamp, "Classe": classe})
60
+
61
+ if classe == "Surfing" and not in_surf_sequence:
62
+ in_surf_sequence = True
63
+ start_timestamp = timestamp
64
+
65
+ elif classe != "Surfing" and in_surf_sequence:
66
+ # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle
67
+ success_next, frame_next = cap.read()
68
+ next_timestamp = round((frame_index + frame_interval) / frame_rate, 2)
69
+ classe_next = None
70
+
71
+ if success_next:
72
+ classe_next = classifier_image(frame_next)
73
+ resultats.append({"Timestamp": next_timestamp, "Classe": classe_next})
74
+
75
+ # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle
76
+ if classe_next == "Surfing":
77
+ success = success_next
78
+ frame = frame_next
79
+ frame_index += frame_interval
80
+ pbar.update(frame_interval)
81
+ continue
82
+ else:
83
+ # Sinon, terminer la séquence "Surfing"
84
+ in_surf_sequence = False
85
+ end_timestamp = timestamp
86
+ sequences_surfing.append((start_timestamp, end_timestamp))
87
+
88
+ success, frame = cap.read()
89
+ frame_index += 1
90
+ pbar.update(1)
91
+
92
+ # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo
93
+ if in_surf_sequence:
94
+ sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2)))
95
+
96
+ cap.release()
97
+ dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"])
98
+ return dataframe_sequences
99
+
100
+ # Fonction pour convertir les séquences en format JSON
101
+ def convertir_sequences_en_json(dataframe):
102
+ events = []
103
+ blocks = []
104
+ for idx, row in dataframe.iterrows():
105
+ block = {
106
+ "id": f"Surfing{idx + 1}",
107
+ "start": round(row["Début"], 2),
108
+ "end": round(row["Fin"], 2)
109
+ }
110
+ blocks.append(block)
111
+ event = {
112
+ "event": "Surfing",
113
+ "blocks": blocks
114
+ }
115
+ events.append(event)
116
+ return events
117
+
118
+ @app.post("/analyze_video/")
119
+ async def analyze_video(file: UploadFile = File(...)):
120
+ with open("uploaded_video.mp4", "wb") as buffer:
121
+ shutil.copyfileobj(file.file, buffer)
122
+
123
+ dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1)
124
+ json_result = convertir_sequences_en_json(dataframe_sequences)
125
+ return json_result
126
+
127
+ # Lancer l'application avec uvicorn (command line)
128
+ # uvicorn main:app --reload
129
+ # http://localhost:8000/docs#/
130
  # (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1