aleenarayamajhi commited on
Commit
5455101
·
verified ·
1 Parent(s): 1330cab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -75
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
 
3
  # ---- Bind to Spaces port/address & disable usage stats ----
@@ -5,13 +6,18 @@ os.environ["STREAMLIT_SERVER_PORT"] = os.getenv("PORT", "7860")
5
  os.environ["STREAMLIT_SERVER_ADDRESS"] = "0.0.0.0"
6
  os.environ.setdefault("STREAMLIT_BROWSER_GATHERUSAGESTATS", "false")
7
 
8
- # ---- Use /tmp (always writable) for HOME and caches ----
9
  BASE_DIR = os.getenv("APP_WRITE_DIR", "/tmp")
10
  os.environ.setdefault("HOME", BASE_DIR)
11
  os.environ.setdefault("XDG_CACHE_HOME", f"{BASE_DIR}/.cache")
12
  os.environ.setdefault("XDG_CONFIG_HOME", f"{BASE_DIR}/.config")
13
  os.environ.setdefault("HF_HOME", f"{BASE_DIR}/.cache/huggingface")
14
- for p in [f"{BASE_DIR}/.cache", f"{BASE_DIR}/.config", f"{BASE_DIR}/.cache/huggingface", f"{BASE_DIR}/.streamlit"]:
 
 
 
 
 
15
  os.makedirs(p, exist_ok=True)
16
 
17
  import streamlit as st
@@ -20,7 +26,7 @@ st.set_page_config(
20
  layout="wide",
21
  )
22
 
23
- import sys, gc
24
  import torch
25
  import numpy as np
26
  from PIL import Image, ImageFilter
@@ -31,7 +37,7 @@ from transformers import (
31
  )
32
  from huggingface_hub import hf_hub_download
33
 
34
- # ---- Make CPU predictable/light ----
35
  try:
36
  torch.set_num_threads(1)
37
  torch.set_num_interop_threads(1)
@@ -87,17 +93,17 @@ if os.path.exists(logo_path):
87
 
88
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
89
 
90
- # Preferred class IDs (fallback to auto-detect if no segments match)
91
- PREF_CANOPY_ID = os.getenv("CANOPY_CLASS_ID")
92
- PREF_DISEASE_ID = os.getenv("DISEASE_CLASS_ID")
93
- PREF_CANOPY_ID = int(PREF_CANOPY_ID) if PREF_CANOPY_ID is not None else None
94
- PREF_DISEASE_ID = int(PREF_DISEASE_ID) if PREF_DISEASE_ID is not None else None
95
 
96
  # ---- Model loader helpers ----
97
  def _load_state(weights_path: str):
98
  if weights_path.endswith(".safetensors"):
99
  from safetensors.torch import load_file as safe_load_file
100
  return safe_load_file(weights_path)
 
101
  return torch.load(weights_path, map_location="cpu", weights_only=True)
102
 
103
  def _load_maskformer(config_path: str, weights_path: str):
@@ -122,7 +128,7 @@ def load_segmentation_models():
122
  canopy_model, disease_model, processor, SEG_LOAD_ERR = load_segmentation_models()
123
  SEG_READY = (canopy_model is not None) and (disease_model is not None)
124
 
125
- # ---- Lazy GPT-2 generator (optional) ----
126
  @st.cache_resource
127
  def get_text_generator():
128
  try:
@@ -131,7 +137,7 @@ def get_text_generator():
131
  except Exception as e:
132
  return None, f"{type(e).__name__}: {e}"
133
 
134
- # ---- Inference helpers ----
135
  @torch.inference_mode()
136
  def run_inference(model, image):
137
  inputs = processor(image, return_tensors="pt", do_resize=False)
@@ -144,18 +150,17 @@ def run_inference(model, image):
144
  mask_threshold=0.5,
145
  overlap_mask_area_threshold=0.5,
146
  )[0]
147
- return results # dict with 'segmentation' and 'segments_info'
148
 
149
  def _segmap_to_numpy(seg_map):
150
- """Ensure segmentation map is NumPy."""
151
  if isinstance(seg_map, torch.Tensor):
152
  return seg_map.detach().cpu().numpy()
153
  return np.array(seg_map)
