Spaces:
Running
on
Zero
Running
on
Zero
| ####################################################################### | |
| # Name: multi_robot_worker.py | |
| # | |
| # - Runs robot in environment for N steps | |
| # - Collects & Returns S(t), A(t), R(t), S(t+1) | |
| # - NOTE: Applicable for multiple robots | |
| ####################################################################### | |
| from pathlib import Path | |
| from test_parameter import * | |
| from time import time | |
| import imageio | |
| import csv | |
| import os | |
| import copy | |
| import numpy as np | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import json | |
| from env import Env | |
| from model import PolicyNet | |
| from robot import Robot | |
| from Taxabind.TaxaBind.SatBind.watershed_segmentation import WatershedBinomial | |
| from Taxabind.TaxaBind.SatBind.kmeans_clustering import CombinedSilhouetteInertiaClusterer | |
| from Taxabind.TaxaBind.SatBind.clip_seg_tta import ClipSegTTA | |
| np.seterr(invalid='raise', divide='raise') | |
| class TestWorker: | |
| def __init__(self, meta_agent_id, n_agent, policy_net, global_step, device='cuda', greedy=False, save_image=False, clip_seg_tta=None): | |
| self.device = device | |
| self.greedy = greedy | |
| self.n_agent = n_agent | |
| self.metaAgentID = meta_agent_id | |
| self.global_step = global_step | |
| self.k_size = K_SIZE | |
| self.save_image = save_image | |
| self.tta = TAXABIND_TTA | |
| self.clip_seg_tta = clip_seg_tta | |
| self.execute_tta = EXECUTE_TTA # Added to interface with app.py | |
| self.env = Env(map_index=self.global_step, n_agent=n_agent, k_size=self.k_size, plot=save_image, test=True) | |
| self.local_policy_net = policy_net | |
| self.robot_list = [] | |
| self.all_robot_positions = [] | |
| for i in range(self.n_agent): | |
| # robot_position = self.env.node_coords[i] | |
| robot_position = self.env.start_positions[i] | |
| robot = Robot(robot_id=i, position=robot_position, plot=save_image) | |
| self.robot_list.append(robot) | |
| self.all_robot_positions.append(robot_position) | |
| self.perf_metrics = dict() | |
| self.bad_mask_init = False | |
| # NOTE: Option to override gifs_path to interface with app.py | |
| self.gifs_path = gifs_path | |
| # # TEMP - EXPORT START POSES FOR BASELINES | |
| # json_path = "eval_start_positions.json" | |
| # sat_to_start_pose_dict = {} | |
| # print("len(self.env.map_list): ", len(self.env.map_list)) | |
| # for i in range(4000): | |
| # print("i: ", i) | |
| # map_idx = i % len(self.env.map_list) | |
| # _, map_start_position = self.env.import_ground_truth(os.path.join(self.env.map_dir, self.env.map_list[map_idx])) | |
| # self.clip_seg_tta.reset(sample_idx=i) | |
| # sat_path = self.clip_seg_tta.gt_mask_name | |
| # sat_to_start_pose_dict[sat_path] = tuple(map(int, map_start_position)) | |
| # # Save to json | |
| # with open(json_path, 'w') as f: | |
| # json.dump(sat_to_start_pose_dict, f) | |
| # print("len(sat_to_start_pose_dict): ", len(sat_to_start_pose_dict)) | |
| # exit() | |
| if self.tta: | |
| # NOTE: Moved to test_driver.py for efficiency (avoid repeated init) | |
| # self.clip_seg_tta = ClipSegTTA( | |
| # img_dir=TAXABIND_IMG_DIR, | |
| # imo_dir=TAXABIND_IMO_DIR, | |
| # json_path=TAXABIND_INAT_JSON_PATH, | |
| # sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH, | |
| # patch_size=TAXABIND_PATCH_SIZE, | |
| # sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH, | |
| # sample_index = TAXABIND_SAMPLE_INDEX, #global_step | |
| # blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL, | |
| # device=self.device, | |
| # sat_to_img_ids_json_is_train_dict=False # for search ds val | |
| # # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH, | |
| # ) | |
| # NOTE: updated due to app.py (hf does not allow heatmap to persist) | |
| if clip_seg_tta is not None: | |
| time_start = time() | |
| heatmap, heatmap_unnormalized, heatmap_unnormalized_initial, patch_embeds = self.clip_seg_tta.reset(sample_idx=self.global_step) | |
| self.clip_seg_tta.heatmap = heatmap | |
| self.clip_seg_tta.heatmap_unnormalized = heatmap_unnormalized | |
| self.clip_seg_tta.heatmap_unnormalized_initial = heatmap_unnormalized_initial | |
| self.clip_seg_tta.patch_embeds = patch_embeds | |
| # print("Resetting for sample index: ", self.global_step) | |
| print("Time taken to reset: ", time() - time_start) | |
| # Override target positions in env | |
| self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions] # Must transpose to (y, x) format | |
| # Override GT seg mask for info_gain metric | |
| if OVERRIDE_GT_MASK_DIR != "": | |
| self.tta_gt_seg_path = os.path.join(OVERRIDE_GT_MASK_DIR, self.clip_seg_tta.gt_mask_name) | |
| print("self.clip_seg_tta.gt_mask_name: ", self.clip_seg_tta.gt_mask_name) | |
| if os.path.exists(self.tta_gt_seg_path): | |
| self.env.gt_segmentation_mask = self.env.import_segmentation_mask(self.tta_gt_seg_path) | |
| else: | |
| print("\n\n!!!!!! WARNING: GT mask not found at path: ", self.tta_gt_seg_path) | |
| if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "": | |
| # mask_name = self.clip_seg_tta.gt_mask_name.split('_')[:-TAX_HIERARCHY_TO_CONDENSE] | |
| # mask_name = '_'.join(mask_name) + ".png" | |
| score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name) | |
| print("score_mask_path: ", score_mask_path) | |
| if os.path.exists(score_mask_path): | |
| self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path) | |
| self.env.begin(self.env.map_start_position) | |
| else: | |
| print(f"\n\n{RED}!!!!!! ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path) | |
| self.bad_mask_init = True | |
| # # # # TEMP: Additional targets | |
| # self.env.target_positions = [] # Reset | |
| # print("self.env.target_positions", self.env.target_positions) | |
| # # self.env.target_positions = [self.env.target_positions[2]] | |
| # # self.env.target_positions = [self.env.target_positions[0]] | |
| # # self.env.target_positions.append((251, 297)) | |
| # self.env.target_positions.append((40,40)) | |
| # self.env.target_positions.append((80,40)) | |
| # self.env.target_positions.append((120,40)) | |
| # self.env.target_positions.append((160,40)) | |
| # self.env.target_positions.append((200,40)) | |
| # Save clustered embeds from sat encoder | |
| # In thery, we only need to do this once (same satellite map throughout) | |
| if USE_CLIP_PREDS: | |
| time_start = time() | |
| self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer( | |
| k_min=1, | |
| k_max=8, | |
| k_avg_max=4, | |
| silhouette_threshold=0.15, | |
| relative_threshold=0.15, | |
| random_state=0, | |
| min_patch_size=5, # smoothing parameter | |
| n_smooth_iter=2, # smoothing parameter | |
| ignore_label=-1, | |
| plot=False, # NOTE: Set to false since using app.py | |
| gifs_dir = self.gifs_path # NOTE: Set to self.gifs_path since using app.py | |
| ) | |
| # Fit & predict (this will also plot the clusters before & after smoothing) | |
| map_shape = (int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])), int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0]))) | |
| self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict( | |
| patch_embeds=self.clip_seg_tta.patch_embeds, | |
| map_shape=map_shape, | |
| ) | |
| print("Chosen k:", self.kmeans_clusterer.final_k) | |
| # print("Smoothed labels shape:", self.kmeans_sat_embeds_clusters.shape) | |
| print("Time taken to cluster: ", time() - time_start) | |
| # if EXECUTE_TTA: | |
| # print("Will execute TTA...") | |
| # Define Poisson TTA params | |
| self.step_since_tta = 0 | |
| self.steps_to_first_tgt = None | |
| self.steps_to_mid_tgt = None | |
| self.steps_to_last_tgt = None | |
| self.sim_0perc = None | |
| self.sim_25perc = None | |
| self.sim_50perc = None | |
| self.sim_75perc = None | |
| self.sim_100perc = None | |
| def run_episode(self, curr_episode): | |
| # Return all metrics as None if faulty mask init | |
| if self.bad_mask_init: | |
| self.perf_metrics['tax'] = None | |
| self.perf_metrics['tax_first'] = None | |
| self.perf_metrics['travel_dist'] = None | |
| self.perf_metrics['travel_steps'] = None | |
| self.perf_metrics['steps_to_first_tgt'] = None | |
| self.perf_metrics['steps_to_mid_tgt'] = None | |
| self.perf_metrics['steps_to_last_tgt'] = None | |
| self.perf_metrics['explored_rate'] = None | |
| self.perf_metrics['targets_found'] = None | |
| self.perf_metrics['targets_total'] = None | |
| self.perf_metrics['sim_0perc'] = None | |
| self.perf_metrics['sim_25perc'] = None | |
| self.perf_metrics['sim_50perc'] = None | |
| self.perf_metrics['sim_75perc'] = None | |
| self.perf_metrics['sim_100perc'] = None | |
| self.perf_metrics['kmeans_k'] = None | |
| self.perf_metrics['tgts_gt_score'] = None | |
| self.perf_metrics['clip_inference_time'] = None | |
| self.perf_metrics['tta_time'] = None | |
| self.perf_metrics['info_gain'] = None | |
| self.perf_metrics['total_info'] = None | |
| self.perf_metrics['success_rate'] = None | |
| return | |
| eps_start = time() | |
| done = False | |
| for robot_id, deciding_robot in enumerate(self.robot_list): | |
| deciding_robot.observations = self.get_observations(deciding_robot.robot_position) | |
| if self.tta and USE_CLIP_PREDS: | |
| self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1) | |
| self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1) | |
| # print("self.env.segmentation_info_mask.shape", self.env.segmentation_info_mask.shape) | |
| ### Run episode for 128 steps ### | |
| for step in range(NUM_EPS_STEPS): | |
| # print("\n\n\n~~~~~~~~~~~~~~~~~~~~~~ Step: ", step, " ~~~~~~~~~~~~~~~~~~~~~~") | |
| next_position_list = [] | |
| dist_list = [] | |
| travel_dist_list = [] | |
| dist_array = np.zeros((self.n_agent, 1)) | |
| for robot_id, deciding_robot in enumerate(self.robot_list): | |
| observations = deciding_robot.observations | |
| # if self.env.node_coords.shape[0] >= self.k_size: | |
| # deciding_robot.save_observations(observations) | |
| ### Forward pass through policy to get next position ### | |
| next_position, action_index = self.select_node(observations) | |
| # if self.env.node_coords.shape[0] >= self.k_size: | |
| # deciding_robot.save_action(action_index) | |
| dist = np.linalg.norm(next_position - deciding_robot.robot_position) | |
| ### Log results of action (e.g. distance travelled) ### | |
| dist_array[robot_id] = dist | |
| dist_list.append(dist) | |
| travel_dist_list.append(deciding_robot.travel_dist) | |
| next_position_list.append(next_position) | |
| self.all_robot_positions[robot_id] = next_position | |
| arriving_sequence = np.argsort(dist_list) | |
| next_position_list = np.array(next_position_list) | |
| dist_list = np.array(dist_list) | |
| travel_dist_list = np.array(travel_dist_list) | |
| next_position_list = next_position_list[arriving_sequence] | |
| dist_list = dist_list[arriving_sequence] | |
| travel_dist_list = travel_dist_list[arriving_sequence] | |
| ### Take Action (Deconflict if 2 agents choose the same target position) ### | |
| next_position_list, dist_list = self.solve_conflict(arriving_sequence, next_position_list, dist_list) | |
| # dist_travelled = np.linalg.norm(next_position - deciding_robot.robot_position) | |
| # deciding_robot.travel_dist += dist_travelled | |
| # deciding_robot.robot_position = next_position | |
| reward_list, done = self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list) | |
| ### Update observations + rewards from action ### | |
| for reward, robot_id in zip(reward_list, arriving_sequence): | |
| robot = self.robot_list[robot_id] | |
| robot.save_trajectory_coords(self.env.find_index_from_coords(robot.robot_position), self.env.num_new_targets_found) | |
| # # TTA Update via Poisson Test (with KMeans clustering stats) | |
| if self.tta and USE_CLIP_PREDS and self.execute_tta: | |
| self.poisson_tta_update(robot, self.global_step, step) | |
| robot.observations = self.get_observations(robot.robot_position) | |
| # if self.env.node_coords.shape[0] >= self.k_size: | |
| robot.save_reward_done(reward, done) | |
| # robot.save_next_observations(robot.observations) | |
| # Update metrics | |
| # NOTE: For 1 robot for now | |
| self.log_metrics(step=step) # robot.targets_found_on_path) | |
| ### Save a frame to generate gif of robot trajectories ### | |
| if self.save_image: | |
| robots_route = [] | |
| for robot in self.robot_list: | |
| robots_route.append([robot.xPoints, robot.yPoints]) | |
| # NOTE: Set to self.gifs_path since using app.py | |
| if not os.path.exists(self.gifs_path): | |
| os.makedirs(self.gifs_path) | |
| sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0] | |
| # NOTE: Replaced since using app.py | |
| if TAXABIND_TTA and USE_CLIP_PREDS: | |
| self.env.plot_heatmap(self.gifs_path, step, max(travel_dist_list), robots_route) | |
| # if TAXABIND_TTA and USE_CLIP_PREDS: | |
| # self.env.plot_env( | |
| # self.global_step, | |
| # gifs_path, | |
| # step, | |
| # max(travel_dist_list), | |
| # robots_route, | |
| # img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st | |
| # sat_path_override=self.clip_seg_tta.imo_path, | |
| # msk_name_override=self.clip_seg_tta.species_name, | |
| # sound_id_override=sound_id_override, | |
| # colormap_mid_val=np.max(self.clip_seg_tta.heatmap_unnormalized_initial) | |
| # ) | |
| # else: | |
| # self.env.plot_env( | |
| # self.global_step, | |
| # gifs_path, | |
| # step, | |
| # max(travel_dist_list), | |
| # robots_route, | |
| # img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st | |
| # sat_path_override=self.clip_seg_tta.imo_path, | |
| # msk_name_override=self.clip_seg_tta.species_name, | |
| # sound_id_override=sound_id_override, | |
| # ) | |
| if done: | |
| break | |
| if self.tta: | |
| tax = Path(self.clip_seg_tta.gt_mask_name).stem | |
| self.perf_metrics['tax'] = " ".join(tax.split("_")[1:]) | |
| self.perf_metrics['tax_first'] = tax.split("_")[1] | |
| else: | |
| self.perf_metrics['tax'] = None | |
| self.perf_metrics['tax_first'] = None | |
| self.perf_metrics['travel_dist'] = max(travel_dist_list) | |
| self.perf_metrics['travel_steps'] = step + 1 | |
| self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt | |
| self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt | |
| self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt | |
| self.perf_metrics['explored_rate'] = self.env.explored_rate | |
| self.perf_metrics['targets_found'] = self.env.targets_found_rate | |
| self.perf_metrics['targets_total'] = len(self.env.target_positions) | |
| self.perf_metrics['sim_0perc'] = self.sim_0perc | |
| self.perf_metrics['sim_25perc'] = self.sim_25perc | |
| self.perf_metrics['sim_50perc'] = self.sim_50perc | |
| self.perf_metrics['sim_75perc'] = self.sim_75perc | |
| self.perf_metrics['sim_100perc'] = self.sim_100perc | |
| if USE_CLIP_PREDS: | |
| self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k | |
| self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score | |
| self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time | |
| self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time | |
| else: | |
| self.perf_metrics['kmeans_k'] = None | |
| self.perf_metrics['tgts_gt_score'] = None | |
| self.perf_metrics['clip_inference_time'] = None | |
| self.perf_metrics['tta_time'] = None | |
| if OVERRIDE_GT_MASK_DIR != "" and os.path.exists(self.tta_gt_seg_path): | |
| self.perf_metrics['info_gain'] = self.env.info_gain | |
| self.perf_metrics['total_info'] = self.env.total_info | |
| else: | |
| self.perf_metrics['info_gain'] = None | |
| self.perf_metrics['total_info'] = None | |
| if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0: | |
| self.perf_metrics['success_rate'] = True | |
| else: | |
| self.perf_metrics['success_rate'] = done | |
| # save gif | |
| if self.save_image: | |
| path = self.gifs_path # NOTE: Set to self.gifs_path since using app.py | |
| self.make_gif(path, curr_episode) | |
| print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC) | |
| def get_observations(self, robot_position): | |
| """ Get robot's sensor observation of environment given position """ | |
| current_node_index = self.env.find_index_from_coords(robot_position) | |
| current_index = torch.tensor([current_node_index]).unsqueeze(0).unsqueeze(0).to(self.device) # (1,1,1) | |
| node_coords = copy.deepcopy(self.env.node_coords) | |
| graph = copy.deepcopy(self.env.graph) | |
| node_utility = copy.deepcopy(self.env.node_utility) | |
| guidepost = copy.deepcopy(self.env.guidepost) | |
| # segmentation_info_mask = copy.deepcopy(self.env.segmentation_info_mask) | |
| segmentation_info_mask = copy.deepcopy(self.env.filtered_seg_info_mask) | |
| # ADDED - SEGMENTATION INFORATION MASK | |
| n_nodes = node_coords.shape[0] | |
| node_coords = node_coords / 640 | |
| node_utility = node_utility / 50 | |
| node_utility_inputs = node_utility.reshape((n_nodes, 1)) | |
| occupied_node = np.zeros((n_nodes, 1)) | |
| for position in self.all_robot_positions: | |
| index = self.env.find_index_from_coords(position) | |
| if index == current_index.item(): | |
| occupied_node[index] = -1 | |
| else: | |
| occupied_node[index] = 1 | |
| # node_inputs = np.concatenate((node_coords, node_utility_inputs, guidepost, occupied_node), axis=1) | |
| node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1) | |
| # node_inputs = np.concatenate((node_coords, segmentation_info_mask, guidepost), axis=1) | |
| node_inputs = torch.FloatTensor(node_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, 3) | |
| # assert node_coords.shape[0] < self.node_padding_size | |
| # padding = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - node_coords.shape[0])) | |
| # node_inputs = padding(node_inputs) | |
| # node_padding_mask = torch.zeros((1, 1, node_coords.shape[0]), dtype=torch.int64).to(self.device) | |
| # node_padding = torch.ones((1, 1, self.node_padding_size - node_coords.shape[0]), dtype=torch.int64).to( | |
| # self.device) | |
| # node_padding_mask = torch.cat((node_padding_mask, node_padding), dim=-1) | |
| node_padding_mask = None | |
| graph = list(graph.values()) | |
| edge_inputs = [] | |
| for node in graph: | |
| node_edges = list(map(int, node)) | |
| edge_inputs.append(node_edges) | |
| bias_matrix = self.calculate_edge_mask(edge_inputs) | |
| edge_mask = torch.from_numpy(bias_matrix).float().unsqueeze(0).to(self.device) | |
| # assert len(edge_inputs) < self.node_padding_size | |
| # padding = torch.nn.ConstantPad2d( | |
| # (0, self.node_padding_size - len(edge_inputs), 0, self.node_padding_size - len(edge_inputs)), 1) | |
| # edge_mask = padding(edge_mask) | |
| # padding2 = torch.nn.ZeroPad2d((0, 0, 0, self.node_padding_size - len(edge_inputs))) | |
| for edges in edge_inputs: | |
| while len(edges) < self.k_size: | |
| edges.append(0) | |
| edge_inputs = torch.tensor(edge_inputs).unsqueeze(0).to(self.device) # (1, node_padding_size+1, k_size) | |
| # edge_inputs = padding2(edge_inputs) | |
| edge_padding_mask = torch.zeros((1, len(edge_inputs), K_SIZE), dtype=torch.int64).to(self.device) | |
| one = torch.ones_like(edge_padding_mask, dtype=torch.int64).to(self.device) | |
| edge_padding_mask = torch.where(edge_inputs == 0, one, edge_padding_mask) | |
| observations = node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask | |
| return observations | |
| def select_node(self, observations): | |
| """ Forward pass through policy to get next position to go to on map """ | |
| node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask = observations | |
| with torch.no_grad(): | |
| logp_list = self.local_policy_net(node_inputs, edge_inputs, current_index, node_padding_mask, edge_padding_mask, edge_mask) | |
| if self.greedy: | |
| action_index = torch.argmax(logp_list, dim=1).long() | |
| else: | |
| action_index = torch.multinomial(logp_list.exp(), 1).long().squeeze(1) | |
| next_node_index = edge_inputs[:, current_index.item(), action_index.item()] | |
| next_position = self.env.node_coords[next_node_index] | |
| return next_position, action_index | |
| def solve_conflict(self, arriving_sequence, next_position_list, dist_list): | |
| """ Deconflict if 2 agents choose the same target position """ | |
| for j, [robot_id, next_position] in enumerate(zip(arriving_sequence, next_position_list)): | |
| moving_robot = self.robot_list[robot_id] | |
| # if next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]: | |
| # dist_to_next_position = np.argsort(np.linalg.norm(self.env.node_coords - next_position, axis=1)) | |
| # k = 0 | |
| # while next_position[0] + next_position[1] * 1j in (next_position_list[:, 0] + next_position_list[:, 1] * 1j)[:j]: | |
| # k += 1 | |
| # next_position = self.env.node_coords[dist_to_next_position[k]] | |
| dist = np.linalg.norm(next_position - moving_robot.robot_position) | |
| next_position_list[j] = next_position | |
| dist_list[j] = dist | |
| moving_robot.travel_dist += dist | |
| moving_robot.robot_position = next_position | |
| return next_position_list, dist_list | |
| def work(self, currEpisode): | |
| ''' | |
| Interacts with the environment. The agent gets either gradients or experience buffer | |
| ''' | |
| self.run_episode(currEpisode) | |
| def calculate_edge_mask(self, edge_inputs): | |
| size = len(edge_inputs) | |
| bias_matrix = np.ones((size, size)) | |
| for i in range(size): | |
| for j in range(size): | |
| if j in edge_inputs[i]: | |
| bias_matrix[i][j] = 0 | |
| return bias_matrix | |
| def make_gif(self, path, n): | |
| """ Generate a gif given list of images """ | |
| with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I', | |
| fps=5) as writer: | |
| for frame in self.env.frame_files: | |
| image = imageio.imread(frame) | |
| writer.append_data(image) | |
| print('gif complete\n') | |
| # Remove files | |
| for filename in self.env.frame_files[:-1]: | |
| os.remove(filename) | |
| # For watershed segmenter gif during TTA | |
| if self.tta: | |
| # print("self.kmeans_clusterer.kmeans_frame_files", self.kmeans_clusterer.kmeans_frame_files) | |
| with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I', | |
| fps=5) as writer: | |
| for frame in self.kmeans_clusterer.kmeans_frame_files: | |
| image = imageio.imread(frame) | |
| writer.append_data(image) | |
| print('Kmeans Clusterer gif complete\n') | |
| # Remove files | |
| for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]: | |
| os.remove(filename) | |
| ################################################################################ | |
| # ADDED | |
| ################################################################################ | |
| def log_metrics(self, step): | |
| # Update tgt found metrics | |
| if self.steps_to_first_tgt is None and self.env.num_targets_found == 1: | |
| self.steps_to_first_tgt = step + 1 | |
| if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2): | |
| self.steps_to_mid_tgt = step + 1 | |
| if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions): | |
| self.steps_to_last_tgt = step + 1 | |
| # Update sim metrics | |
| if OVERRIDE_GT_MASK_DIR != "" and os.path.exists(self.tta_gt_seg_path): | |
| side_dim = int(np.sqrt(self.env.segmentation_info_mask.shape[0])) | |
| pred_mask = self.env.segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T | |
| gt_mask = self.env.gt_segmentation_info_mask.squeeze().reshape((side_dim, side_dim)).T | |
| if step == 0: | |
| self.sim_0perc = self.norm_distance(pred_mask, gt_mask) | |
| elif step == int(NUM_EPS_STEPS * 0.25): | |
| self.sim_25perc = self.norm_distance(pred_mask, gt_mask) | |
| elif step == int(NUM_EPS_STEPS * 0.5): | |
| self.sim_50perc = self.norm_distance(pred_mask, gt_mask) | |
| elif step == int(NUM_EPS_STEPS * 0.75): | |
| self.sim_75perc = self.norm_distance(pred_mask, gt_mask) | |
| elif step == NUM_EPS_STEPS - 1: | |
| self.sim_100perc = self.norm_distance(pred_mask, gt_mask) | |
| def norm_distance(self, P, Q, norm_type="L2"): | |
| # Normalize both grids to [0,1] | |
| try: | |
| if P.max() != P.min(): | |
| P_norm = (P - P.min()) / (P.max() - P.min()) | |
| else: | |
| P_norm = P | |
| if Q.max() != Q.min(): | |
| Q_norm = (Q - Q.min()) / (Q.max() - Q.min()) | |
| else: | |
| Q_norm = Q | |
| except FloatingPointError as e: | |
| print(f"{RED}Caught floating point error:{NC} {e}") | |
| print("P min/max:", P.min(), P.max()) | |
| print("Q min/max:", Q.min(), Q.max()) | |
| print("Q: ", Q) | |
| similarity = None | |
| return similarity | |
| if norm_type == "L1": | |
| num_cells = P.shape[0] * P.shape[1] | |
| # L1 distance: sum of absolute differences | |
| l1_dist = np.sum(np.abs(P_norm - Q_norm)) | |
| # Normalize: maximum L1 distance is num_cells (if every cell differs by 1) | |
| similarity = 1 - (l1_dist / num_cells) | |
| elif norm_type == "L2": | |
| # L2 distance via Root Mean Squared Error (RMSE) | |
| rmse = np.sqrt(np.mean((P_norm - Q_norm)**2)) | |
| # Since both grids are in [0,1], maximum RMSE is 1. | |
| similarity = 1 - rmse | |
| else: | |
| raise ValueError("norm_type must be either 'L1' or 'L2'") | |
| return similarity | |
| def transpose_flat_idx(self, idx, H= NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH): | |
| """ | |
| Given a flattened index X in an NxN matrix, | |
| return the new index X' after transposing the matrix. | |
| """ | |
| row = idx // W | |
| col = idx % W | |
| idx_T = col * H + row | |
| return idx_T | |
| def poisson_tta_update(self, robot, episode, step): | |
| # TODO: Move into TTA loop to save computation | |
| # Generate Kmeans Clusters Stats | |
| visited_indices = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords] | |
| region_stats_dict = self.kmeans_clusterer.compute_region_statistics( | |
| self.kmeans_sat_embeds_clusters, | |
| self.clip_seg_tta.heatmap_unnormalized, | |
| visited_indices, | |
| episode_num=episode, | |
| step_num=step | |
| ) | |
| # Prep & execute TTA | |
| self.step_since_tta += 1 | |
| if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0: # Allow even if no positive at start | |
| # for _ in range(NUM_TTA_STEPS): | |
| filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords] | |
| filt_targets_found_on_path = robot.targets_found_on_path | |
| num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1] | |
| # per_sample_weight_scale = [min(1.0, (step/num_cells)) for _ in filt_traj_coords] | |
| # per_sample_weight_scale = [0.5 * min(1.0, (step/100)) for _ in filt_traj_coords] | |
| # per_sample_weight_scale = [1.0 for _ in filt_traj_coords] | |
| pos_sample_weight_scale, neg_sample_weight_scale = [], [] | |
| for i, sample_loc in enumerate(filt_traj_coords): | |
| label = self.kmeans_clusterer.get_label_id(sample_loc) | |
| num_patches = region_stats_dict[label]['num_patches'] | |
| patches_visited = region_stats_dict[label]['patches_visited'] | |
| expectation = region_stats_dict[label]['expectation'] | |
| ## BEST so far: exponent like focal loss to wait for more samples before confidently decreasing | |
| pos_weight = 4.0 # 2.0 | |
| # pos_weight = 1.0 + 4.0 * min(1.0, (patches_visited/(num_patches))**GAMMA_EXPONENT) # (1,5) | |
| # pos_weight = 1.0 + 4.0 * min(1.0, (patches_visited/(3*expectation))**GAMMA_EXPONENT) | |
| # neg_weight = min(1.0, (patches_visited/(3*expectation))**GAMMA_EXPONENT) | |
| neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT) # (0,1) | |
| pos_sample_weight_scale.append(pos_weight) | |
| neg_sample_weight_scale.append(neg_weight) | |
| ## Prelim throughts (BAD - quickly reduce low probs region even with little samples) | |
| # neg_weight = min(1.0, patches_visited/(3*expectation)) | |
| # local_probs = self.kmeans_clusterer.get_probs(sample_loc, self.clip_seg_tta.heatmap) | |
| # neg_weight = min(local_probs/2, patches_visited/(3*expectation)) # 2*expectation (if don't want TTA scheduler - 3x TTA) | |
| # neg_weight = min(1.0, patches_visited/num_patches) # 2*expectation (if don't want TTA scheduler - 3x TTA) | |
| ## Hacky, but works better (does not decrase low probs region too fast) | |
| # if label == 0: | |
| # neg_weight = min(0.5, patches_visited/(3*expectation)) | |
| # else: | |
| # neg_weight = min(0.05, patches_visited/(3*expectation)) | |
| # squared | |
| # # # Adaptative LR (as samples increase, increase LR to fit more datapoints - else won't update) | |
| adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells) | |
| # print("!!! adaptive_lr", adaptive_lr) | |
| # adaptive_lr = 2e-6 | |
| # NOTE: Not as good as adaptive LR (cos discrete) | |
| # # Num TTA teps schedulerq | |
| # min_tta_steps = 3 | |
| # max_tta_steps = 10 | |
| # num_tta_steps = int((max_tta_steps - min_tta_steps) * (step / num_cells) + min_tta_steps) | |
| # print("!!! num_tta_steps", num_tta_steps) | |
| # TTA Update | |
| # NOTE: updated due to app.py (hf does not allow heatmap to persist) | |
| heatmap = self.clip_seg_tta.execute_tta( | |
| filt_traj_coords, | |
| filt_targets_found_on_path, | |
| tta_steps=NUM_TTA_STEPS, | |
| lr=adaptive_lr, | |
| pos_sample_weight=pos_sample_weight_scale, | |
| neg_sample_weight=neg_sample_weight_scale, | |
| modality=MODALITY, | |
| query_variety=QUERY_VARIETY, | |
| target_found_idxs=self.env.target_found_idxs, | |
| reset_weights=RESET_WEIGHTS | |
| ) | |
| self.clip_seg_tta.heatmap = heatmap | |
| self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1) | |
| self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1) | |
| self.step_since_tta = 0 | |
| ################################################################################ | |
| # if def main | |
| if __name__ == "__main__": | |
| # CHANGE ME! | |
| currEpisode = 0 | |
| # Prepare the model | |
| # device = torch.device('cpu') #if USE_GPU_TRAINING else torch.device('cpu') | |
| device = torch.device('cuda') if USE_GPU else torch.device('cpu') | |
| policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device) | |
| # script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| script_dir = Path(__file__).resolve().parent | |
| print("real_script_dir: ", script_dir) | |
| # checkpoint = torch.load(f'{script_dir}/modules/vlm_search/{model_path}/{MODEL_NAME}') | |
| checkpoint = torch.load(f'{model_path}/{MODEL_NAME}') | |
| policy_net.load_state_dict(checkpoint['policy_model']) | |
| print('Model loaded!') | |
| # print(next(policy_net.parameters()).device) | |
| # Init Taxabind here (only need to init once) | |
| if TAXABIND_TTA: | |
| # self.clip_seg_tta = None | |
| clip_seg_tta = ClipSegTTA( | |
| img_dir=TAXABIND_IMG_DIR, | |
| imo_dir=TAXABIND_IMO_DIR, | |
| json_path=TAXABIND_INAT_JSON_PATH, | |
| sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH, | |
| patch_size=TAXABIND_PATCH_SIZE, | |
| sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH, | |
| sample_index = 0, # Set using 'reset' in worker | |
| blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL, | |
| device=device, | |
| sat_to_img_ids_json_is_train_dict=False, # for search ds val | |
| tax_to_filter_val=QUERY_TAX, | |
| load_model=USE_CLIP_PREDS, | |
| initial_modality=INITIAL_MODALITY, | |
| sound_data_path = TAXABIND_SOUND_DATA_PATH, | |
| sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH, | |
| # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH, | |
| ) | |
| print("ClipSegTTA Loaded!") | |
| else: | |
| clip_seg_tta = None | |
| # Define TestWorker | |
| planner = TestWorker( | |
| meta_agent_id=0, | |
| n_agent=1, | |
| policy_net=policy_net, | |
| global_step=3, | |
| device='cuda', | |
| greedy=True, | |
| save_image=SAVE_GIFS, | |
| clip_seg_tta=clip_seg_tta | |
| ) | |
| planner.run_episode(currEpisode) | |
| print("planner.perf_metrics: ", planner.perf_metrics) |