2nzi commited on
Commit
967b935
·
verified ·
1 Parent(s): d9d6d2c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +136 -120
main.py CHANGED
@@ -1,135 +1,151 @@
1
- import uvicorn
 
 
2
  import pandas as pd
3
- from typing import Union
4
- from fastapi import FastAPI, Query
5
- import joblib
6
- from enum import Enum
 
 
7
  from fastapi.responses import HTMLResponse
8
 
9
- description = """
10
- Welcome to the GetAround Car Value Prediction API. This app provides an endpoint to predict car values based on various features! Try it out 🕹️
11
 
12
- ## Machine Learning
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- This section includes a Machine Learning endpoint that predicts car values based on various features. Here is the endpoint:
 
 
15
 
16
- * `/predict`: **POST** request that accepts a list of car features and returns a predicted car value.
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- Check out the documentation below 👇 for more information on each endpoint.
19
- """
 
20
 
21
- tags_metadata = [
22
- {
23
- "name": "Machine Learning",
24
- "description": "Endpoint for predicting car values based on provided features."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  }
26
- ]
27
-
28
- app = FastAPI(
29
- title="🚗 GetAround Car Value Prediction API",
30
- description=description,
31
- version="0.1",
32
- contact={
33
- "name": "Antoine VERDON",
34
- "email": "[email protected]",
35
- },
36
- openapi_tags=tags_metadata
37
- )
38
 
39
- class CarBrand(str, Enum):
40
- citroen = "Citroën"
41
- peugeot = "Peugeot"
42
- pgo = "PGO"
43
- renault = "Renault"
44
- audi = "Audi"
45
- bmw = "BMW"
46
- other = "other"
47
- mercedes = "Mercedes"
48
- opel = "Opel"
49
- volkswagen = "Volkswagen"
50
- ferrari = "Ferrari"
51
- maserati = "Maserati"
52
- mitsubishi = "Mitsubishi"
53
- nissan = "Nissan"
54
- seat = "SEAT"
55
- subaru = "Subaru"
56
- toyota = "Toyota"
57
-
58
- class FuelType(str, Enum):
59
- diesel = "diesel"
60
- petrol = "petrol"
61
- hybrid_petrol = "hybrid_petrol"
62
- electro = "electro"
63
-
64
- class PaintColor(str, Enum):
65
- black = "black"
66
- grey = "grey"
67
- white = "white"
68
- red = "red"
69
- silver = "silver"
70
- blue = "blue"
71
- orange = "orange"
72
- beige = "beige"
73
- brown = "brown"
74
- green = "green"
75
-
76
- class CarType(str, Enum):
77
- convertible = "convertible"
78
- coupe = "coupe"
79
- estate = "estate"
80
- hatchback = "hatchback"
81
- sedan = "sedan"
82
- subcompact = "subcompact"
83
- suv = "suv"
84
- van = "van"
85
-
86
- @app.get("/", response_class=HTMLResponse, tags=["Introduction Endpoints"])
87
  async def index():
88
  return (
89
- "Hello world! This `/` is the most simple and default endpoint. "
90
- "If you want to learn more, check out documentation of the API at "
91
- "<a href='/docs'>/docs</a> or "
92
- "<a href='https://2nzi-getaroundapi.hf.space/docs' target='_blank'>external docs</a>."
 
 
 
 
 
 
 
 
93
  )
94
 
95
- @app.post("/predict", tags=["Machine Learning"])
96
- async def predict(
97
- brand: CarBrand,
98
- mileage: int = Query(...),
99
- engine_power: int = Query(...),
100
- fuel: FuelType = Query(...),
101
- paint_color: PaintColor = Query(...),
102
- car_type: CarType = Query(...),
103
- private_parking_available: bool = Query(...),
104
- has_gps: bool = Query(...),
105
- has_air_conditioning: bool = Query(...),
106
- automatic_car: bool = Query(...),
107
- has_getaround_connect: bool = Query(...),
108
- has_speed_regulator: bool = Query(...),
109
- winter_tires: bool = Query(...)
110
- ):
111
-
112
- car_data_dict = {
113
- 'model_key': [brand],
114
- 'mileage': [mileage],
115
- 'engine_power': [engine_power],
116
- 'fuel': [fuel],
117
- 'paint_color': [paint_color],
118
- 'car_type': [car_type],
119
- 'private_parking_available': [private_parking_available],
120
- 'has_gps': [has_gps],
121
- 'has_air_conditioning': [has_air_conditioning],
122
- 'automatic_car': [automatic_car],
123
- 'has_getaround_connect': [has_getaround_connect],
124
- 'has_speed_regulator': [has_speed_regulator],
125
- 'winter_tires': [winter_tires]
126
- }
127
- car_data = pd.DataFrame(car_data_dict)
128
-
129
- model = joblib.load('best_model_XGBoost.pkl')
130
- prediction = model.predict(car_data)
131
- response = {"prediction": prediction.tolist()[0]}
132
- return response
133
 
134
- if __name__ == "__main__":
135
- uvicorn.run(app, host="0.0.0.0", port=4000)
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ import cv2
3
+ import torch
4
  import pandas as pd
