derektan commited on
Commit
ab93c8a
·
1 Parent(s): 212e8c3

Moved model loading into thread, for better concurrency and speed

Browse files
Files changed (1) hide show
  1. app.py +28 -19
app.py CHANGED
@@ -22,6 +22,7 @@ import torch
22
  from PIL import Image
23
  import json
24
  import copy
 
25
  import spaces # integration with ZeroGPU on hf
26
 
27
  # Import configuration & RL / TTA utilities -------------------------------------------------
@@ -205,28 +206,36 @@ def process_search_tta(
205
  planner.gifs_path = save_dir
206
  return planner
207
 
208
- # Prepare gif directories
209
  gifs_dir_tta = os.path.join(gifs_path, "with_tta")
210
- gifs_dir_no = os.path.join(gifs_path, "no_tta")
211
- for d in (gifs_dir_tta, gifs_dir_no):
212
- os.makedirs(d, exist_ok=True)
213
- # Clean previous pngs
214
- for f in os.listdir(d):
215
- # if f.endswith(".png"):
216
- os.remove(os.path.join(d, f))
217
-
218
- planner_tta = build_planner(True, gifs_dir_tta, clip_seg_tta_1)
219
- planner_no = build_planner(False, gifs_dir_no, clip_seg_tta_2)
220
-
221
- # Launch both planners in background threads
222
- thread_tta = threading.Thread(target=planner_tta.run_episode, args=(0,), daemon=True)
223
- thread_no = threading.Thread(target=planner_no.run_episode, args=(0,), daemon=True)
224
- thread_tta.start()
225
- thread_no.start()
226
-
227
- # Register threads so they can be cancelled when user switches tabs
 
 
 
 
 
 
228
  _register_thread(thread_tta)
229
  _register_thread(thread_no)
 
 
230
 
231
 
232
  sent_tta: set[str] = set()
 
22
  from PIL import Image
23
  import json
24
  import copy
25
+ import shutil
26
  import spaces # integration with ZeroGPU on hf
27
 
28
  # Import configuration & RL / TTA utilities -------------------------------------------------
 
206
  planner.gifs_path = save_dir
207
  return planner
208
 
209
+ # Directory paths for generated frames
210
  gifs_dir_tta = os.path.join(gifs_path, "with_tta")
211
+ gifs_dir_no = os.path.join(gifs_path, "no_tta")
212
+
213
+ # Clean previous run's PNG frames before launching new planners
214
+ for _dir in (gifs_dir_tta, gifs_dir_no):
215
+ shutil.rmtree(_dir, ignore_errors=True)
216
+ os.makedirs(_dir, exist_ok=True)
217
+
218
+ def _planner_thread(enable_tta: bool, save_dir: str, clip_obj):
219
+ """Prepare directory, build planner, and run an episode (in one thread)."""
220
+ # Directory prepared by caller; just build planner and run episode
221
+ planner = build_planner(enable_tta, save_dir, clip_obj)
222
+ planner.run_episode(0)
223
+
224
+ # Launch both planners in background threads – preparation included
225
+ thread_tta = threading.Thread(
226
+ target=_planner_thread,
227
+ args=(True, gifs_dir_tta, clip_seg_tta_1),
228
+ daemon=True,
229
+ )
230
+ thread_no = threading.Thread(
231
+ target=_planner_thread,
232
+ args=(False, gifs_dir_no, clip_seg_tta_2),
233
+ daemon=True,
234
+ )
235
  _register_thread(thread_tta)
236
  _register_thread(thread_no)
237
+ thread_tta.start()
238
+ thread_no.start()
239
 
240
 
241
  sent_tta: set[str] = set()