Solar-Iz commited on
Commit
c7490aa
·
1 Parent(s): 68ef689

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +15 -30
app/main.py CHANGED
@@ -1,50 +1,35 @@
1
  import PIL
2
  from fastapi import FastAPI, File, UploadFile
3
  from pydantic import BaseModel
4
- from fastapi.responses import JSONResponse
5
  from utils.model_func import class_id_to_label, load_model, transform_image
6
 
7
- model = None
8
  app = FastAPI()
9
 
10
 
 
11
  class ImageClass(BaseModel):
12
  prediction: str
13
 
14
- class TextClass(BaseModel):
15
- text: str
16
-
17
-
18
  @app.on_event("startup")
19
- async def startup_event():
20
  global model
21
- # Здесь используйте функцию из utils.model_func для загрузки модели
22
  model = load_model()
23
 
 
 
 
24
 
25
- # @app.post('/classify')
26
- # async def classify_image(file: UploadFile = File(...)):
27
- # # Здесь используйте функцию из utils.model_func для классификации изображения
28
- # image_bytes = await file.read()
29
- # prediction = transform_image(image_bytes, model)
30
- # return {"prediction": prediction}
31
 
32
  @app.post('/classify')
33
- async def classify_image(file: UploadFile = File(...)):
34
- # Use the function from utils.model_func to classify the image
35
  image = PIL.Image.open(file.file)
36
  adapted_image = transform_image(image)
37
- pred_index = model(adapted_image.unsqeeze(0).detach().cpu().numpy().argmax())
38
- imagenet_class = class_id_to_label(pred_index)
39
- response = ImageClass(prediction=imagenet_class)
40
- return response
41
-
42
-
43
-
44
-
45
- @app.post('/clf_text')
46
- async def classify_text(text_data: TextClass):
47
- # Здесь используйте функцию из utils.model_func для классификации текста
48
- text = text_data.text
49
- prediction = class_id_to_label(text, model)
50
- return {"prediction": prediction}
 
1
  import PIL
2
  from fastapi import FastAPI, File, UploadFile
3
  from pydantic import BaseModel
 
4
  from utils.model_func import class_id_to_label, load_model, transform_image
5
 
6
+ model = None
7
  app = FastAPI()
8
 
9
 
10
+ # Create class of answer: only class name
11
  class ImageClass(BaseModel):
12
  prediction: str
13
 
14
+ # Load model at startup
 
 
 
15
  @app.on_event("startup")
16
+ def startup_event():
17
  global model
 
18
  model = load_model()
19
 
20
+ @app.get('/')
21
+ def return_info():
22
+ return 'Hello FastAPI'
23
 
 
 
 
 
 
 
24
 
25
  @app.post('/classify')
26
+ def classify(file: UploadFile = File(...)):
 
27
  image = PIL.Image.open(file.file)
28
  adapted_image = transform_image(image)
29
+ pred_index = model(adapted_image.unsqueeze(0)).detach().cpu().numpy().argmax()
30
+ imagenet_class = class_id_to_label(pred_index)
31
+ response = ImageClass(
32
+ prediction=imagenet_class
33
+ )
34
+
35
+ return response