Spaces:
Sleeping
Sleeping
import warnings | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from models.grid_proto_fewshot import FewShotSeg | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
from models.SamWrapper import SamWrapper | |
from util.utils import cca, get_connected_components, rotate_tensor_no_crop, reverse_tensor, get_confidence_from_logits | |
from util.lora import inject_trainable_lora | |
from models.segment_anything.utils.transforms import ResizeLongestSide | |
import cv2 | |
import time | |
from abc import ABC, abstractmethod | |
CONF_MODE="conf" | |
CENTROID_MODE="centroid" | |
BOTH_MODE="both" | |
POINT_MODES=(CONF_MODE, CENTROID_MODE, BOTH_MODE) | |
TYPE_ALPNET="alpnet" | |
TYPE_SAM="sam" | |
def plot_connected_components(cca_output, original_image, confidences:dict=None, title="debug/connected_components.png"): | |
num_labels, labels, stats, centroids = cca_output | |
# Create an output image with random colors for each component | |
output_image = np.zeros((labels.shape[0], labels.shape[1], 3), np.uint8) | |
for label in range(1, num_labels): # Start from 1 to skip the background | |
mask = labels == label | |
output_image[mask] = np.random.randint(0, 255, size=3) | |
# Plotting the original and the colored components image | |
plt.figure(figsize=(10, 5)) | |
plt.subplot(121), plt.imshow(original_image), plt.title('Original Image') | |
plt.subplot(122), plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)), plt.title('Connected Components') | |
if confidences is not None: | |
# Plot the axes color chart with the confidences, use the same colors as the connected components | |
plt.subplot(122) | |
scatter = plt.scatter(centroids[:, 0], centroids[:, 1], c=list(confidences.values()), cmap='jet') | |
plt.colorbar(scatter) | |
plt.savefig(title) | |
plt.close() | |
class SegmentationInput(ABC): | |
def set_query_images(self, query_images): | |
pass | |
def to(self, device): | |
pass | |
class SegmentationOutput(ABC): | |
def get_prediction(self): | |
pass | |
class ALPNetInput(SegmentationInput): # for alpnet | |
def __init__(self, support_images:list, support_labels:list, query_images:torch.Tensor, isval, val_wsize, show_viz=False, supp_fts=None): | |
self.supp_imgs = [support_images] | |
self.fore_mask = [support_labels] | |
self.back_mask = [[1 - sup_labels for sup_labels in support_labels]] | |
self.qry_imgs = [query_images] | |
self.isval = isval | |
self.val_wsize = val_wsize | |
self.show_viz = show_viz | |
self.supp_fts = supp_fts | |
def set_query_images(self, query_images): | |
self.qry_imgs = [query_images] | |
def to(self, device): | |
self.supp_imgs = [[supp_img.to(device) for way in self.supp_imgs for supp_img in way]] | |
self.fore_mask = [[fore_mask.to(device) for way in self.fore_mask for fore_mask in way]] | |
self.back_mask = [[back_mask.to(device) for way in self.back_mask for back_mask in way]] | |
self.qry_imgs = [qry_img.to(device) for qry_img in self.qry_imgs] | |
if self.supp_fts is not None: | |
self.supp_fts = self.supp_fts.to(device) | |
class ALPNetOutput(SegmentationOutput): | |
def __init__(self, pred, align_loss, sim_maps, assign_maps, proto_grid, supp_fts, qry_fts): | |
self.pred = pred | |
self.align_loss = align_loss | |
self.sim_maps = sim_maps | |
self.assign_maps = assign_maps | |
self.proto_grid = proto_grid | |
self.supp_fts = supp_fts | |
self.qry_fts = qry_fts | |
def get_prediction(self): | |
return self.pred | |
class SAMWrapperInput(SegmentationInput): | |
def __init__(self, image, image_labels): | |
self.image = image | |
self.image_labels = image_labels | |
def set_query_images(self, query_images): | |
B, C, H, W = query_images.shape | |
if isinstance(query_images, torch.Tensor): | |
query_images = query_images.cpu().detach().numpy() | |
assert B == 1, "batch size must be 1" | |
query_images = (query_images - query_images.min()) / (query_images.max() - query_images.min()) * 255 | |
query_images = query_images.astype(np.uint8) | |
self.image = np.transpose(query_images[0], (1, 2, 0)) | |
def to(self, device): | |
pass | |
class InputFactory(ABC): | |
def create_input(input_type, query_image, support_images=None, support_labels=None, isval=False, val_wsize=None, show_viz=False, supp_fts=None, original_sz=None, img_sz=None, gts=None): | |
if input_type == TYPE_ALPNET: | |
return ALPNetInput(support_images, support_labels, query_image, isval, val_wsize, show_viz, supp_fts) | |
elif input_type == TYPE_SAM: | |
qimg = np.array(query_image.detach().cpu()) | |
B,C,H,W = qimg.shape | |
assert B == 1, "batch size must be 1" | |
gts = np.array(gts.detach().cpu()).astype(np.uint8).reshape(H,W) | |
assert np.unique(gts).shape[0] <= 2, "support labels must be binary" | |
gts[gts > 0] = 1 | |
qimg = qimg.reshape(H,W,C) | |
qimg = (qimg - qimg.min()) / (qimg.max() - qimg.min()) * 255 | |
qimg = qimg.astype(np.uint8) | |
return SAMWrapperInput(qimg, gts) | |
else: | |
raise ValueError(f"input_type not supported") | |
class ModelWrapper(ABC): | |
def __init__(self, model): | |
self.model = model | |
def __call__(self, input_data: SegmentationInput)->SegmentationOutput: | |
pass | |
def state_dict(self): | |
return self.model.state_dict() | |
def load_state_dict(self, state_dict): | |
self.model.load_state_dict(state_dict) | |
def eval(self): | |
self.model.eval() | |
def train(self): | |
self.model.train() | |
def parameters(self): | |
pass | |
class ALPNetWrapper(ModelWrapper): | |
def __init__(self, model: FewShotSeg): | |
super().__init__(model) | |
def __call__(self, input_data: ALPNetInput): | |
output = self.model(**input_data.__dict__) | |
output = ALPNetOutput(*output) | |
return output.pred | |
def parameters(self): | |
return self.model.encoder.parameters() | |
def train(self): | |
self.model.encoder.train() | |
class SamWrapperWrapper(ModelWrapper): | |
def __init__(self, model:SamWrapper): | |
super().__init__(model) | |
def __call__(self, input_data: SAMWrapperInput): | |
pred = self.model(**input_data.__dict__) | |
# make pred look like logits | |
pred = torch.tensor(pred).float()[None, None, ...] | |
pred = torch.cat([1-pred, pred], dim=1) | |
return pred | |
def to(self, device): | |
self.model.sam.to(device) | |
class ProtoSAM(nn.Module): | |
def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/sam_default.pth", num_points_for_sam=1, use_points=True, use_bbox=False, use_mask=False, debug=False, use_cca=False, point_mode=CONF_MODE, use_sam_trans=True, coarse_pred_only=False, alpnet_image_size=None, use_neg_points=False, ): | |
super().__init__() | |
if isinstance(image_size, int): | |
image_size = (image_size, image_size) | |
self.image_size = image_size | |
self.coarse_segmentation_model = coarse_segmentation_model | |
self.get_sam(sam_pretrained_path, use_sam_trans) | |
self.num_points_for_sam = num_points_for_sam | |
self.use_points = use_points | |
self.use_bbox = use_bbox # if False then uses points | |
self.use_mask = use_mask | |
self.use_neg_points = use_neg_points | |
assert self.use_bbox or self.use_points or self.use_mask, "must use at least one of bbox, points, or mask" | |
self.use_cca = use_cca | |
self.point_mode = point_mode | |
if self.point_mode not in POINT_MODES: | |
raise ValueError(f"point mode must be one of {POINT_MODES}") | |
self.debug=debug | |
self.coarse_pred_only = coarse_pred_only | |
def get_sam(self, checkpoint_path, use_sam_trans): | |
model_type="vit_b" # TODO make generic? | |
if 'vit_h' in checkpoint_path: | |
model_type = "vit_h" | |
self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval() | |
self.predictor = SamPredictor(self.sam) | |
self.sam.requires_grad_(False) | |
if use_sam_trans: | |
# sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size, pixel_mean=[0], pixel_std=[1]) | |
sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size) | |
sam_trans.pixel_mean = torch.tensor([0, 0, 0]).view(3, 1, 1) | |
sam_trans.pixel_std = torch.tensor([1, 1, 1]).view(3, 1, 1) | |
else: | |
sam_trans = None | |
self.sam_trans = sam_trans | |
def get_bbox(self, pred): | |
''' | |
pred tensor of shape (H, W) where 1 represents foreground and 0 represents background | |
returns a list of 2d points representing the bbox | |
''' | |
if isinstance(pred, np.ndarray): | |
pred = torch.tensor(pred) | |
# get the indices of the foreground points | |
indices = torch.nonzero(pred) | |
# get the min and max of the indices | |
min_x = indices[:, 1].min() | |
max_x = indices[:, 1].max() | |
min_y = indices[:, 0].min() | |
max_y = indices[:, 0].max() | |
# get the bbox | |
bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]] | |
return bbox | |
def get_bbox_per_cc(self, conn_components): | |
""" | |
conn_components: output of cca function | |
return list of bboxes per connected component, each bbox is a list of 2d points | |
""" | |
bboxes = [] | |
for i in range(1, conn_components[0]): | |
# get the indices of the foreground points | |
indices = torch.nonzero(torch.tensor(conn_components[1] == i)) | |
# get the min and max of the indices | |
min_x = indices[:, 1].min() | |
max_x = indices[:, 1].max() | |
min_y = indices[:, 0].min() | |
max_y = indices[:, 0].max() | |
# get the bbox | |
# bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]] | |
# bbox = [[min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y]] | |
# bbox should be in a XYXY format | |
bbox = [min_x, min_y, max_x, max_y] | |
bboxes.append(bbox) | |
bboxes = np.array(bboxes) | |
return bboxes | |
def get_most_conf_points(self, output_p_fg, pred, k): | |
''' | |
get the k most confident points from pred | |
output_p: 3d tensor of shape (H, W) | |
pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background | |
''' | |
# Create a mask where pred is 1 | |
mask = pred.bool() | |
# Apply the mask to output_p_fg | |
masked_output_p_fg = output_p_fg[mask] | |
if masked_output_p_fg.numel() == 0: | |
return None, None | |
# Get the top k probabilities and their indices | |
confidences, indices = torch.topk(masked_output_p_fg, k) | |
# Get the locations of the top k points in xy format | |
locations = torch.nonzero(mask)[indices] | |
# convert locations to xy format | |
locations = locations[:, [1, 0]] | |
# convert locations to list of lists | |
# points = [loc.tolist() for loc in locations] | |
return locations.numpy(), [float(conf.item()) for conf in confidences] | |
def plot_most_conf_points(self, points, confidences, pred, image, bboxes=None, title=None): | |
''' | |
points: np array of shape (N, 2) where each row is a point in xy format | |
pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background | |
image: 2d tensor of shape (H,W) representing the image | |
bbox: list or np array of shape (N, 4) where each row is a bbox in xyxy format | |
''' | |
warnings.filterwarnings('ignore', category=UserWarning) | |
if isinstance(pred, torch.Tensor): | |
pred = pred.cpu().detach().numpy() | |
if len(image.shape) == 3 and image.shape[0] == 3: | |
image = image.permute(1, 2, 0) | |
if title is None: | |
title="debug/most_conf_points.png" | |
fig = plt.figure() | |
image = (image - image.min()) / (image.max() - image.min()) | |
plt.imshow(image) | |
plt.imshow(pred, alpha=0.5) | |
for i, point in enumerate(points): | |
plt.scatter(point[0][0], point[0][1], cmap='viridis', marker='*', c='red') | |
if confidences is not None: | |
plt.text(point[0], point[1], f"{confidences[i]:.3f}", fontsize=12, color='red') | |
# assume points is a list of lists | |
if bboxes is not None: | |
for bbox in bboxes: | |
if bbox is None: | |
continue | |
bbox = np.array(bbox) | |
# plt.scatter(bbox[:, 1], bbox[:, 0], c='red') | |
# plot a line connecting the points | |
box = np.array([[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]]) | |
box = np.vstack([box, box[0]]) | |
plt.plot(box[:, 0], box[:, 1], c='green') | |
plt.colorbar() | |
fig.savefig(title) | |
plt.close(fig) | |
def plot_sam_preds(self, masks, scores, image, input_point, input_label, input_box=None): | |
if len(image.shape) == 3: | |
image = image.permute(1, 2, 0) | |
image = (image - image.min()) / (image.max() - image.min()) | |
for i, (mask, score) in enumerate(zip(masks, scores)): | |
plt.figure(figsize=(10,10)) | |
plt.imshow(image) | |
show_mask(mask, plt.gca()) | |
if input_point is not None: | |
show_points(input_point, input_label, plt.gca()) | |
if input_box is not None: | |
show_box(input_box, plt.gca()) | |
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) | |
# plt.axis('off') | |
plt.savefig(f'debug/sam_mask_{i+1}.png') | |
plt.close() | |
if i > 5: | |
break | |
def get_sam_input_points(self, conn_components, output_p, get_neg_points=False, l=1): | |
""" | |
args: | |
conn_components: output of cca function | |
output_p: 3d tensor of shape (1, 2, H, W) | |
get_neg_points: bool, if True then return the negative points | |
l: int, number of negative points to get | |
""" | |
sam_input_points = [] | |
sam_neg_points = [] | |
fg_p = output_p[0, 1].detach().cpu() | |
if get_neg_points: | |
# get global negative points | |
bg_p = output_p[0, 0].detach().cpu() | |
bg_p[bg_p < 0.95] = 0 | |
bg_pred = torch.where(bg_p > 0, 1, 0) | |
glob_neg_points, _ = self.get_most_conf_points(bg_p, bg_pred, 1) | |
if self.debug: | |
# plot the bg_p as a heatmap | |
plt.figure() | |
plt.imshow(bg_p) | |
plt.colorbar() | |
plt.savefig('debug/bg_p_heatmap.png') | |
plt.close() | |
for i, cc_id in enumerate(np.unique(conn_components[1])): | |
# get self.num_points_for_sam most confident points from pred | |
if cc_id == 0: | |
continue # skip background | |
pred = torch.tensor(conn_components[1] == cc_id).float() | |
if self.point_mode == CONF_MODE: | |
points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam) # (N, 2) | |
elif self.point_mode == CENTROID_MODE: | |
points = conn_components[3][cc_id][None, :] # (1, 2) | |
confidences = [1 for _ in range(len(points))] | |
elif self.point_mode == BOTH_MODE: | |
points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam) | |
point = conn_components[3][cc_id][None, :] | |
points = np.vstack([points, point]) # (N+1, 2) | |
confidences.append(1) | |
else: | |
raise NotImplementedError(f"point mode {self.point_mode} not implemented") | |
sam_input_points.append(np.array(points)) | |
if get_neg_points: | |
pred_uint8 = (pred.numpy() * 255).astype(np.uint8) | |
# Dilate the mask to expand it | |
kernel_size = 3 # Size of the dilation kernel, adjust accordingly | |
kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
dilation_iterations = 10 # Number of times dilation is applied, adjust as needed | |
dilated_mask = cv2.dilate(pred_uint8, kernel, iterations=dilation_iterations) | |
# Subtract the original mask from the dilated mask | |
# This will give a boundary that is only outside the original mask | |
outside_boundary = dilated_mask - pred_uint8 | |
# Convert back to torch tensor and normalize | |
boundary = torch.tensor(outside_boundary).float() / 255 | |
try: | |
bg_p = output_p[0, 0].detach().cpu() | |
neg_points, neg_confidences = self.get_most_conf_points(bg_p, boundary, l) | |
except RuntimeError as e: | |
# make each point (None, None) | |
neg_points = None | |
# append global negative points to the negative points | |
if neg_points is not None and glob_neg_points is not None: | |
neg_points = np.vstack([neg_points, glob_neg_points]) | |
else: | |
neg_points = glob_neg_points if neg_points is None else neg_points | |
if self.debug and neg_points is not None: | |
# draw an image with 2 subplots, one is the pred and the other is the boundary | |
plt.figure() | |
plt.subplot(121) | |
plt.imshow(pred) | |
plt.imshow(boundary, alpha=0.5) | |
# plot the neg points | |
plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red') | |
plt.subplot(122) | |
plt.imshow(pred) | |
plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red') | |
plt.savefig('debug/pred_and_boundary.png') | |
plt.close() | |
sam_neg_points.append(neg_points) | |
else: | |
# create a list of None same shape as points | |
sam_neg_points = [None for _ in range(len(sam_input_points))] | |
sam_input_labels = np.array([l+1 for l, cc_points in enumerate(sam_input_points) for _ in range(len(cc_points))]) | |
sam_input_points = np.stack(sam_input_points) # should be of shape (num_connected_components, num_points_for_sam, 2) | |
# if get_neg_points: | |
sam_neg_input_points = np.stack(sam_neg_points) if sam_neg_points is not None else None | |
if sam_neg_input_points is not None: | |
sam_neg_input_points = sam_neg_points | |
sam_neg_input_labels = np.array([0] * len(sam_neg_input_points) ) | |
else: | |
sam_neg_input_points = None | |
sam_neg_input_labels = None | |
return sam_input_points, sam_input_labels, sam_neg_input_points, sam_neg_input_labels | |
def get_sam_input_mask(self, conn_components): | |
sam_input_masks = [] | |
sam_input_mask_lables = [] | |
for i, cc_id in enumerate(np.unique(conn_components[1])): | |
# get self.num_points_for_sam most confident points from pred | |
if cc_id == 0: | |
continue | |
pred = torch.tensor(conn_components[1] == cc_id).float() | |
sam_input_masks.append(pred) | |
sam_input_mask_lables.append(cc_id) | |
sam_input_masks = np.stack(sam_input_masks) | |
sam_input_mask_lables = np.array(sam_input_mask_lables) | |
return sam_input_masks, sam_input_mask_lables | |
def predict_w_masks(self, sam_input_masks, qry_img, original_size): | |
masks = [] | |
scores = [] | |
for in_mask in sam_input_masks: | |
in_mask = cv2.resize(in_mask, (256, 256), interpolation=cv2.INTER_NEAREST) | |
in_mask[in_mask == 1] = 10 | |
in_mask[in_mask == 0] = -8 | |
assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8 | |
self.predictor.set_image(qry_img) | |
mask, score, _ = self.predictor.predict( | |
mask_input=in_mask[None, ...].astype(np.uint8), | |
multimask_output=True) | |
# get max index from score | |
if self.debug: | |
# plot each channel of mask | |
fig, ax = plt.subplots(1, 4, figsize=(15, 5)) | |
for i in range(mask.shape[0]): | |
ax[i].imshow(qry_img) | |
ax[i].imshow(mask[i], alpha=0.5) | |
ax[i].set_title(f"Mask {i+1}, Score: {score[i]:.3f}", fontsize=18) | |
# ax[i].axis('off') | |
ax[-1].imshow(cv2.resize(in_mask, original_size, interpolation=cv2.INTER_NEAREST)) | |
fig.savefig(f'debug/sam_mask_from_mask_prompts.png') | |
plt.close(fig) | |
max_index = score.argmax() | |
masks.append(mask[max_index]) | |
scores.append(score[max_index]) | |
return masks, scores | |
def predict_w_points_bbox(self, sam_input_points, bboxes, sam_neg_input_points, qry_img, pred, return_logits=False): | |
masks, scores = [], [] | |
self.predictor.set_image(qry_img) | |
# if sam_input_points is None: | |
# sam_input_points = [None for _ in range(len(bboxes))] | |
for point, bbox_xyxy, neg_point in zip(sam_input_points, bboxes, sam_neg_input_points): | |
assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8 | |
points = point | |
point_labels = np.array([1] * len(point)) if point is not None else None | |
if self.use_neg_points: | |
neg_points = [npoint for npoint in neg_point if None not in npoint] | |
points = np.vstack([point, *neg_points]) | |
point_labels = np.array([1] * len(point) + [0] * len(neg_points)) | |
if self.debug: | |
self.plot_most_conf_points(points[:, None, ...], None, pred, qry_img, bboxes=bbox_xyxy[None,...] if bbox_xyxy is not None else None, title="debug/pos_neg_points.png") # TODO add plots for all points not just the first set of points | |
mask, score, _ = self.predictor.predict( | |
point_coords=points, | |
point_labels=point_labels, | |
# box=bbox_xyxy[None, :] if bbox_xyxy is not None else None, | |
box = bbox_xyxy if bbox_xyxy is not None else None, | |
# mask_input=sam_mask_input, | |
return_logits=return_logits, | |
multimask_output=False if self.use_cca else True | |
) | |
# best_pred_idx = np.argmax(score) | |
best_pred_idx = 0 | |
masks.append(mask[best_pred_idx]) | |
scores.append(score[best_pred_idx]) | |
if self.debug: | |
# pass | |
self.plot_sam_preds(mask, score, qry_img[...,0], points.reshape(-1,2) if sam_input_points is not None else None, point_labels, input_box=bbox_xyxy if bbox_xyxy is not None else None) | |
return masks, scores | |
def forward(self, query_image, coarse_model_input, degrees_rotate=0): | |
""" | |
query_image: 3d tensor of shape (1, 3, H, W) | |
images should be normalized with mean and std but not to [0, 1]? | |
""" | |
original_size = query_image.shape[-2] | |
# rotate query_image by degrees_rotate | |
start_time = time.time() | |
rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate) | |
# print(f"rotating query image took {time.time() - start_time} seconds") | |
start_time = time.time() | |
coarse_model_input.set_query_images(rotated_img) | |
output_logits_rot = self.coarse_segmentation_model(coarse_model_input) | |
# print(f"ALPNet took {time.time() - start_time} seconds") | |
if degrees_rotate != 0: | |
start_time = time.time() | |
output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate) | |
# print(f"reversing rotated output_logits took {time.time() - start_time} seconds") | |
else: | |
output_logits = output_logits_rot | |
# check if softmax is needed | |
output_p = output_logits.softmax(dim=1) | |
# output_p = output_logits | |
pred = output_logits.argmax(dim=1)[0] | |
if self.debug: | |
_pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu()) | |
plt.subplot(132) | |
plt.imshow(query_image[0,0].detach().cpu()) | |
plt.imshow(_pred, alpha=0.5) | |
plt.subplot(131) | |
# plot heatmap of prob of being fg | |
plt.imshow(output_p[0, 1].detach().cpu()) | |
# plot rotated query image and rotated pred | |
output_p_rot = output_logits_rot.softmax(dim=1) | |
_pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu()) | |
_pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0] | |
plt.subplot(133) | |
plt.imshow(rotated_img[0, 0].detach().cpu()) | |
plt.imshow(_pred_rot, alpha=0.5) | |
plt.savefig('debug/coarse_pred.png') | |
plt.close() | |
if self.coarse_pred_only: | |
output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits | |
pred = output_logits.argmax(dim=1)[0] | |
conf = get_confidence_from_logits(output_logits) | |
if self.use_cca: | |
_pred = np.array(pred.detach().cpu()) | |
_pred, conf = cca(_pred, output_logits, return_conf=True) | |
pred = torch.from_numpy(_pred) | |
if self.training: | |
return output_logits, [conf] | |
# Ensure pred is a float tensor for consistent visualization | |
return pred.float(), [conf] | |
if query_image.shape[-2:] != self.image_size: | |
query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear') | |
output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear') | |
# if need_softmax(output_logits): | |
# output_logits = output_logits.softmax(dim=1) | |
# output_p = output_logits | |
output_p = output_logits.softmax(dim=1) | |
pred = output_p.argmax(dim=1)[0] | |
_pred = np.array(output_p.argmax(dim=1)[0].detach().cpu()) | |
start_time = time.time() | |
if self.use_cca: | |
conn_components = cca(_pred, output_logits, return_cc=True) | |
conf=None | |
else: | |
conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True) | |
if self.debug: | |
plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf) | |
# print(f"connected components took {time.time() - start_time} seconds") | |
if _pred.max() == 0: | |
return output_p.argmax(dim=1)[0], [0] | |
# get bbox from pred | |
if self.use_bbox: | |
start_time = time.time() | |
try: | |
bboxes = self.get_bbox_per_cc(conn_components) | |
except: | |
bboxes = [None] * conn_components[0] | |
else: | |
bboxes = [None] * conn_components[0] | |
# print(f"getting bboxes took {time.time() - start_time} seconds") | |
start_time = time.time() | |
if self.use_points: | |
sam_input_points, sam_input_point_labels, sam_neg_input_points, sam_neg_input_labels = self.get_sam_input_points(conn_components, output_p, get_neg_points=self.use_neg_points, l=1) | |
else: | |
sam_input_points = [None] * conn_components[0] | |
sam_input_point_labels = [None] * conn_components[0] | |
sam_neg_input_points = [None] * conn_components[0] | |
sam_neg_input_labels = [None] * conn_components[0] | |
# print(f"getting sam input points took {time.time() - start_time} seconds") | |
if self.use_mask: | |
sam_input_masks, sam_input_mask_labels = self.get_sam_input_mask(conn_components) | |
else: | |
sam_input_masks = None | |
sam_input_mask_labels = None | |
if self.debug and sam_input_points is not None: | |
title = f'debug/most_conf_points.png' | |
if self.use_cca: | |
title = f'debug/most_conf_points_cca.png' | |
# convert points to a list where each item is a list of 2 elements in xy format | |
self.plot_most_conf_points(sam_input_points, None, _pred, query_image[0, 0].detach().cpu(), bboxes=bboxes, title=title) # TODO add plots for all points not just the first set of points | |
# self.sam_trans = None | |
if self.sam_trans is None: | |
query_image = query_image.permute(1, 2, 0).detach().cpu().numpy() | |
else: | |
query_image = self.sam_trans.apply_image_torch(query_image[0]) | |
query_image = self.sam_trans.preprocess(query_image) | |
query_image = query_image.permute(1, 2, 0).detach().cpu().numpy() | |
# mask = self.sam_trans.preprocess(mask) | |
query_image = ((query_image - query_image.min()) / (query_image.max() - query_image.min()) * 255).astype(np.uint8) | |
if self.use_mask: | |
masks, scores = self.predict_w_masks(sam_input_masks, query_image, original_size) | |
start_time = time.time() | |
if self.use_points or self.use_bbox: | |
masks, scores = self.predict_w_points_bbox(sam_input_points, bboxes, sam_neg_input_points, query_image, pred, return_logits=True if self.training else False) | |
# print(f"predicting w points/bbox took {time.time() - start_time} seconds") | |
pred = sum(masks) | |
if not self.training: | |
pred = pred > 0 | |
pred = torch.tensor(pred).float().to(output_p.device) | |
# pred = torch.tensor(masks[0]).float().cuda() | |
# resize pred to the size of the input | |
pred = F.interpolate(pred.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0] | |
return pred, scores | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30/255, 144/255, 255/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def show_points(coords, labels, ax, marker_size=375): | |
pos_points = coords[labels==1] | |
neg_points = coords[labels==0] | |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
def show_box(box, ax): | |
x0, y0 = box[0], box[1] | |
w, h = box[2] - box[0], box[3] - box[1] | |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
def need_softmax(tensor, dim=1): | |
return not torch.all(torch.isclose(tensor.sum(dim=dim), torch.ones_like(tensor.sum(dim=dim))) & (tensor >= 0)) | |