Spaces:
Sleeping
Sleeping
# from fastapi import FastAPI, UploadFile, File | |
# import cv2 | |
# import torch | |
# import pandas as pd | |
# from PIL import Image | |
# from transformers import AutoImageProcessor, AutoModelForImageClassification | |
# from tqdm import tqdm | |
# import json | |
# import shutil | |
# from fastapi.middleware.cors import CORSMiddleware | |
# from fastapi.responses import HTMLResponse | |
# app = FastAPI() | |
# # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify) | |
# app.add_middleware( | |
# CORSMiddleware, | |
# # allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app | |
# allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app | |
# allow_credentials=True, | |
# allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.) | |
# allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.) | |
# ) | |
# # Charger le processor et le modèle fine-tuné depuis le chemin local | |
# local_model_path = r'./vit-finetuned-ucf101' | |
# processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
# model = AutoModelForImageClassification.from_pretrained(local_model_path) | |
# # model = AutoModelForImageClassification.from_pretrained("2nzi/vit-finetuned-ucf101") | |
# model.eval() | |
# # Fonction pour classifier une image | |
# def classifier_image(image): | |
# image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
# inputs = processor(images=image_pil, return_tensors="pt") | |
# with torch.no_grad(): | |
# outputs = model(**inputs) | |
# logits = outputs.logits | |
# predicted_class_idx = logits.argmax(-1).item() | |
# predicted_class = model.config.id2label[predicted_class_idx] | |
# return predicted_class | |
# # Fonction pour traiter la vidéo et identifier les séquences de "Surfing" | |
# def identifier_sequences_surfing(video_path, intervalle=0.5): | |
# cap = cv2.VideoCapture(video_path) | |
# frame_rate = cap.get(cv2.CAP_PROP_FPS) | |
# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# frame_interval = int(frame_rate * intervalle) | |
# resultats = [] | |
# sequences_surfing = [] | |
# frame_index = 0 | |
# in_surf_sequence = False | |
# start_timestamp = None | |
# with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar: | |
# success, frame = cap.read() | |
# while success: | |
# if frame_index % frame_interval == 0: | |
# timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level | |
# classe = classifier_image(frame) | |
# resultats.append({"Timestamp": timestamp, "Classe": classe}) | |
# if classe == "Surfing" and not in_surf_sequence: | |
# in_surf_sequence = True | |
# start_timestamp = timestamp | |
# elif classe != "Surfing" and in_surf_sequence: | |
# # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle | |
# success_next, frame_next = cap.read() | |
# next_timestamp = round((frame_index + frame_interval) / frame_rate, 2) | |
# classe_next = None | |
# if success_next: | |
# classe_next = classifier_image(frame_next) | |
# resultats.append({"Timestamp": next_timestamp, "Classe": classe_next}) | |
# # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle | |
# if classe_next == "Surfing": | |
# success = success_next | |
# frame = frame_next | |
# frame_index += frame_interval | |
# pbar.update(frame_interval) | |
# continue | |
# else: | |
# # Sinon, terminer la séquence "Surfing" | |
# in_surf_sequence = False | |
# end_timestamp = timestamp | |
# sequences_surfing.append((start_timestamp, end_timestamp)) | |
# success, frame = cap.read() | |
# frame_index += 1 | |
# pbar.update(1) | |
# # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo | |
# if in_surf_sequence: | |
# sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2))) | |
# cap.release() | |
# dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"]) | |
# return dataframe_sequences | |
# # Fonction pour convertir les séquences en format JSON | |
# def convertir_sequences_en_json(dataframe): | |
# events = [] | |
# blocks = [] | |
# for idx, row in dataframe.iterrows(): | |
# block = { | |
# "id": f"Surfing{idx + 1}", | |
# "start": round(row["Début"], 2), | |
# "end": round(row["Fin"], 2) | |
# } | |
# blocks.append(block) | |
# event = { | |
# "event": "Surfing", | |
# "blocks": blocks | |
# } | |
# events.append(event) | |
# return events | |
# @app.post("/analyze_video/") | |
# async def analyze_video(file: UploadFile = File(...)): | |
# with open("uploaded_video.mp4", "wb") as buffer: | |
# shutil.copyfileobj(file.file, buffer) | |
# dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1) | |
# json_result = convertir_sequences_en_json(dataframe_sequences) | |
# return json_result | |
# @app.get("/", response_class=HTMLResponse) | |
# async def index(): | |
# return ( | |
# """ | |
# <html> | |
# <body> | |
# <h1>Hello world!</h1> | |
# <p>This `/` is the most simple and default endpoint.</p> | |
# <p>If you want to learn more, check out the documentation of the API at | |
# <a href='/docs'>/docs</a> or | |
# <a href='https://2nzi-video-sequence-labeling.hf.space/docs' target='_blank'>external docs</a>. | |
# </p> | |
# </body> | |
# </html> | |
# """ | |
# ) | |
# # Lancer l'application avec uvicorn (command line) | |
# # uvicorn main:app --reload | |
# # http://localhost:8000/docs#/ | |
# # (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1 | |
from fastapi import FastAPI, UploadFile, File | |
import cv2 | |
import torch | |
import pandas as pd | |
from PIL import Image | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from tqdm import tqdm | |
import json | |
import shutil | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import HTMLResponse | |
app = FastAPI() | |
# Add CORS middleware to allow requests from localhost:8080 (or any origin you specify) | |
app.add_middleware( | |
CORSMiddleware, | |
# allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app | |
allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.) | |
allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.) | |
) | |
# Charger le processor et le modèle fine-tuné depuis le chemin local | |
local_model_path = r'./vit-finetuned-ucf101' | |
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
model = AutoModelForImageClassification.from_pretrained(local_model_path) | |
# model = AutoModelForImageClassification.from_pretrained("2nzi/vit-finetuned-ucf101") | |
model.eval() | |
# Fonction pour classifier une image | |
def classifier_image(image): | |
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
inputs = processor(images=image_pil, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_class = model.config.id2label[predicted_class_idx] | |
return predicted_class | |
# Fonction pour traiter la vidéo et identifier les séquences de "Surfing" | |
def identifier_sequences_surfing(video_path, intervalle=0.5): | |
cap = cv2.VideoCapture(video_path) | |
frame_rate = cap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_interval = int(frame_rate * intervalle) | |
resultats = [] | |
sequences_surfing = [] | |
frame_index = 0 | |
in_surf_sequence = False | |
start_timestamp = None | |
with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar: | |
success, frame = cap.read() | |
while success: | |
if frame_index % frame_interval == 0: | |
timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level | |
classe = classifier_image(frame) | |
resultats.append({"Timestamp": timestamp, "Classe": classe}) | |
if classe == "Surfing" and not in_surf_sequence: | |
in_surf_sequence = True | |
start_timestamp = timestamp | |
elif classe != "Surfing" and in_surf_sequence: | |
# Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle | |
success_next, frame_next = cap.read() | |
next_timestamp = round((frame_index + frame_interval) / frame_rate, 2) | |
classe_next = None | |
if success_next: | |
classe_next = classifier_image(frame_next) | |
resultats.append({"Timestamp": next_timestamp, "Classe": classe_next}) | |
# Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle | |
if classe_next == "Surfing": | |
success = success_next | |
frame = frame_next | |
frame_index += frame_interval | |
pbar.update(frame_interval) | |
continue | |
else: | |
# Sinon, terminer la séquence "Surfing" | |
in_surf_sequence = False | |
end_timestamp = timestamp | |
sequences_surfing.append((start_timestamp, end_timestamp)) | |
success, frame = cap.read() | |
frame_index += 1 | |
pbar.update(1) | |
# Si on est toujours dans une séquence "Surfing" à la fin de la vidéo | |
if in_surf_sequence: | |
sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2))) | |
cap.release() | |
dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"]) | |
return dataframe_sequences | |
# Fonction pour convertir les séquences en format JSON | |
def convertir_sequences_en_json(dataframe): | |
events = [] | |
blocks = [] | |
for idx, row in dataframe.iterrows(): | |
block = { | |
"id": f"Surfing{idx + 1}", | |
"start": round(row["Début"], 2), | |
"end": round(row["Fin"], 2) | |
} | |
blocks.append(block) | |
event = { | |
"event": "Surfing", | |
"blocks": blocks | |
} | |
events.append(event) | |
return events | |
import os | |
import tempfile | |
async def analyze_video(file: UploadFile = File(...)): | |
# Utiliser tempfile pour créer un fichier temporaire | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", dir="/tmp") as tmp: | |
shutil.copyfileobj(file.file, tmp) | |
tmp_path = tmp.name | |
# Analyser la vidéo | |
dataframe_sequences = identifier_sequences_surfing(tmp_path, intervalle=1) | |
json_result = convertir_sequences_en_json(dataframe_sequences) | |
# Supprimer le fichier temporaire après utilisation | |
os.remove(tmp_path) | |
return {"filename": file.filename, "result": json_result} | |
async def index(): | |
return ( | |
""" | |
<html> | |
<body> | |
<h1>Hello world!</h1> | |
<p>This `/` is the most simple and default endpoint.</p> | |
<p>If you want to learn more, check out the documentation of the API at | |
<a href='/docs'>/docs</a> or | |
<a href='https://2nzi-video-sequence-labeling.hf.space/docs' target='_blank'>external docs</a>. | |
</p> | |
</body> | |
</html> | |
""" | |
) | |
# Lancer l'application avec uvicorn (command line) | |
# uvicorn main:app --reload | |
# http://localhost:8000/docs#/ | |
# (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1 |