154
 
155
- def union_mask_with_fallback(results, preferred_label_id=None):
156
  """
157
- Build a boolean mask. Try preferred_label_id first; if it yields no pixels,
158
- choose the label_id with the largest total area.
159
  Returns (mask_bool, chosen_label_id).
160
  """
161
  seg_map = _segmap_to_numpy(results["segmentation"])
@@ -172,13 +177,13 @@ def union_mask_with_fallback(results, preferred_label_id=None):
172
  hit = True
173
  return m, hit
174
 
175
- # 1) Try preferred id (if provided)
176
  if preferred_label_id is not None:
177
  m, hit = union_for(preferred_label_id)
178
  if hit:
179
  return m, int(preferred_label_id)
180
 
181
- # 2) Auto-pick: label with largest area
182
  area_per_label = {}
183
  for info in infos:
184
  sid = int(info["id"])
@@ -186,80 +191,111 @@ def union_mask_with_fallback(results, preferred_label_id=None):
186
  area_per_label[lid] = area_per_label.get(lid, 0) + int((seg_map == sid).sum())
187
  if not area_per_label:
188
  return np.zeros(seg_map.shape, dtype=bool), None
189
- best_lid = max(area_per_label, key=area_per_label.get)
 
 
190
  m, _ = union_for(best_lid)
191
  return m, int(best_lid)
192
 
193
- # ---------- Soft, true-shape overlay (no square boxes) ----------
194
- def soft_overlay(image_np: np.ndarray,
195
- canopy_mask: np.ndarray,
196
- disease_mask: np.ndarray,
197
- alpha_canopy: float = 0.30,
198
- alpha_disease: float = 0.50,
199
- sigma: float = 1.8) -> Image.Image:
200
- """
201
- Feathered compositing with no outline, so shapes follow the mask exactly.
202
- """
203
- out = image_np.astype(np.float32)
204
-
205
- def _mask_to_soft_alpha(mask_bool: np.ndarray, sigma_val: float) -> np.ndarray:
206
- if not mask_bool.any():
207
- return np.zeros_like(mask_bool, dtype=np.float32)
208
- m = Image.fromarray((mask_bool.astype(np.uint8) * 255))
209
- m = m.filter(ImageFilter.GaussianBlur(radius=max(0.1, float(sigma_val))))
210
- return (np.asarray(m, dtype=np.float32) / 255.0)
211
-
212
- # Canopy first (orange), then disease (blue) on top
213
- ca = _mask_to_soft_alpha(canopy_mask, sigma) * alpha_canopy
214
- if ca.any():
215
- for ch, col in enumerate([255, 165, 0]):
216
- out[..., ch] = out[..., ch] * (1.0 - ca) + col * ca
217
-
218
- da = _mask_to_soft_alpha(disease_mask, sigma) * alpha_disease
219
- if da.any():
220
- for ch, col in enumerate([0, 0, 255]):
221
- out[..., ch] = out[..., ch] * (1.0 - da) + col * da
222
-
223
- out = np.clip(out, 0, 255).astype(np.uint8)
224
- return Image.fromarray(out)
225
- # ---------------------------------------------------------------
 
 
 
 
 
 
 
226
 
227
  def process_image(image: Image.Image):
228
- # Slightly larger default improves mask smoothness
229
- MAX_SIDE = int(os.getenv("MAX_SIDE", "640")) # was 512
230
  im = image.copy()
231
  im.thumbnail((MAX_SIDE, MAX_SIDE), Image.LANCZOS)
232
-
233
  image_np = np.array(im)
234
 
235
  # Run BOTH models
236
  canopy_results = run_inference(canopy_model, im)
237
  disease_results = run_inference(disease_model, im)
238
 
239
- # Build masks, using preferred label ids, with fallbacks to largest-area label
240
- canopy_mask_bool, chosen_canopy_id = union_mask_with_fallback(canopy_results, PREF_CANOPY_ID)
241
- disease_mask_bool, chosen_disease_id = union_mask_with_fallback(disease_results, PREF_DISEASE_ID)
 
 
 
 
 
 
 
 
242
 
243
  # If canopy not found, bail gracefully
