2nzi commited on
Commit
b0dd811
·
verified ·
1 Parent(s): 967b935

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +120 -136
main.py CHANGED
@@ -1,151 +1,135 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- import cv2
3
- import torch
4
  import pandas as pd
5
- from PIL import Image
6
- from transformers import AutoImageProcessor, AutoModelForImageClassification
7
- from tqdm import tqdm
8
- import json
9
- import shutil
10
- from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.responses import HTMLResponse
12
 
13
- app = FastAPI()
 
14
 
15
- # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify)
16
- app.add_middleware(
17
- CORSMiddleware,
18
- # allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
19
- allow_origins=["http://localhost:8080"], # Replace with the URL of your Vue.js app
20
- allow_credentials=True,
21
- allow_methods=["*"], # Allows all HTTP methods (GET, POST, etc.)
22
- allow_headers=["*"], # Allows all headers (such as Content-Type, Authorization, etc.)
23
- )
24
-
25
- # Charger le processor et le modèle fine-tuné depuis le chemin local
26
- local_model_path = r'.\vit-finetuned-ucf101'
27
- processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
28
- model = AutoModelForImageClassification.from_pretrained(local_model_path)
29
- # model = AutoModelForImageClassification.from_pretrained("2nzi/vit-finetuned-ucf101")
30
- model.eval()
31
-
32
- # Fonction pour classifier une image
33
- def classifier_image(image):
34
- image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
35
- inputs = processor(images=image_pil, return_tensors="pt")
36
- with torch.no_grad():
37
- outputs = model(**inputs)
38
- logits = outputs.logits
39
- predicted_class_idx = logits.argmax(-1).item()
40
- predicted_class = model.config.id2label[predicted_class_idx]
41
- return predicted_class
42
-
43
- # Fonction pour traiter la vidéo et identifier les séquences de "Surfing"
44
- def identifier_sequences_surfing(video_path, intervalle=0.5):
45
- cap = cv2.VideoCapture(video_path)
46
- frame_rate = cap.get(cv2.CAP_PROP_FPS)
47
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48
- frame_interval = int(frame_rate * intervalle)
49
-
50
- resultats = []
51
- sequences_surfing = []
52
- frame_index = 0
53
- in_surf_sequence = False
54
- start_timestamp = None
55
-
56
- with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar:
57
- success, frame = cap.read()
58
- while success:
59
- if frame_index % frame_interval == 0:
60
- timestamp = round(frame_index / frame_rate, 2) # Maintain precision to the centisecond level
61
- classe = classifier_image(frame)
62
- resultats.append({"Timestamp": timestamp, "Classe": classe})
63
-
64
- if classe == "Surfing" and not in_surf_sequence:
65
- in_surf_sequence = True
66
- start_timestamp = timestamp
67
-
68
- elif classe != "Surfing" and in_surf_sequence:
69
- # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle
70
- success_next, frame_next = cap.read()
71
- next_timestamp = round((frame_index + frame_interval) / frame_rate, 2)
72
- classe_next = None
73
 
74
- if success_next:
75
- classe_next = classifier_image(frame_next)
76
- resultats.append({"Timestamp": next_timestamp, "Classe": classe_next})
77
 
78
- # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle
79
- if classe_next == "Surfing":
80
- success = success_next
81
- frame = frame_next
82
- frame_index += frame_interval
83
- pbar.update(frame_interval)
84
- continue
85
- else:
86
- # Sinon, terminer la séquence "Surfing"
87
- in_surf_sequence = False
88
- end_timestamp = timestamp
89
- sequences_surfing.append((start_timestamp, end_timestamp))
90
 
91
- success, frame = cap.read()
92
- frame_index += 1
93
- pbar.update(1)
94
 
95
- # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo
96
- if in_surf_sequence:
97
- sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2)))
98
-
99
- cap.release()
100
- dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"])
101
- return dataframe_sequences
102
-
103
- # Fonction pour convertir les séquences en format JSON
104
- def convertir_sequences_en_json(dataframe):
105
- events = []
106
- blocks = []
107
- for idx, row in dataframe.iterrows():
108
- block = {
109
- "id": f"Surfing{idx + 1}",
110
- "start": round(row["Début"], 2),
111
- "end": round(row["Fin"], 2)
112
- }
113
- blocks.append(block)
114
- event = {
115
- "event": "Surfing",
116
- "blocks": blocks
117
  }
118
- events.append(event)
119
- return events
120
-
121
- @app.post("/analyze_video/")
122
- async def analyze_video(file: UploadFile = File(...)):
123
- with open("uploaded_video.mp4", "wb") as buffer:
124
- shutil.copyfileobj(file.file, buffer)
125
-
126
- dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1)
127
- json_result = convertir_sequences_en_json(dataframe_sequences)
128
- return json_result
 
129
 
130
- @app.get("/", response_class=HTMLResponse)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  async def index():
132
  return (
133
- """
134
- <html>
135
- <body>
136
- <h1>Hello world!</h1>
137
- <p>This `/` is the most simple and default endpoint.</p>
138
- <p>If you want to learn more, check out the documentation of the API at
139
- <a href='/docs'>/docs</a> or
140
- <a href='https://2nzi-video-sequence-labeling.hf.space/docs' target='_blank'>external docs</a>.
141
- </p>
142
- </body>
143
- </html>
144
- """
145
  )
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- # Lancer l'application avec uvicorn (command line)
149
- # uvicorn main:app --reload
150
- # http://localhost:8000/docs#/
151
- # (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1
 
