Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File | |
import cv2 | |
import torch | |
import pandas as pd | |
from PIL import Image | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from tqdm import tqdm | |
import json | |
import shutil | |
from fastapi.middleware.cors import CORSMiddleware | |
app = FastAPI() | |
# Add CORS middleware to allow requests from localhost:8080 (or any origin you specify) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.) | |
allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.) | |
) | |
# Charger le processor et le modèle fine-tuné depuis le chemin local | |
local_model_path = r'.\vit-finetuned-ucf101' | |
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
model = AutoModelForImageClassification.from_pretrained(local_model_path) | |
model.eval() | |
# Fonction pour classifier une image | |
def classifier_image(image): | |
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
inputs = processor(images=image_pil, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_class = model.config.id2label[predicted_class_idx] | |
return predicted_class | |
# Fonction pour traiter la vidéo et identifier les séquences de "Surfing" | |
def identifier_sequences_surfing(video_path, intervalle=0.5): | |
cap = cv2.VideoCapture(video_path) | |
frame_rate = cap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_interval = int(frame_rate * intervalle) | |
resultats = [] | |
sequences_surfing = [] | |
frame_index = 0 | |
in_surf_sequence = False | |
start_timestamp = None | |
with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar: | |
success, frame = cap.read() | |
while success: | |
if frame_index % frame_interval == 0: | |
timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level | |
classe = classifier_image(frame) | |
resultats.append({"Timestamp": timestamp, "Classe": classe}) | |
if classe == "Surfing" and not in_surf_sequence: | |
in_surf_sequence = True | |
start_timestamp = timestamp | |
elif classe != "Surfing" and in_surf_sequence: | |
# Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle | |
success_next, frame_next = cap.read() | |
next_timestamp = round((frame_index + frame_interval) / frame_rate, 2) | |
classe_next = None | |
if success_next: | |
classe_next = classifier_image(frame_next) | |
resultats.append({"Timestamp": next_timestamp, "Classe": classe_next}) | |
# Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle | |
if classe_next == "Surfing": | |
success = success_next | |
frame = frame_next | |
frame_index += frame_interval | |
pbar.update(frame_interval) | |
continue | |
else: | |
# Sinon, terminer la séquence "Surfing" | |
in_surf_sequence = False | |
end_timestamp = timestamp | |
sequences_surfing.append((start_timestamp, end_timestamp)) | |
success, frame = cap.read() | |
frame_index += 1 | |
pbar.update(1) | |
# Si on est toujours dans une séquence "Surfing" à la fin de la vidéo | |
if in_surf_sequence: | |
sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2))) | |
cap.release() | |
dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"]) | |
return dataframe_sequences | |
# Fonction pour convertir les séquences en format JSON | |
def convertir_sequences_en_json(dataframe): | |
events = [] | |
blocks = [] | |
for idx, row in dataframe.iterrows(): | |
block = { | |
"id": f"Surfing{idx + 1}", | |
"start": round(row["Début"], 2), | |
"end": round(row["Fin"], 2) | |
} | |
blocks.append(block) | |
event = { | |
"event": "Surfing", | |
"blocks": blocks | |
} | |
events.append(event) | |
return events | |
async def analyze_video(file: UploadFile = File(...)): | |
with open("uploaded_video.mp4", "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1) | |
json_result = convertir_sequences_en_json(dataframe_sequences) | |
return json_result | |
# Lancer l'application avec uvicorn (command line) | |
# uvicorn main:app --reload | |
# http://localhost:8000/docs#/ | |
# (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1 |