244
- if canopy_mask_bool.sum() == 0:
245
  return im, Image.fromarray(image_np), 0.0
246
 
247
- # Only count disease inside canopy
248
- disease_in_canopy = disease_mask_bool & canopy_mask_bool
249
- canopy_pixels = int(canopy_mask_bool.sum())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  disease_pixels = int(disease_in_canopy.sum())
251
  disease_pct = round((disease_pixels / canopy_pixels) * 100, 2) if canopy_pixels > 0 else 0.0
252
 
253
- # Soft, more transparent overlay
254
- overlay = soft_overlay(
255
- image_np,
256
- canopy_mask=canopy_mask_bool,
257
- disease_mask=disease_in_canopy,
258
- alpha_canopy=0.30,
259
- alpha_disease=0.50,
260
- sigma=1.8
261
- )
262
-
263
  return im, overlay, disease_pct
264
 
265
  # ---- UI ----
@@ -284,7 +320,6 @@ user_description = st.sidebar.text_area(
284
  )
285
 
286
  if uploaded_file:
287
- from PIL import Image # local import to keep startup snappy
288
  image = Image.open(uploaded_file).convert("RGB")
289
  st.write("**Running Segmentation..**")
290
 
@@ -297,9 +332,9 @@ if uploaded_file:
297
  small, overlay_image, disease_percentage = process_image(image)
298
  col1, col2 = st.columns([1, 1])
299
  with col1:
300
- st.image(small, caption="Original Image", width=350)
301
  with col2:
302
- st.image(overlay_image, caption="Segmented Image (canopy=orange, disease=blue)", width=350)
303
  except Exception as e:
304
  st.warning(f"Segmentation failed gracefully: {type(e).__name__}: {e}")
305
  st.image(image, caption="Original Image", width=350)
 
1
+ # app.py
2
  import os
3
 
4
  # ---- Bind to Spaces port/address & disable usage stats ----
 
6
  os.environ["STREAMLIT_SERVER_ADDRESS"] = "0.0.0.0"
7
  os.environ.setdefault("STREAMLIT_BROWSER_GATHERUSAGESTATS", "false")
8
 
9
+ # ---- Writable dirs on Spaces ----
10
  BASE_DIR = os.getenv("APP_WRITE_DIR", "/tmp")
11
  os.environ.setdefault("HOME", BASE_DIR)
12
  os.environ.setdefault("XDG_CACHE_HOME", f"{BASE_DIR}/.cache")
13
  os.environ.setdefault("XDG_CONFIG_HOME", f"{BASE_DIR}/.config")
14
  os.environ.setdefault("HF_HOME", f"{BASE_DIR}/.cache/huggingface")
15
+ for p in [
16
+ f"{BASE_DIR}/.cache",
17
+ f"{BASE_DIR}/.config",
18
+ f"{BASE_DIR}/.cache/huggingface",
19
+ f"{BASE_DIR}/.streamlit",
20
+ ]:
21
  os.makedirs(p, exist_ok=True)
22
 
23
  import streamlit as st
 
26
  layout="wide",
27
  )
28
 
29
+ import gc
30
  import torch
31
  import numpy as np
32
  from PIL import Image, ImageFilter
 
37
  )
38
  from huggingface_hub import hf_hub_download
39
 
40
+ # ---- Keep CPU predictable/light ----
41
  try:
42
  torch.set_num_threads(1)
43
  torch.set_num_interop_threads(1)
 
93
 
94
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
95
 
96
+ PREF_CANOPY_ID = os.getenv("CANOPY_CLASS_ID")
97
+ PREF_DISEASE_ID = os.getenv("DISEASE_CLASS_ID")
98
+ PREF_CANOPY_ID = int(PREF_CANOPY_ID) if PREF_CANOPY_ID is not None else None
99
+ PREF_DISEASE_ID = int(PREF_DISEASE_ID) if PREF_DISEASE_ID is not None else None
 
100
 
101
  # ---- Model loader helpers ----
102
  def _load_state(weights_path: str):
103
  if weights_path.endswith(".safetensors"):
104
  from safetensors.torch import load_file as safe_load_file
105
  return safe_load_file(weights_path)
106
+ # .bin
107
  return torch.load(weights_path, map_location="cpu", weights_only=True)
