2nzi commited on
Commit
53c7c09
·
verified ·
1 Parent(s): 31216a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -4
app.py CHANGED
@@ -1,3 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, UploadFile, File
2
  import cv2
3
  import torch
@@ -118,14 +271,25 @@ def convertir_sequences_en_json(dataframe):
118
  events.append(event)
119
  return events
120
 
 
 
 
 
121
  @app.post("/analyze_video/")
122
  async def analyze_video(file: UploadFile = File(...)):
123
- with open("uploaded_video.mp4", "wb") as buffer:
124
- shutil.copyfileobj(file.file, buffer)
 
 
125
 
126
- dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1)
 
127
  json_result = convertir_sequences_en_json(dataframe_sequences)
128
- return json_result
 
 
 
 
129
 
130
  @app.get("/", response_class=HTMLResponse)
131
  async def index():
 
1
+ # from fastapi import FastAPI, UploadFile, File
2
+ # import cv2
3
+ # import torch
4
+ # import pandas as pd
5
+ # from PIL import Image
6
+ # from transformers import AutoImageProcessor, AutoModelForImageClassification
7
+ # from tqdm import tqdm
8
+ # import json
9
+ # import shutil
10
+ # from fastapi.middleware.cors import CORSMiddleware
11
+ # from fastapi.responses import HTMLResponse
12
+
13
+ # app = FastAPI()
14
+
15
+ # # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify)
16
+ # app.add_middleware(
17
+ # CORSMiddleware,
18
+ # # allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
19
+ # allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
20
+ # allow_credentials=True,
21
+ # allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.)
22
+ # allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.)
23
+ # )
24
+
25
+ # # Charger le processor et le modèle fine-tuné depuis le chemin local
26
+ # local_model_path = r'./vit-finetuned-ucf101'
27
+ # processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
28
+ # model = AutoModelForImageClassification.from_pretrained(local_model_path)
29
+ # # model = AutoModelForImageClassification.from_pretrained("2nzi/vit-finetuned-ucf101")
30
+ # model.eval()
31
+
32
+ # # Fonction pour classifier une image
33
+ # def classifier_image(image):
34
+ # image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
35
+ # inputs = processor(images=image_pil, return_tensors="pt")
36
+ # with torch.no_grad():
37
+ # outputs = model(**inputs)
38
+ # logits = outputs.logits
39
+ # predicted_class_idx = logits.argmax(-1).item()
40
+ # predicted_class = model.config.id2label[predicted_class_idx]
41
+ # return predicted_class
42
+
43
+ # # Fonction pour traiter la vidéo et identifier les séquences de "Surfing"
44
+ # def identifier_sequences_surfing(video_path, intervalle=0.5):
45
+ # cap = cv2.VideoCapture(video_path)
46
+ # frame_rate = cap.get(cv2.CAP_PROP_FPS)
47
+ # total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48
+ # frame_interval = int(frame_rate * intervalle)
49
+
50
+ # resultats = []
51
+ # sequences_surfing = []
52
+ # frame_index = 0
53
+ # in_surf_sequence = False
54
+ # start_timestamp = None
55
+
56
+ # with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar:
57
+ # success, frame = cap.read()
58
+ # while success:
59
+ # if frame_index % frame_interval == 0:
60
+ # timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level
61
+ # classe = classifier_image(frame)
62
+ # resultats.append({"Timestamp": timestamp, "Classe": classe})
63
+
64
+ # if classe == "Surfing" and not in_surf_sequence:
65
+ # in_surf_sequence = True
66
+ # start_timestamp = timestamp
67
+
68
+ # elif classe != "Surfing" and in_surf_sequence:
69
+ # # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle
70
+ # success_next, frame_next = cap.read()
71
+ # next_timestamp = round((frame_index + frame_interval) / frame_rate, 2)
72
+ # classe_next = None
73
+
74
+ # if success_next:
75
+ # classe_next = classifier_image(frame_next)
76
+ # resultats.append({"Timestamp": next_timestamp, "Classe": classe_next})
77
+
78
+ # # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle
79
+ # if classe_next == "Surfing":
80
+ # success = success_next
81
+ # frame = frame_next
82
+ # frame_index += frame_interval
83
+ # pbar.update(frame_interval)
84
+ # continue
85
+ # else:
86
+ # # Sinon, terminer la séquence "Surfing"
87
+ # in_surf_sequence = False
88
+ # end_timestamp = timestamp
89
+ # sequences_surfing.append((start_timestamp, end_timestamp))
90
+
91
+ # success, frame = cap.read()
92
+ # frame_index += 1
93
+ # pbar.update(1)
94
+
95
+ # # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo
96
+ # if in_surf_sequence:
97
+ # sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2)))
98
+
99
+ # cap.release()
100
+ # dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"])
101
+ # return dataframe_sequences
102
+
103
+ # # Fonction pour convertir les séquences en format JSON
104
+ # def convertir_sequences_en_json(dataframe):
105
+ # events = []
106
+ # blocks = []
107
+ # for idx, row in dataframe.iterrows():
108
+ # block = {
109
+ # "id": f"Surfing{idx + 1}",
110
+ # "start": round(row["Début"], 2),
111
+ # "end": round(row["Fin"], 2)
112
+ # }
113
+ # blocks.append(block)
114
+ # event = {
115
+ # "event": "Surfing",
116
+ # "blocks": blocks
117
+ # }
118
+ # events.append(event)
119
+ # return events
120
+
121
+ # @app.post("/analyze_video/")
122
+ # async def analyze_video(file: UploadFile = File(...)):
123
+ # with open("uploaded_video.mp4", "wb") as buffer:
124
+ # shutil.copyfileobj(file.file, buffer)
125
+
126
+ # dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1)
127
+ # json_result = convertir_sequences_en_json(dataframe_sequences)
128
+ # return json_result
129
+
130
+ # @app.get("/", response_class=HTMLResponse)
131
+ # async def index():
132
+ # return (
133
+ # """
134
+ # <html>
135
+ # <body>
136
+ # <h1>Hello world!</h1>
137
+ # <p>This `/` is the most simple and default endpoint.</p>
138
+ # <p>If you want to learn more, check out the documentation of the API at
139
+ # <a href='/docs'>/docs</a> or
140
+ # <a href='https://2nzi-video-sequence-labeling.hf.space/docs' target='_blank'>external docs</a>.
141
+ # </p>
142
+ # </body>
143
+ # </html>
144
+ # """
145
+ # )
146
+
147
+
148
+ # # Lancer l'application avec uvicorn (command line)
149
+ # # uvicorn main:app --reload
150
+ # # http://localhost:8000/docs#/
151
+ # # (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1
152
+
153
+
154
  from fastapi import FastAPI, UploadFile, File
155
  import cv2
156
  import torch
 
271
  events.append(event)
272
  return events
273
 
274
+
275
+ import os
276
+ import tempfile
277
+
278
  @app.post("/analyze_video/")
279
  async def analyze_video(file: UploadFile = File(...)):
280
+ # Utiliser tempfile pour créer un fichier temporaire
281
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", dir="/tmp") as tmp:
282
+ shutil.copyfileobj(file.file, tmp)
283
+ tmp_path = tmp.name
284
 
285
+ # Analyser la vidéo
286
+ dataframe_sequences = identifier_sequences_surfing(tmp_path, intervalle=1)
287
  json_result = convertir_sequences_en_json(dataframe_sequences)
288
+
289
+ # Supprimer le fichier temporaire après utilisation
290
+ os.remove(tmp_path)
291
+
292
+ return {"filename": file.filename, "result": json_result}
293
 
294
  @app.get("/", response_class=HTMLResponse)
295
  async def index():