Spaces:
Running
on
Zero
Running
on
Zero
derektan
commited on
Commit
Β·
3cbeaeb
1
Parent(s):
f118874
[NEW] Able to launch both gifs concurrently
Browse files
app.py
CHANGED
|
@@ -21,6 +21,7 @@ import os, glob, threading, time
|
|
| 21 |
import torch
|
| 22 |
from PIL import Image
|
| 23 |
import json
|
|
|
|
| 24 |
import spaces # integration with ZeroGPU on hf
|
| 25 |
|
| 26 |
# Import configuration & RL / TTA utilities -------------------------------------------------
|
|
@@ -66,8 +67,8 @@ print('Model loaded!')
|
|
| 66 |
|
| 67 |
# Init Taxabind here (only need to init once)
|
| 68 |
if TAXABIND_TTA:
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
img_dir=TAXABIND_IMG_DIR,
|
| 72 |
imo_dir=TAXABIND_IMO_DIR,
|
| 73 |
json_path=TAXABIND_INAT_JSON_PATH,
|
|
@@ -85,9 +86,28 @@ if TAXABIND_TTA:
|
|
| 85 |
sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
|
| 86 |
# sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
|
| 87 |
)
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
else:
|
| 90 |
-
clip_seg_tta = None
|
| 91 |
|
| 92 |
# Load metadata json
|
| 93 |
tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
|
|
@@ -102,132 +122,120 @@ tgts_metadata = json.load(open(tgts_metadata_json_path))
|
|
| 102 |
# object. By defining explicit generator functions (with `yield from`) we ensure
|
| 103 |
# `inspect.isgeneratorfunction` evaluates to True and Gradio streams correctly.
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
taxonomy: str | None = None,
|
| 109 |
-
):
|
| 110 |
-
"""Stream search episode **with** TTA enabled while disabling buttons."""
|
| 111 |
-
# Disable buttons initially (image reset)
|
| 112 |
-
yield gr.update(interactive=False), gr.update(interactive=False), gr.update(value=None), gr.update(value="Initializing modelβ¦")
|
| 113 |
-
|
| 114 |
-
last_img = None
|
| 115 |
-
for img in process(sat_path, ground_path, taxonomy, True):
|
| 116 |
-
last_img = img
|
| 117 |
-
yield gr.update(interactive=False), gr.update(interactive=False), img, gr.update(value="Runningβ¦")
|
| 118 |
-
|
| 119 |
-
# Re-enable buttons at the end
|
| 120 |
-
yield gr.update(interactive=True), gr.update(interactive=True), last_img, gr.update(value="Done.")
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def process_no_tta(
|
| 125 |
sat_path: str | None,
|
| 126 |
ground_path: str | None,
|
| 127 |
taxonomy: str | None = None,
|
| 128 |
):
|
| 129 |
-
"""
|
| 130 |
-
yield gr.update(interactive=False), gr.update(interactive=False), gr.update(value=None), gr.update(value="Initializing modelβ¦")
|
| 131 |
-
|
| 132 |
-
last_img = None
|
| 133 |
-
for img in process(sat_path, ground_path, taxonomy, False):
|
| 134 |
-
last_img = img
|
| 135 |
-
yield gr.update(interactive=False), gr.update(interactive=False), img, gr.update(value="Runningβ¦")
|
| 136 |
-
|
| 137 |
-
yield gr.update(interactive=True), gr.update(interactive=True), last_img, gr.update(value="Done.")
|
| 138 |
|
|
|
|
|
|
|
| 139 |
|
| 140 |
-
#
|
| 141 |
-
# @spaces.GPU
|
| 142 |
-
def process(
|
| 143 |
-
sat_path: str | None,
|
| 144 |
-
ground_path: str | None,
|
| 145 |
-
taxonomy: str | None = None,
|
| 146 |
-
with_tta: bool = True,
|
| 147 |
-
):
|
| 148 |
-
"""Callback executed when the user presses **Run** in the UI.
|
| 149 |
-
|
| 150 |
-
At test-time we simply trigger the RL search episode via
|
| 151 |
-
``planner.run_episode`` and return its performance metrics.
|
| 152 |
-
The image inputs are currently *not* used directly here but are
|
| 153 |
-
retained to conform to the requested interface.
|
| 154 |
-
"""
|
| 155 |
-
# If no satellite image is provided we bail out early.
|
| 156 |
if sat_path is None:
|
| 157 |
-
|
|
|
|
| 158 |
|
| 159 |
-
#
|
| 160 |
-
# Load images from paths and configure ClipSegTTA inputs
|
| 161 |
sat_img = Image.open(sat_path).convert("RGB")
|
| 162 |
ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None
|
| 163 |
-
tgts = [tuple(tgt) for tgt in tgts_metadata[taxonomy]["target_positions"]]
|
| 164 |
-
|
| 165 |
-
clip_seg_tta.img_paths = [ground_path] if ground_path else []
|
| 166 |
-
clip_seg_tta.imo_path = sat_path
|
| 167 |
-
clip_seg_tta.imgs = ([clip_seg_tta.dataset.img_transform(ground_img_pil).to(device)]
|
| 168 |
-
if ground_img_pil else [])
|
| 169 |
-
clip_seg_tta.imo = clip_seg_tta.dataset.imo_transform(sat_img).to(device)
|
| 170 |
-
clip_seg_tta.sounds = []
|
| 171 |
-
clip_seg_tta.sound_ids = [] # None
|
| 172 |
-
clip_seg_tta.species_name = taxonomy or ""
|
| 173 |
-
clip_seg_tta.gt_mask_name = taxonomy.replace(" ", "_") # None
|
| 174 |
-
clip_seg_tta.target_positions = tgts if tgts != [] else [(0,0)]
|
| 175 |
-
|
| 176 |
-
# Define TestWorker
|
| 177 |
-
planner = TestWorker(
|
| 178 |
-
meta_agent_id=0,
|
| 179 |
-
n_agent=1,
|
| 180 |
-
policy_net=policy_net,
|
| 181 |
-
global_step=-1,
|
| 182 |
-
device=device,
|
| 183 |
-
greedy=True,
|
| 184 |
-
save_image=SAVE_GIFS,
|
| 185 |
-
clip_seg_tta=clip_seg_tta
|
| 186 |
-
)
|
| 187 |
-
|
| 188 |
-
# ------------------------------------------------------------------
|
| 189 |
|
| 190 |
-
#
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
#
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
-
sent: set[str] = set()
|
| 208 |
-
last_img: str | None = None
|
| 209 |
try:
|
| 210 |
-
while
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
|
| 214 |
for fp in pngs:
|
| 215 |
-
if fp not in
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
time.sleep(POLL_INTERVAL)
|
| 220 |
finally:
|
| 221 |
-
#
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
# If the episode finished naturally, send the last frame once more
|
| 227 |
-
if last_img is not None:
|
| 228 |
-
yield last_img
|
| 229 |
|
| 230 |
-
|
|
|
|
| 231 |
|
| 232 |
|
| 233 |
# ββββββββββββββββββββββββββ Gradio UI βββββββββββββββββββββββββββββββββ
|
|
@@ -280,15 +288,14 @@ with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
|
|
| 280 |
type="filepath",
|
| 281 |
height=320,
|
| 282 |
)
|
| 283 |
-
|
| 284 |
-
run_no_tta_btn = gr.Button("Run (without TTA)", variant="secondary")
|
| 285 |
|
| 286 |
with gr.Column():
|
| 287 |
gr.Markdown("### Live Heatmap (with TTA)")
|
| 288 |
-
display_img_tta = gr.Image(label="Heatmap (TTA)", type="filepath", height=512
|
| 289 |
status_tta = gr.Markdown("")
|
| 290 |
gr.Markdown("### Live Heatmap (without TTA)")
|
| 291 |
-
display_img_no_tta = gr.Image(label="Heatmap (no TTA)", type="filepath", height=512
|
| 292 |
status_no_tta = gr.Markdown("")
|
| 293 |
|
| 294 |
# Bind callback
|
|
@@ -321,21 +328,16 @@ with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
|
|
| 321 |
],
|
| 322 |
],
|
| 323 |
inputs=[sat_input, ground_input, taxonomy_input],
|
| 324 |
-
outputs=[
|
| 325 |
-
fn=
|
| 326 |
cache_examples=False,
|
| 327 |
)
|
| 328 |
|
| 329 |
|
| 330 |
-
|
| 331 |
-
fn=
|
| 332 |
-
inputs=[sat_input, ground_input, taxonomy_input],
|
| 333 |
-
outputs=[run_tta_btn, run_no_tta_btn, display_img_tta, status_tta],
|
| 334 |
-
)
|
| 335 |
-
run_no_tta_btn.click(
|
| 336 |
-
fn=process_no_tta,
|
| 337 |
inputs=[sat_input, ground_input, taxonomy_input],
|
| 338 |
-
outputs=[
|
| 339 |
)
|
| 340 |
|
| 341 |
# Footer to point out to model and data from app page.
|
|
|
|
| 21 |
import torch
|
| 22 |
from PIL import Image
|
| 23 |
import json
|
| 24 |
+
import copy
|
| 25 |
import spaces # integration with ZeroGPU on hf
|
| 26 |
|
| 27 |
# Import configuration & RL / TTA utilities -------------------------------------------------
|
|
|
|
| 67 |
|
| 68 |
# Init Taxabind here (only need to init once)
|
| 69 |
if TAXABIND_TTA:
|
| 70 |
+
# Instantiate TWO independent ClipSegTTA objects (one per concurrent run)
|
| 71 |
+
clip_seg_tta_1 = ClipSegTTA(
|
| 72 |
img_dir=TAXABIND_IMG_DIR,
|
| 73 |
imo_dir=TAXABIND_IMO_DIR,
|
| 74 |
json_path=TAXABIND_INAT_JSON_PATH,
|
|
|
|
| 86 |
sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
|
| 87 |
# sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
|
| 88 |
)
|
| 89 |
+
clip_seg_tta_2 = ClipSegTTA(
|
| 90 |
+
img_dir=TAXABIND_IMG_DIR,
|
| 91 |
+
imo_dir=TAXABIND_IMO_DIR,
|
| 92 |
+
json_path=TAXABIND_INAT_JSON_PATH,
|
| 93 |
+
sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
|
| 94 |
+
patch_size=TAXABIND_PATCH_SIZE,
|
| 95 |
+
sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
|
| 96 |
+
sample_index = -1, # Set using 'reset' in worker
|
| 97 |
+
blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
|
| 98 |
+
device=device,
|
| 99 |
+
sat_to_img_ids_json_is_train_dict=False,
|
| 100 |
+
tax_to_filter_val=QUERY_TAX,
|
| 101 |
+
load_model=USE_CLIP_PREDS,
|
| 102 |
+
initial_modality=INITIAL_MODALITY,
|
| 103 |
+
sound_data_path=TAXABIND_SOUND_DATA_PATH,
|
| 104 |
+
sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
|
| 105 |
+
)
|
| 106 |
+
print("ClipSegTTA instances loaded!")
|
| 107 |
+
# Keep original name for single-run mode compatibility
|
| 108 |
+
clip_seg_tta = clip_seg_tta_1
|
| 109 |
else:
|
| 110 |
+
clip_seg_tta_1 = clip_seg_tta_2 = clip_seg_tta = None
|
| 111 |
|
| 112 |
# Load metadata json
|
| 113 |
tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
|
|
|
|
| 122 |
# object. By defining explicit generator functions (with `yield from`) we ensure
|
| 123 |
# `inspect.isgeneratorfunction` evaluates to True and Gradio streams correctly.
|
| 124 |
|
| 125 |
+
# # # integration with ZeroGPU on hf
|
| 126 |
+
# @spaces.GPU
|
| 127 |
+
def process_search_tta(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
sat_path: str | None,
|
| 129 |
ground_path: str | None,
|
| 130 |
taxonomy: str | None = None,
|
| 131 |
):
|
| 132 |
+
"""Run both TTA and non-TTA search episodes concurrently and stream both heat-maps."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
# Disable Run button and clear image/status outputs
|
| 135 |
+
yield gr.update(interactive=False), gr.update(value=None), gr.update(value=None), gr.update(value="Initializing modelβ¦"), gr.update(value="Initializing modelβ¦")
|
| 136 |
|
| 137 |
+
# Bail early if satellite image missing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
if sat_path is None:
|
| 139 |
+
yield gr.update(interactive=True), gr.update(value=None), gr.update(value=None), gr.update(value="No satellite image provided."), gr.update(value="")
|
| 140 |
+
return
|
| 141 |
|
| 142 |
+
# Prepare PIL images
|
|
|
|
| 143 |
sat_img = Image.open(sat_path).convert("RGB")
|
| 144 |
ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
# Lookup target positions metadata (may be empty)
|
| 147 |
+
tgt_positions = []
|
| 148 |
+
if taxonomy and taxonomy in tgts_metadata:
|
| 149 |
+
tgt_positions = [tuple(t) for t in tgts_metadata[taxonomy]["target_positions"]]
|
| 150 |
+
|
| 151 |
+
# Helper to build a TestWorker with/without TTA
|
| 152 |
+
def build_planner(enable_tta: bool, save_dir: str, clip_obj):
|
| 153 |
+
local_clip = clip_obj # re-use the pre-instantiated ClipSegTTA
|
| 154 |
+
if local_clip is not None:
|
| 155 |
+
# Feed inputs to ClipSegTTA copy
|
| 156 |
+
local_clip.img_paths = [ground_path] if ground_path else []
|
| 157 |
+
local_clip.imo_path = sat_path
|
| 158 |
+
local_clip.imgs = ([local_clip.dataset.img_transform(ground_img_pil).to(device)] if ground_img_pil else [])
|
| 159 |
+
local_clip.imo = local_clip.dataset.imo_transform(sat_img).to(device)
|
| 160 |
+
local_clip.sounds = []
|
| 161 |
+
local_clip.sound_ids = []
|
| 162 |
+
local_clip.species_name = taxonomy or ""
|
| 163 |
+
local_clip.gt_mask_name = taxonomy.replace(" ", "_") if taxonomy else ""
|
| 164 |
+
local_clip.target_positions = tgt_positions if tgt_positions else [(0, 0)]
|
| 165 |
+
|
| 166 |
+
planner = TestWorker(
|
| 167 |
+
meta_agent_id=0,
|
| 168 |
+
n_agent=1,
|
| 169 |
+
policy_net=policy_net,
|
| 170 |
+
global_step=-1,
|
| 171 |
+
device=device,
|
| 172 |
+
greedy=True,
|
| 173 |
+
save_image=SAVE_GIFS,
|
| 174 |
+
clip_seg_tta=local_clip,
|
| 175 |
+
)
|
| 176 |
+
planner.execute_tta = enable_tta
|
| 177 |
+
planner.gifs_path = save_dir
|
| 178 |
+
return planner
|
| 179 |
+
|
| 180 |
+
# Prepare gif directories
|
| 181 |
+
gifs_dir_tta = os.path.join(gifs_path, "with_tta")
|
| 182 |
+
gifs_dir_no = os.path.join(gifs_path, "no_tta")
|
| 183 |
+
for d in (gifs_dir_tta, gifs_dir_no):
|
| 184 |
+
os.makedirs(d, exist_ok=True)
|
| 185 |
+
# Clean previous pngs
|
| 186 |
+
for f in os.listdir(d):
|
| 187 |
+
# if f.endswith(".png"):
|
| 188 |
+
os.remove(os.path.join(d, f))
|
| 189 |
+
|
| 190 |
+
planner_tta = build_planner(True, gifs_dir_tta, clip_seg_tta_1)
|
| 191 |
+
planner_no = build_planner(False, gifs_dir_no, clip_seg_tta_2)
|
| 192 |
+
|
| 193 |
+
# Launch both planners in background threads
|
| 194 |
+
thread_tta = threading.Thread(target=planner_tta.run_episode, args=(0,), daemon=True)
|
| 195 |
+
thread_no = threading.Thread(target=planner_no.run_episode, args=(0,), daemon=True)
|
| 196 |
+
thread_tta.start()
|
| 197 |
+
thread_no.start()
|
| 198 |
+
|
| 199 |
+
sent_tta: set[str] = set()
|
| 200 |
+
sent_no: set[str] = set()
|
| 201 |
+
last_tta = None
|
| 202 |
+
last_no = None
|
| 203 |
|
|
|
|
|
|
|
| 204 |
try:
|
| 205 |
+
while thread_tta.is_alive() or thread_no.is_alive():
|
| 206 |
+
updated = False
|
| 207 |
+
# Collect new frames from TTA dir
|
| 208 |
+
pngs = glob.glob(os.path.join(gifs_dir_tta, "*.png"))
|
| 209 |
+
pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
|
| 210 |
+
for fp in pngs:
|
| 211 |
+
if fp not in sent_tta:
|
| 212 |
+
sent_tta.add(fp)
|
| 213 |
+
last_tta = fp
|
| 214 |
+
updated = True
|
| 215 |
+
# Collect new frames from no-TTA dir
|
| 216 |
+
pngs = glob.glob(os.path.join(gifs_dir_no, "*.png"))
|
| 217 |
pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
|
| 218 |
for fp in pngs:
|
| 219 |
+
if fp not in sent_no:
|
| 220 |
+
sent_no.add(fp)
|
| 221 |
+
last_no = fp
|
| 222 |
+
updated = True
|
| 223 |
+
|
| 224 |
+
if updated:
|
| 225 |
+
status_tta = "Runningβ¦" if thread_tta.is_alive() else "Done."
|
| 226 |
+
status_no = "Runningβ¦" if thread_no.is_alive() else "Done."
|
| 227 |
+
yield gr.update(interactive=False), last_tta, last_no, gr.update(value=status_tta), gr.update(value=status_no)
|
| 228 |
+
|
| 229 |
time.sleep(POLL_INTERVAL)
|
| 230 |
finally:
|
| 231 |
+
# Ensure background threads are stopped on cancel
|
| 232 |
+
for th in (thread_tta, thread_no):
|
| 233 |
+
if th.is_alive():
|
| 234 |
+
_stop_thread(th)
|
| 235 |
+
th.join(timeout=1)
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
+
# Final emit after both finish
|
| 238 |
+
yield gr.update(interactive=True), last_tta, last_no, gr.update(value="Done."), gr.update(value="Done.")
|
| 239 |
|
| 240 |
|
| 241 |
# ββββββββββββββββββββββββββ Gradio UI βββββββββββββββββββββββββββββββββ
|
|
|
|
| 288 |
type="filepath",
|
| 289 |
height=320,
|
| 290 |
)
|
| 291 |
+
run_btn = gr.Button("Run Search-TTA", variant="primary")
|
|
|
|
| 292 |
|
| 293 |
with gr.Column():
|
| 294 |
gr.Markdown("### Live Heatmap (with TTA)")
|
| 295 |
+
display_img_tta = gr.Image(label="Heatmap (TTA)", type="filepath", height=400) # 512
|
| 296 |
status_tta = gr.Markdown("")
|
| 297 |
gr.Markdown("### Live Heatmap (without TTA)")
|
| 298 |
+
display_img_no_tta = gr.Image(label="Heatmap (no TTA)", type="filepath", height=400) # 512
|
| 299 |
status_no_tta = gr.Markdown("")
|
| 300 |
|
| 301 |
# Bind callback
|
|
|
|
| 328 |
],
|
| 329 |
],
|
| 330 |
inputs=[sat_input, ground_input, taxonomy_input],
|
| 331 |
+
outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta],
|
| 332 |
+
fn=process_search_tta,
|
| 333 |
cache_examples=False,
|
| 334 |
)
|
| 335 |
|
| 336 |
|
| 337 |
+
run_btn.click(
|
| 338 |
+
fn=process_search_tta,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
inputs=[sat_input, ground_input, taxonomy_input],
|
| 340 |
+
outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta],
|
| 341 |
)
|
| 342 |
|
| 343 |
# Footer to point out to model and data from app page.
|
env.py
CHANGED
|
@@ -816,8 +816,9 @@ class Env():
|
|
| 816 |
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 817 |
cbar.set_label("Normalized Probs")
|
| 818 |
|
| 819 |
-
|
| 820 |
-
|
|
|
|
| 821 |
|
| 822 |
os.makedirs(save_dir, exist_ok=True)
|
| 823 |
out_path = os.path.join(save_dir, f"{step}.png")
|
|
|
|
| 816 |
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 817 |
cbar.set_label("Normalized Probs")
|
| 818 |
|
| 819 |
+
# Change coverage to 1dp
|
| 820 |
+
plt.suptitle('Targets Found: {}/{} Coverage: {:.1f}% Steps: {}/{}'.format(self.num_targets_found, \
|
| 821 |
+
len(self.target_positions), self.explored_rate*100, step, NUM_EPS_STEPS))
|
| 822 |
|
| 823 |
os.makedirs(save_dir, exist_ok=True)
|
| 824 |
out_path = os.path.join(save_dir, f"{step}.png")
|