1
+ import uvicorn
 
 
2
  import pandas as pd
3
+ from typing import Union
4
+ from fastapi import FastAPI, Query
5
+ import joblib
6
+ from enum import Enum
 
 
7
  from fastapi.responses import HTMLResponse
8
 
9
+ description = """
10
+ Welcome to the GetAround Car Value Prediction API. This app provides an endpoint to predict car values based on various features! Try it out 🕹️
11
 
12
+ ## Machine Learning
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ This section includes a Machine Learning endpoint that predicts car values based on various features. Here is the endpoint:
 
 
15
 
16
+ * `/predict`: **POST** request that accepts a list of car features and returns a predicted car value.
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ Check out the documentation below 👇 for more information on each endpoint.
19
+ """
 
20
 
21
+ tags_metadata = [
22
+ {
23
+ "name": "Machine Learning",
24
+ "description": "Endpoint for predicting car values based on provided features."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  }
26
+ ]
27
+
28
+ app = FastAPI(
29
+ title="🚗 GetAround Car Value Prediction API",
30
+ description=description,
31
+ version="0.1",
32
+ contact={
33
+ "name": "Antoine VERDON",
34
+ "email": "antoineverdon.[email protected]",
35
+ },
36
+ openapi_tags=tags_metadata
37
+ )
38
 
39
+ class CarBrand(str, Enum):
40
+ citroen = "Citroën"
41
+ peugeot = "Peugeot"
42
+ pgo = "PGO"
43
+ renault = "Renault"
44
+ audi = "Audi"
45
+ bmw = "BMW"
46
+ other = "other"
47
+ mercedes = "Mercedes"
48
+ opel = "Opel"
49
+ volkswagen = "Volkswagen"
50
+ ferrari = "Ferrari"
51
+ maserati = "Maserati"
52
+ mitsubishi = "Mitsubishi"
53
+ nissan = "Nissan"
54
+ seat = "SEAT"
55
+ subaru = "Subaru"
56
+ toyota = "Toyota"
57
+
58
+ class FuelType(str, Enum):
59
+ diesel = "diesel"
60
+ petrol = "petrol"
61
+ hybrid_petrol = "hybrid_petrol"
62
+ electro = "electro"
63
+
64
+ class PaintColor(str, Enum):
65
+ black = "black"
66
+ grey = "grey"
67
+ white = "white"
68
+ red = "red"
69
+ silver = "silver"
70
+ blue = "blue"
71
+ orange = "orange"
72
+ beige = "beige"
73
+ brown = "brown"
74
+ green = "green"
75
+
76
+ class CarType(str, Enum):
77
+ convertible = "convertible"
78
+ coupe = "coupe"
79
+ estate = "estate"
80
+ hatchback = "hatchback"
81
+ sedan = "sedan"
82
+ subcompact = "subcompact"
83
+ suv = "suv"
84
+ van = "van"
85
+
86
+ @app.get("/", response_class=HTMLResponse, tags=["Introduction Endpoints"])
87
  async def index():
88
  return (
89
+ "Hello world! This `/` is the most simple and default endpoint. "
90
+ "If you want to learn more, check out documentation of the API at "
91
+ "<a href='/docs'>/docs</a> or "
92
+ "<a href='https://2nzi-getaroundapi.hf.space/docs' target='_blank'>external docs</a>."
 
 
 
 
 
 
 
 
93
  )
94
 
95
+ @app.post("/predict", tags=["Machine Learning"])
96
+ async def predict(
97
+ brand: CarBrand,
98
+ mileage: int = Query(...),
99
+ engine_power: int = Query(...),
100
+ fuel: FuelType = Query(...),
101
+ paint_color: PaintColor = Query(...),
102
+ car_type: CarType = Query(...),
103
+ private_parking_available: bool = Query(...),
104
+ has_gps: bool = Query(...),
105
+ has_air_conditioning: bool = Query(...),
106
+ automatic_car: bool = Query(...),
107
+ has_getaround_connect: bool = Query(...),
108
+ has_speed_regulator: bool = Query(...),
109
+ winter_tires: bool = Query(...)
110
+ ):
111
+
112
+ car_data_dict = {
113
+ 'model_key': [brand],
114
+ 'mileage': [mileage],
115
+ 'engine_power': [engine_power],
116
+ 'fuel': [fuel],
117
+ 'paint_color': [paint_color],
118
+ 'car_type': [car_type],
119
+ 'private_parking_available': [private_parking_available],
120
+ 'has_gps': [has_gps],
121
+ 'has_air_conditioning': [has_air_conditioning],
122
+ 'automatic_car': [automatic_car],
123
+ 'has_getaround_connect': [has_getaround_connect],
124
+ 'has_speed_regulator': [has_speed_regulator],
125
+ 'winter_tires': [winter_tires]
126
+ }
127
+ car_data = pd.DataFrame(car_data_dict)
128
+
129
+ model = joblib.load('best_model_XGBoost.pkl')
130
+ prediction = model.predict(car_data)
131
+ response = {"prediction": prediction.tolist()[0]}
132
+ return response
133
 
134
+ if __name__ == "__main__":
135
+ uvicorn.run(app, host="0.0.0.0", port=4000)