Spaces:
Sleeping
Sleeping
| """ | |
| Validation script | |
| """ | |
| import math | |
| import os | |
| import pandas as pd | |
| import csv | |
| import shutil | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms.functional as F | |
| from torch.utils.data import DataLoader | |
| import torch.backends.cudnn as cudnn | |
| import numpy as np | |
| import time | |
| import matplotlib.pyplot as plt | |
| from models.ProtoSAM import ProtoSAM, ALPNetWrapper, SamWrapperWrapper, InputFactory, ModelWrapper, TYPE_ALPNET, TYPE_SAM | |
| from models.ProtoMedSAM import ProtoMedSAM | |
| from models.grid_proto_fewshot import FewShotSeg | |
| from models.segment_anything.utils.transforms import ResizeLongestSide | |
| from models.SamWrapper import SamWrapper | |
| # from dataloaders.PolypDataset import get_polyp_dataset, get_vps_easy_unseen_dataset, get_vps_hard_unseen_dataset, PolypDataset, KVASIR, CVC300, COLON_DB, ETIS_DB, CLINIC_DB | |
| from dataloaders.PolypDataset import get_polyp_dataset, PolypDataset | |
| from dataloaders.PolypTransforms import get_polyp_transform | |
| from dataloaders.SimpleDataset import SimpleDataset | |
| from dataloaders.ManualAnnoDatasetv2 import get_nii_dataset | |
| from dataloaders.common import ValidationDataset | |
| from config_ssl_upload import ex | |
| import tqdm | |
| from tqdm.auto import tqdm | |
| import cv2 | |
| from collections import defaultdict | |
| # config pre-trained model caching path | |
| os.environ['TORCH_HOME'] = "./pretrained_model" | |
| # Supported Datasets | |
| CHAOS = "chaos" | |
| SABS = "sabs" | |
| POLYPS = "polyps" | |
| ALP_DS = [CHAOS, SABS] | |
| ROT_DEG = 0 | |
| def get_bounding_box(segmentation_map): | |
| """Generate bounding box from a segmentation map. one bounding box to include the extreme points of the segmentation map.""" | |
| if isinstance(segmentation_map, torch.Tensor): | |
| segmentation_map = segmentation_map.cpu().numpy() | |
| bbox = cv2.boundingRect(segmentation_map.astype(np.uint8)) | |
| # plot bounding boxes for each contours | |
| # plt.figure() | |
| # x, y, w, h = bbox | |
| # plt.imshow(segmentation_map) | |
| # plt.gca().add_patch(plt.Rectangle((x, y), w, h, fill=False, edgecolor='r', linewidth=2)) | |
| # plt.savefig("debug/bounding_boxes.png") | |
| return bbox | |
| def calc_iou(boxA, boxB): | |
| """ | |
| boxA: [x, y, w, h] | |
| """ | |
| xA = max(boxA[0], boxB[0]) | |
| yA = max(boxA[1], boxB[1]) | |
| xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2]) | |
| yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3]) | |
| interArea = max(0, xB - xA) * max(0, yB - yA) | |
| boxAArea = boxA[2] * boxA[3] | |
| boxBArea = boxB[2] * boxB[3] | |
| iou = interArea / float(boxAArea + boxBArea - interArea) | |
| return iou | |
| def eval_detection(pred_list): | |
| """ | |
| pred_list: list of dictionaries with keys 'pred_bbox', 'gt_bbox' and score (prediction confidence score). | |
| compute AP50, AP75, AP50:95:10 | |
| """ | |
| iou_thresholds = np.round(np.arange(0.5, 1.0, 0.05), 2) | |
| ap_dict = {iou: [] for iou in iou_thresholds} | |
| for iou_threshold in iou_thresholds: | |
| tp, fp = 0, 0 | |
| for pred in pred_list: | |
| pred_bbox = pred['pred_bbox'] | |
| gt_bbox = pred['gt_bbox'] | |
| iou = calc_iou(pred_bbox, gt_bbox) | |
| if iou >= iou_threshold: | |
| tp += 1 | |
| else: | |
| fp += 1 | |
| precision = tp / (tp + fp) | |
| recall = tp / len(pred_list) | |
| f1 = 2 * (precision * recall) / (precision + recall) | |
| ap_dict[iou_threshold] = { | |
| 'iou_threshold': iou_threshold, | |
| 'tp': tp, | |
| 'fp': fp, | |
| 'n_gt': len(pred_list), | |
| 'f1': f1, | |
| 'precision': precision, | |
| 'recall': recall | |
| } | |
| # Convert results to a DataFrame and save to CSV | |
| results = [] | |
| for iou_threshold in iou_thresholds: | |
| results.append(ap_dict[iou_threshold]) | |
| df = pd.DataFrame(results) | |
| return df | |
| def plot_pred_gt_support(query_image, pred, gt, support_images, support_masks, score=None, save_path="debug/pred_vs_gt"): | |
| """ | |
| Save 5 key images: support images, support mask, query, ground truth and prediction. | |
| Handles both grayscale and RGB images consistently with the same mask color. | |
| Args: | |
| query_image: Query image tensor (grayscale or RGB) | |
| pred: 2d tensor where 1 represents foreground and 0 represents background | |
| gt: 2d tensor where 1 represents foreground and 0 represents background | |
| support_images: Support image tensors (grayscale or RGB) | |
| support_masks: Support mask tensors | |
| score: Optional score to add to filename | |
| save_path: Base path without extension for saving images | |
| """ | |
| # Create directory for this case | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| # Process query image - ensure HxWxC format for visualization | |
| query_image = query_image.clone().detach().cpu() | |
| if len(query_image.shape) == 3 and query_image.shape[0] <= 3: # CHW format | |
| query_image = query_image.permute(1, 2, 0) | |
| # Handle grayscale vs RGB consistently | |
| if len(query_image.shape) == 2 or (len(query_image.shape) == 3 and query_image.shape[2] == 1): | |
| # For grayscale, use cmap='gray' for visualization | |
| is_grayscale = True | |
| if len(query_image.shape) == 3: | |
| query_image = query_image.squeeze(2) # Remove channel dimension for grayscale | |
| else: | |
| is_grayscale = False | |
| # Normalize image for visualization | |
| query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min() + 1e-8) | |
| # Convert pred and gt to numpy for visualization | |
| pred_np = pred.cpu().float().numpy() # Ensure float before converting to numpy | |
| gt_np = gt.cpu().float().numpy() # Ensure float before converting to numpy | |
| # Ensure binary masks | |
| pred_np = (pred_np > 0).astype(np.float32) | |
| gt_np = (gt_np > 0).astype(np.float32) | |
| # Set all positive values to 1.0 to ensure consistent red coloring in YlOrRd colormap | |
| pred_np[pred_np > 0] = 1.0 | |
| gt_np[gt_np > 0] = 1.0 | |
| # Create colormap for mask overlays - using the YlOrRd colormap as requested | |
| mask_cmap = plt.cm.get_cmap('YlOrRd') | |
| # Generate color masks with alpha values | |
| pred_rgba = mask_cmap(pred_np) | |
| pred_rgba[..., 3] = pred_np * 0.7 # Last channel is alpha - semitransparent where mask=1 | |
| gt_rgba = mask_cmap(gt_np) | |
| gt_rgba[..., 3] = gt_np * 0.7 # Last channel is alpha - semitransparent where mask=1 | |
| # 1. Save query image (original) | |
| plt.figure(figsize=(10, 10)) | |
| if is_grayscale: | |
| plt.imshow(query_image, cmap='gray') | |
| else: | |
| plt.imshow(query_image) | |
| plt.axis('off') | |
| # Remove padding/whitespace | |
| plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) | |
| plt.savefig(f"{save_path}/query.png", bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| # 2. Save query image with prediction overlay | |
| plt.figure(figsize=(10, 10)) | |
| if is_grayscale: | |
| plt.imshow(query_image, cmap='gray') | |
| else: | |
| plt.imshow(query_image) | |
| plt.imshow(pred_rgba) | |
| plt.axis('off') | |
| # Remove padding/whitespace | |
| plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) | |
| plt.savefig(f"{save_path}/pred.png", bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| # 3. Save query image with ground truth overlay | |
| plt.figure(figsize=(10, 10)) | |
| if is_grayscale: | |
| plt.imshow(query_image, cmap='gray') | |
| else: | |
| plt.imshow(query_image) | |
| plt.imshow(gt_rgba) | |
| plt.axis('off') | |
| # Remove padding/whitespace | |
| plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) | |
| plt.savefig(f"{save_path}/gt.png", bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| # Process and save support images and masks (just the first one for brevity) | |
| if support_images is not None: | |
| if isinstance(support_images, list): | |
| support_images = torch.cat(support_images, dim=0).clone().detach() | |
| if isinstance(support_masks, list): | |
| support_masks = torch.cat(support_masks, dim=0).clone().detach() | |
| # Move to CPU for processing | |
| support_images = support_images.cpu() | |
| support_masks = support_masks.cpu() | |
| # Handle different dimensions of support images | |
| if len(support_images.shape) == 4: # NCHW format | |
| # Convert to NHWC for visualization | |
| support_images = support_images.permute(0, 2, 3, 1) | |
| # Just process the first support image | |
| i = 0 | |
| if support_images.shape[0] > 0: | |
| support_img = support_images[i].clone() | |
| support_mask = support_masks[i].clone() | |
| # Check if grayscale or RGB | |
| if support_img.shape[-1] == 1: # Last dimension is channels | |
| support_img = support_img.squeeze(-1) # Remove channel dimension | |
| support_is_gray = True | |
| elif support_img.shape[-1] == 3: | |
| support_is_gray = False | |
| else: # Assume it's grayscale if not 1 or 3 channels | |
| support_is_gray = True | |
| # Normalize support image | |
| support_img = (support_img - support_img.min()) / (support_img.max() - support_img.min() + 1e-8) | |
| # 4. Save support image only | |
| plt.figure(figsize=(10, 10)) | |
| if support_is_gray: | |
| plt.imshow(support_img, cmap='gray') | |
| else: | |
| plt.imshow(support_img) | |
| plt.axis('off') | |
| # Remove padding/whitespace | |
| plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) | |
| plt.savefig(f"{save_path}/support_1.png", bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| # 5. Save support mask only (direct mask visualization similar to gt/pred) | |
| plt.figure(figsize=(10, 10)) | |
| # Process support mask with same approach | |
| support_mask_np = support_mask.cpu().float().numpy() | |
| support_mask_np = (support_mask_np > 0).astype(np.float32) | |
| support_mask_np[support_mask_np > 0] = 1.0 # Set to 1.0 for consistent coloring | |
| support_mask_rgba = mask_cmap(support_mask_np) | |
| support_mask_rgba[..., 3] = support_mask_np * 0.7 # Last channel is alpha - semitransparent where mask=1 | |
| if is_grayscale: | |
| plt.imshow(support_img, cmap='gray') | |
| else: | |
| plt.imshow(support_img) | |
| plt.imshow(support_mask_rgba) | |
| plt.axis('off') | |
| # Remove padding/whitespace | |
| plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) | |
| plt.savefig(f"{save_path}/support_mask.png", bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| def get_dice_iou_precision_recall(pred: torch.Tensor, gt: torch.Tensor): | |
| """ | |
| pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background | |
| gt: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background | |
| """ | |
| if gt.sum() == 0: | |
| print("gt is all background") | |
| return {"dice": 0, "precision": 0, "recall": 0} | |
| # Resize pred to match gt dimensions if they're different | |
| if pred.shape != gt.shape: | |
| print(f"Resizing prediction from {pred.shape} to match ground truth {gt.shape}") | |
| # Use interpolate to resize pred to match gt dimensions | |
| pred = torch.nn.functional.interpolate( | |
| pred.unsqueeze(0).unsqueeze(0).float(), | |
| size=gt.shape, | |
| mode='nearest' | |
| ).squeeze(0).squeeze(0) | |
| tp = (pred * gt).sum() | |
| fp = (pred * (1 - gt)).sum() | |
| fn = ((1 - pred) * gt).sum() | |
| dice = 2 * tp / (2 * tp + fp + fn + 1e-8) | |
| precision = tp / (tp + fp + 1e-8) | |
| recall = tp / (tp + fn + 1e-8) | |
| iou = tp / (tp + fp + fn + 1e-8) | |
| return {"dice": dice, "iou": iou, "precision": precision, "recall": recall} | |
| def get_alpnet_model(_config) -> ModelWrapper: | |
| alpnet = FewShotSeg( | |
| _config["input_size"][0], | |
| _config["reload_model_path"], | |
| _config["model"] | |
| ) | |
| alpnet.cuda() | |
| alpnet_wrapper = ALPNetWrapper(alpnet) | |
| return alpnet_wrapper | |
| def get_sam_model(_config) -> ModelWrapper: | |
| sam_args = { | |
| "model_type": "vit_h", | |
| "sam_checkpoint": "pretrained_model/sam_vit_h.pth" | |
| } | |
| sam = SamWrapper(sam_args=sam_args).cuda() | |
| sam_wrapper = SamWrapperWrapper(sam) | |
| return sam_wrapper | |
| def get_model(_config) -> ProtoSAM: | |
| # Initial Segmentation Model | |
| if _config["base_model"] == TYPE_ALPNET: | |
| base_model = get_alpnet_model(_config) | |
| else: | |
| raise NotImplementedError(f"base model {_config['base_model']} not implemented") | |
| # ProtoSAM model | |
| if _config["protosam_sam_ver"] in ("sam_h", "sam_b"): | |
| sam_h_checkpoint = "pretrained_model/sam_vit_h.pth" | |
| sam_b_checkpoint = "pretrained_model/sam_vit_b.pth" | |
| sam_checkpoint = sam_h_checkpoint if _config["protosam_sam_ver"] == "sam_h" else sam_b_checkpoint | |
| model = ProtoSAM(image_size = (1024, 1024), | |
| coarse_segmentation_model=base_model, | |
| use_bbox=_config["use_bbox"], | |
| use_points=_config["use_points"], | |
| use_mask=_config["use_mask"], | |
| debug=_config["debug"], | |
| num_points_for_sam=1, | |
| use_cca=_config["do_cca"], | |
| point_mode=_config["point_mode"], | |
| use_sam_trans=True, | |
| coarse_pred_only=_config["coarse_pred_only"], | |
| sam_pretrained_path=sam_checkpoint, | |
| use_neg_points=_config["use_neg_points"],) | |
| elif _config["protosam_sam_ver"] == "medsam": | |
| model = ProtoMedSAM(image_size = (1024, 1024), | |
| coarse_segmentation_model=base_model, | |
| debug=_config["debug"], | |
| use_cca=_config["do_cca"], | |
| ) | |
| else: | |
| raise NotImplementedError(f"protosam_sam_ver {_config['protosam_sam_ver']} not implemented") | |
| return model | |
| def get_support_set_polyps(_config, dataset:PolypDataset): | |
| n_support = _config["n_support"] | |
| (support_images, support_labels, case) = dataset.get_support(n_support=n_support) | |
| return support_images, support_labels, case | |
| def get_support_set_alpds(config, dataset:ValidationDataset): | |
| support_set = dataset.get_support_set(config) | |
| support_fg_masks = support_set["support_labels"] | |
| support_images = support_set["support_images"] | |
| support_scan_id = support_set["support_scan_id"] | |
| return support_images, support_fg_masks, support_scan_id | |
| def get_support_set(_config, dataset): | |
| if _config["dataset"].lower() == POLYPS: | |
| support_images, support_fg_masks, case = get_support_set_polyps(_config, dataset) | |
| elif any(item in _config["dataset"].lower() for item in ALP_DS): | |
| support_images, support_fg_masks, support_scan_id = get_support_set_alpds(_config, dataset) | |
| else: | |
| raise NotImplementedError(f"dataset {_config['dataset']} not implemented") | |
| return support_images, support_fg_masks, support_scan_id | |
| def update_support_set_by_scan_part(support_images, support_labels, qpart): | |
| qpart_support_images = [support_images[qpart]] | |
| qpart_support_labels = [support_labels[qpart]] | |
| return qpart_support_images, qpart_support_labels | |
| def manage_support_sets(sample_batched, all_support_images, all_support_fg_mask, support_images, support_fg_mask, qpart=None): | |
| if sample_batched['part_assign'][0] != qpart: | |
| qpart = sample_batched['part_assign'][0] | |
| support_images, support_fg_mask = update_support_set_by_scan_part(all_support_images, all_support_fg_mask, qpart) | |
| return support_images, support_fg_mask, qpart | |
| def main(_run, _config, _log): | |
| if _run.observers: | |
| os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True) | |
| for source_file, _ in _run.experiment_info['sources']: | |
| os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), | |
| exist_ok=True) | |
| _run.observers[0].save_file(source_file, f'source/{source_file}') | |
| print(f"####### created dir:{_run.observers[0].dir} #######") | |
| shutil.rmtree(f'{_run.observers[0].basedir}/_sources') | |
| print(f"config do_cca: {_config['do_cca']}, use_bbox: {_config['use_bbox']}") | |
| cudnn.enabled = True | |
| cudnn.benchmark = True | |
| torch.cuda.set_device(device=_config['gpu_id']) | |
| torch.set_num_threads(1) | |
| _log.info(f'###### Reload model {_config["reload_model_path"]} ######') | |
| print(f'###### Reload model {_config["reload_model_path"]} ######') | |
| model = get_model(_config) | |
| model = model.to(torch.device("cuda")) | |
| model.eval() | |
| sam_trans = ResizeLongestSide(1024) | |
| if _config["dataset"].lower() == POLYPS: | |
| tr_dataset, te_dataset = get_polyp_dataset(sam_trans=sam_trans, image_size=(1024, 1024)) | |
| elif CHAOS in _config["dataset"].lower() or SABS in _config["dataset"].lower(): | |
| tr_dataset, te_dataset = get_nii_dataset(_config, _config["input_size"][0]) | |
| else: | |
| raise NotImplementedError( | |
| f"dataset {_config['dataset']} not implemented") | |
| # dataloaders | |
| testloader = DataLoader( | |
| te_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=1, | |
| pin_memory=False, | |
| drop_last=False | |
| ) | |
| _log.info('###### Starting validation ######') | |
| model.eval() | |
| mean_dice = [] | |
| mean_prec = [] | |
| mean_rec = [] | |
| mean_iou = [] | |
| mean_dice_cases = {} | |
| mean_iou_cases = {} | |
| bboxes_w_scores = [] | |
| curr_case = None | |
| supp_fts = None | |
| qpart = None | |
| support_images = support_fg_mask = None | |
| all_support_images, all_support_fg_mask, support_scan_id = None, None, None | |
| MAX_SUPPORT_IMAGES = 1 | |
| is_alp_ds = any(item in _config["dataset"].lower() for item in ALP_DS) | |
| is_polyp_ds = _config["dataset"].lower() == POLYPS | |
| if is_alp_ds: | |
| all_support_images, all_support_fg_mask, support_scan_id = get_support_set(_config, te_dataset) | |
| elif is_polyp_ds: | |
| support_images, support_fg_mask, case = get_support_set_polyps(_config, tr_dataset) | |
| with tqdm(testloader) as pbar: | |
| for idx, sample_batched in enumerate(tqdm(testloader)): | |
| case = sample_batched['case'][0] | |
| if is_alp_ds: | |
| support_images, support_fg_mask, qpart = manage_support_sets( | |
| sample_batched, | |
| all_support_images, | |
| all_support_fg_mask, | |
| support_images, | |
| support_fg_mask, | |
| qpart, | |
| ) | |
| if is_alp_ds and sample_batched["scan_id"][0] in support_scan_id: | |
| continue | |
| query_images = sample_batched['image'].cuda() | |
| query_labels = torch.cat([sample_batched['label']], dim=0) | |
| if not 1 in query_labels and _config["skip_no_organ_slices"]: | |
| continue | |
| n_try = 1 | |
| with torch.no_grad(): | |
| coarse_model_input = InputFactory.create_input( | |
| input_type=_config["base_model"], | |
| query_image=query_images, | |
| support_images=support_images, | |
| support_labels=support_fg_mask, | |
| isval=True, | |
| val_wsize=_config["val_wsize"], | |
| original_sz=query_images.shape[-2:], | |
| img_sz=query_images.shape[-2:], | |
| gts=query_labels, | |
| ) | |
| coarse_model_input.to(torch.device("cuda")) | |
| query_pred, scores = model( | |
| query_images, coarse_model_input, degrees_rotate=0) | |
| query_pred = query_pred.cpu().detach() | |
| if _config["debug"]: | |
| if is_alp_ds: | |
| save_path = f'debug/preds/{case}_{sample_batched["z_id"].item()}_{idx}_{n_try}' | |
| os.makedirs(save_path, exist_ok=True) | |
| elif is_polyp_ds: | |
| save_path = f'debug/preds/{case}_{idx}_{n_try}' | |
| os.makedirs(save_path, exist_ok=True) | |
| plot_pred_gt_support(query_images[0,0].cpu(), query_pred.cpu(), query_labels[0].cpu(), | |
| support_images, support_fg_mask, save_path=save_path, score=scores[0]) | |
| # print(query_pred.shape) | |
| # print(query_labels[0].shape) | |
| metrics = get_dice_iou_precision_recall( | |
| query_pred, query_labels[0].to(query_pred.device)) | |
| mean_dice.append(metrics["dice"]) | |
| mean_prec.append(metrics["precision"]) | |
| mean_rec.append(metrics["recall"]) | |
| mean_iou.append(metrics["iou"]) | |
| bboxes_w_scores.append({"pred_bbox": get_bounding_box(query_pred.cpu()), | |
| "gt_bbox": get_bounding_box(query_labels[0].cpu()), | |
| "score": np.mean(scores)}) | |
| if case not in mean_dice_cases: | |
| mean_dice_cases[case] = [] | |
| mean_iou_cases[case] = [] | |
| mean_dice_cases[case].append(metrics["dice"]) | |
| mean_iou_cases[case].append(metrics["iou"]) | |
| if metrics["dice"] < 0.6 and _config["debug"]: | |
| path = f'{_run.observers[0].dir}/bad_preds/case_{case}_idx_{idx}_dice_{metrics["dice"]:.4f}' | |
| if _config["debug"]: | |
| path = f'debug/bad_preds/case_{case}_idx_{idx}_dice_{metrics["dice"]:.4f}' | |
| os.makedirs(path, exist_ok=True) | |
| print(f"saving bad prediction to {path}") | |
| plot_pred_gt_support(query_images[0,0].cpu(), query_pred.cpu(), query_labels[0].cpu( | |
| ), support_images, support_fg_mask, save_path=path, score=scores[0]) | |
| pbar.set_postfix_str({"mdice": f"{np.mean(mean_dice):.4f}", "miou": f"{np.mean(mean_iou):.4f}, n_try: {n_try}"}) | |
| for k in mean_dice_cases.keys(): | |
| _run.log_scalar(f'mar_val_batches_meanDice_{k}', np.mean(mean_dice_cases[k])) | |
| _run.log_scalar(f'mar_val_batches_meanIOU_{k}', np.mean(mean_iou_cases[k])) | |
| _log.info(f'mar_val batches meanDice_{k}: {np.mean(mean_dice_cases[k])}') | |
| _log.info(f'mar_val batches meanIOU_{k}: {np.mean(mean_iou_cases[k])}') | |
| # write validation result to log file | |
| m_meanDice = np.mean(mean_dice) | |
| m_meanPrec = np.mean(mean_prec) | |
| m_meanRec = np.mean(mean_rec) | |
| m_meanIOU = np.mean(mean_iou) | |
| _run.log_scalar('mar_val_batches_meanDice', m_meanDice) | |
| _run.log_scalar('mar_val_batches_meanPrec', m_meanPrec) | |
| _run.log_scalar('mar_val_al_batches_meanRec', m_meanRec) | |
| _run.log_scalar('mar_val_al_batches_meanIOU', m_meanIOU) | |
| _log.info(f'mar_val batches meanDice: {m_meanDice}') | |
| _log.info(f'mar_val batches meanPrec: {m_meanPrec}') | |
| _log.info(f'mar_val batches meanRec: {m_meanRec}') | |
| _log.info(f'mar_val batches meanIOU: {m_meanIOU}') | |
| print("============ ============") | |
| _log.info(f'End of validation') | |
| return 1 |