5
+ from PIL import Image
6
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
7
+ from tqdm import tqdm
8
+ import json
9
+ import shutil
10
+ from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.responses import HTMLResponse
12
 
13
+ app = FastAPI()
 
14
 
15
+ # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify)
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ # allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
19
+ allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
20
+ allow_credentials=True,
21
+ allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.)
22
+ allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.)
23
+ )
24
+
25
+ # Charger le processor et le modèle fine-tuné depuis le chemin local
26
+ local_model_path = r'.\vit-finetuned-ucf101'
27
+ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
28
+ model = AutoModelForImageClassification.from_pretrained(local_model_path)
29
+ # model = AutoModelForImageClassification.from_pretrained("2nzi/vit-finetuned-ucf101")
30
+ model.eval()
31
+
32
+ # Fonction pour classifier une image
33
+ def classifier_image(image):
34
+ image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
35
+ inputs = processor(images=image_pil, return_tensors="pt")
36
+ with torch.no_grad():
37
+ outputs = model(**inputs)
38
+ logits = outputs.logits
39
+ predicted_class_idx = logits.argmax(-1).item()
40
+ predicted_class = model.config.id2label[predicted_class_idx]
41
+ return predicted_class
42
+
43
+ # Fonction pour traiter la vidéo et identifier les séquences de "Surfing"
44
+ def identifier_sequences_surfing(video_path, intervalle=0.5):
45
+ cap = cv2.VideoCapture(video_path)
46
+ frame_rate = cap.get(cv2.CAP_PROP_FPS)
47
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48
+ frame_interval = int(frame_rate * intervalle)
49
+
50
+ resultats = []
51
+ sequences_surfing = []
52
+ frame_index = 0
53
+ in_surf_sequence = False
54
+ start_timestamp = None
55
+
56
+ with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar:
57
+ success, frame = cap.read()
58
+ while success:
59
+ if frame_index % frame_interval == 0:
60
+ timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level
61
+ classe = classifier_image(frame)
62
+ resultats.append({"Timestamp": timestamp, "Classe": classe})
63
+
64
+ if classe == "Surfing" and not in_surf_sequence:
65
+ in_surf_sequence = True
66
+ start_timestamp = timestamp
67
+
68
+ elif classe != "Surfing" and in_surf_sequence:
69
+ # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle
70
+ success_next, frame_next = cap.read()
71
+ next_timestamp = round((frame_index + frame_interval) / frame_rate, 2)
72
+ classe_next = None
73
 
74
+ if success_next:
75
+ classe_next = classifier_image(frame_next)
76
+ resultats.append({"Timestamp": next_timestamp, "Classe": classe_next})
77
 
78
+ # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle
79
+ if classe_next == "Surfing":
80
+ success = success_next
81
+ frame = frame_next
82
+ frame_index += frame_interval
83
+ pbar.update(frame_interval)
84
+ continue
85
+ else:
86
+ # Sinon, terminer la séquence "Surfing"
87
+ in_surf_sequence = False
88
+ end_timestamp = timestamp
89
+ sequences_surfing.append((start_timestamp, end_timestamp))
90
 
91
+ success, frame = cap.read()
92
+ frame_index += 1
93
+ pbar.update(1)
94
 
95
+ # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo
96
+ if in_surf_sequence:
97
+ sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2)))
98
+
99
+ cap.release()
100
+ dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"])
101
+ return dataframe_sequences
102
+
103
+ # Fonction pour convertir les séquences en format JSON
104
+ def convertir_sequences_en_json(dataframe):
105
+ events = []
106
+ blocks = []
107
+ for idx, row in dataframe.iterrows():
108
+ block = {
109
+ "id": f"Surfing{idx + 1}",
110
+ "start": round(row["Début"], 2),
111
+ "end": round(row["Fin"], 2)
112
+ }
113
+ blocks.append(block)
114
+ event = {
115
+ "event": "Surfing",
116
+ "blocks": blocks
117
  }
118
+ events.append(event)
119
+ return events
120
+
121
+ @app.post("/analyze_video/")
122
+ async def analyze_video(file: UploadFile = File(...)):
123
+ with open("uploaded_video.mp4", "wb") as buffer:
124
+ shutil.copyfileobj(file.file, buffer)
 
 
 
 
 
125
 
126
+ dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1)
127
+ json_result = convertir_sequences_en_json(dataframe_sequences)
128
+ return json_result
129
+
130
+ @app.get("/", response_class=HTMLResponse)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  async def index():
132
  return (
133
+ """
134
+ <html>
135
+ <body>
136
+ <h1>Hello world!</h1>
137
+ <p>This `/` is the most simple and default endpoint.</p>
138
+ <p>If you want to learn more, check out the documentation of the API at
139
+ <a href='/docs'>/docs</a> or
140
+ <a href='https://2nzi-video-sequence-labeling.hf.space/docs' target='_blank'>external docs</a>.
141
+ </p>
142
+ </body>
143
+ </html>
144
+ """
145
  )
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ # Lancer l'application avec uvicorn (command line)
149
+ # uvicorn main:app --reload
150
+ # http://localhost:8000/docs#/
151
+ # (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1