108
 
109
  def _load_maskformer(config_path: str, weights_path: str):
 
128
  canopy_model, disease_model, processor, SEG_LOAD_ERR = load_segmentation_models()
129
  SEG_READY = (canopy_model is not None) and (disease_model is not None)
130
 
131
+ # ---- Optional text generator (kept lazy) ----
132
  @st.cache_resource
133
  def get_text_generator():
134
  try:
 
137
  except Exception as e:
138
  return None, f"{type(e).__name__}: {e}"
139
 
140
+ # ---- Inference & mask helpers ----
141
  @torch.inference_mode()
142
  def run_inference(model, image):
143
  inputs = processor(image, return_tensors="pt", do_resize=False)
 
150
  mask_threshold=0.5,
151
  overlap_mask_area_threshold=0.5,
152
  )[0]
153
+ return results
154
 
155
  def _segmap_to_numpy(seg_map):
 
156
  if isinstance(seg_map, torch.Tensor):
157
  return seg_map.detach().cpu().numpy()
158
  return np.array(seg_map)
159
 
160
+ def union_mask_with_fallback(results, preferred_label_id=None, prefer_smallest=False):
161
  """
162
+ Build a boolean mask. If preferred_label_id is provided and present, use it.
163
+ Otherwise choose by area: largest by default, or smallest if prefer_smallest=True.
164
  Returns (mask_bool, chosen_label_id).
165
  """
166
  seg_map = _segmap_to_numpy(results["segmentation"])
 
177
  hit = True
178
  return m, hit
179
 
180
+ # Preferred first
181
  if preferred_label_id is not None:
182
  m, hit = union_for(preferred_label_id)
183
  if hit:
184
  return m, int(preferred_label_id)
185
 
186
+ # Compute area per label_id
187
  area_per_label = {}
188
  for info in infos:
189
  sid = int(info["id"])
 
191
  area_per_label[lid] = area_per_label.get(lid, 0) + int((seg_map == sid).sum())
192
  if not area_per_label:
193
  return np.zeros(seg_map.shape, dtype=bool), None
194
+
195
+ # Choose by area
196
+ best_lid = (min if prefer_smallest else max)(area_per_label, key=area_per_label.get)
197
  m, _ = union_for(best_lid)
198
  return m, int(best_lid)
199
 
200
+ def smooth_bool_mask(mask_bool, radius=1.1):
201
+ """Slight Gaussian blur + threshold to de-jag mask edges."""
202
+ if mask_bool.dtype != np.bool_:
203
+ mask_bool = mask_bool.astype(bool)
204
+ im = Image.fromarray((mask_bool * 255).astype(np.uint8))
205
+ im = im.filter(ImageFilter.GaussianBlur(radius=radius))
206
+ arr = np.array(im) > 127
207
+ return arr
208
+
209
+ def erode_4n(mask):
210
+ """Simple 4-neighbour erosion (approx) without SciPy."""
211
+ up = np.roll(mask, -1, axis=0)
212
+ down = np.roll(mask, 1, axis=0)
213
+ left = np.roll(mask, -1, axis=1)
214
+ right = np.roll(mask, 1, axis=1)
215
+ return mask & up & down & left & right
216
+
217
+ def mask_outline(mask):
218
+ """1px outline from a boolean mask."""
219
+ er = erode_4n(mask)
220
+ return mask & (~er)
221
+
222
+ def to_color(mask_bool, image_np, color):
223
+ rgb = np.zeros_like(image_np)
224
+ rgb[mask_bool] = color
225
+ return rgb
226
+
227
+ def create_overlay(image_np, canopy_bool, disease_bool, alpha_canopy=0.35, alpha_disease=0.45):
228
+ canopy_rgb = to_color(canopy_bool, image_np, [255, 165, 0]) # orange
229
+ disease_rgb = to_color(disease_bool, image_np, [ 0, 0, 255]) # blue
230
+
231
+ # Blend canopy first, then disease (so disease remains visible over canopy)
232
+ overlay1 = (image_np * (1 - alpha_canopy) + canopy_rgb * alpha_canopy).astype(np.uint8)
233
+ overlay2 = (overlay1 * (1 - alpha_disease) + disease_rgb * alpha_disease).astype(np.uint8)
234
+
235
+ # Thin outline around disease to improve visibility
236
+ edges = mask_outline(disease_bool)
237
+ overlay2[edges] = [ 10, 74, 255]
238
+
239
+ return Image.fromarray(overlay2)
240
 
