mouadenna commited on
Commit
8d082c2
·
verified ·
1 Parent(s): 4aa135b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -1
app.py CHANGED
@@ -9,7 +9,78 @@ import io
9
  import zipfile
10
 
11
  # Assuming you have these functions defined elsewhere
12
- from your_module import preprocess, best_model, DEVICE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4):
15
  tiles = []
 
9
  import zipfile
10
 
11
  # Assuming you have these functions defined elsewhere
12
+ import torch
13
+ import numpy as np
14
+ from PIL import Image
15
+ import albumentations as albu
16
+ import segmentation_models_pytorch as smp
17
+ from albumentations.pytorch.transforms import ToTensorV2
18
+
19
+
20
+
21
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ ENCODER = 'se_resnext50_32x4d'
23
+ ENCODER_WEIGHTS = 'imagenet'
24
+
25
+ # Load and prepare the model
26
+ best_model = torch.load('deeplabv3+ v15.pth', map_location=DEVICE)
27
+ best_model.eval().float()
28
+
29
+ def to_tensor(x, **kwargs):
30
+ return x.astype('float32')#.transpose(2, 0, 1)
31
+
32
+ # Preprocessing
33
+ preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
34
+
35
+ def get_preprocessing():
36
+ _transform = [
37
+ albu.Resize(512, 512),
38
+ albu.Lambda(image=preprocessing_fn),
39
+ albu.Lambda(image=to_tensor, mask=to_tensor),
40
+ ToTensorV2(),
41
+ #albu.Normalize(mean=MEAN,std=STD)
42
+
43
+ ]
44
+ return albu.Compose(_transform)
45
+
46
+
47
+ preprocess = get_preprocessing()
48
+
49
+ @torch.no_grad()
50
+ def process_and_predict(image, model):
51
+ # Convert PIL Image to numpy array if necessary
52
+ if isinstance(image, Image.Image):
53
+ image = np.array(image)
54
+
55
+ # Ensure image is 3-channel
56
+ if image.ndim == 2:
57
+ image = np.stack([image] * 3, axis=-1)
58
+ elif image.shape[2] == 4:
59
+ image = image[:, :, :3]
60
+
61
+ # Apply preprocessing
62
+ preprocessed = preprocess(image=image)['image']
63
+ #preprocessed=torch.tensor(preprocessed)
64
+ # Add batch dimension and move to device
65
+ input_tensor = preprocessed.unsqueeze(0).to(DEVICE)
66
+
67
+ print(input_tensor.shape)
68
+ # Predict
69
+ mask = model(input_tensor)
70
+ mask = torch.sigmoid(mask)
71
+ mask = (mask > 0.6).float()
72
+
73
+ # Convert to PIL Image
74
+ mask_image = Image.fromarray((mask.squeeze().cpu().numpy() * 255).astype(np.uint8))
75
+
76
+ return mask_image
77
+
78
+ #example
79
+ def main(image_path):
80
+ image = Image.open(image_path)
81
+ mask = process_and_predict(image, best_model)
82
+ return mask
83
+
84
 
85
  def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4):
86
  tiles = []