241
  def process_image(image: Image.Image):
242
+ MAX_SIDE = int(os.getenv("MAX_SIDE", "768"))
 
243
  im = image.copy()
244
  im.thumbnail((MAX_SIDE, MAX_SIDE), Image.LANCZOS)
 
245
  image_np = np.array(im)
246
 
247
  # Run BOTH models
248
  canopy_results = run_inference(canopy_model, im)
249
  disease_results = run_inference(disease_model, im)
250
 
251
+ # Canopy: prefer pinned id or (fallback) largest area
252
+ canopy_mask, chosen_canopy_id = union_mask_with_fallback(
253
+ canopy_results, PREF_CANOPY_ID, prefer_smallest=False
254
+ )
255
+ canopy_mask = smooth_bool_mask(canopy_mask, radius=1.0)
256
+
257
+ # Disease: prefer pinned id or (fallback) smallest area (spots)
258
+ disease_mask, chosen_disease_id = union_mask_with_fallback(
259
+ disease_results, PREF_DISEASE_ID, prefer_smallest=True
260
+ )
261
+ disease_mask = smooth_bool_mask(disease_mask, radius=0.8)
262
 
263
  # If canopy not found, bail gracefully
264
+ if canopy_mask.sum() == 0:
265
  return im, Image.fromarray(image_np), 0.0
266
 
267
+ # Only count disease INSIDE canopy
268
+ disease_in_canopy = disease_mask & canopy_mask
269
+
270
+ frac = disease_in_canopy.sum() / max(1, canopy_mask.sum())
271
+ if frac > 0.6:
272
+ # Re-pick a smaller label (2nd smallest by area)
273
+ seg_map = _segmap_to_numpy(disease_results["segmentation"])
274
+ infos = disease_results.get("segments_info", [])
275
+ area_per_label = {}
276
+ for info in infos:
277
+ sid = int(info["id"]); lid = int(info["label_id"])
278
+ area_per_label.setdefault(lid, 0)
279
+ area_per_label[lid] += int((seg_map == sid).sum())
280
+ if len(area_per_label) > 1:
281
+ small_to_large = sorted(area_per_label.items(), key=lambda kv: kv[1])
282
+ for lid, _ in small_to_large:
283
+ if PREF_DISEASE_ID is not None and lid == PREF_DISEASE_ID:
284
+ continue
285
+ m = np.zeros_like(seg_map, dtype=bool)
286
+ for info in infos:
287
+ if int(info["label_id"]) == lid:
288
+ m |= (seg_map == int(info["id"]))
289
+ if m.sum() > 0:
290
+ disease_mask = smooth_bool_mask(m, radius=0.8)
291
+ disease_in_canopy = disease_mask & canopy_mask
292
+ break
293
+
294
+ canopy_pixels = int(canopy_mask.sum())
295
  disease_pixels = int(disease_in_canopy.sum())
296
  disease_pct = round((disease_pixels / canopy_pixels) * 100, 2) if canopy_pixels > 0 else 0.0
297
 
298
+ overlay = create_overlay(image_np, canopy_mask, disease_in_canopy)
 
 
 
 
 
 
 
 
 
299
  return im, overlay, disease_pct
300
 
301
  # ---- UI ----
 
320
  )
321
 
322
  if uploaded_file:
 
323
  image = Image.open(uploaded_file).convert("RGB")
324
  st.write("**Running Segmentation..**")
325
 
 
332
  small, overlay_image, disease_percentage = process_image(image)
333
  col1, col2 = st.columns([1, 1])
334
  with col1:
335
+ st.image(small, caption="Original Image", width=450)
336
  with col2:
337
+ st.image(overlay_image, caption="Segmented Image (canopy=orange, disease=blue)", width=450)
338
  except Exception as e:
339
  st.warning(f"Segmentation failed gracefully: {type(e).__name__}: {e}")
340
  st.image(image, caption="Original Image", width=350)