Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from tqdm import tqdm | |
from typing import * | |
from src import logger | |
def predict_mask( | |
data: Any, | |
device: Any, | |
model: nn.Module, | |
inference: bool, | |
valid_loader=None, | |
criterion=None, | |
): | |
""" | |
predicts mask for the image | |
Args: | |
data (Any): image data for predicting | |
model (nn.Module): model for training | |
device (0/'cud'/'cpu'/Any): name of device | |
inference (bool): Whether to evaluate or predict | |
valid_Loader (nn.Module): test loader for training | |
criterion (nn.Module): loss criteria | |
Example: | |
>>> train( | |
>>> data = torch.FloatTensor, | |
>>> model=model, | |
>>> device=0/'cuda'/'cpu' | |
>>> ingerence=0 | |
>>> valid_loader= test_loader | |
>>> criterion= fn_loss | |
""" | |
if inference: | |
with torch.no_grad(): | |
image = data.type(torch.FloatTensor).to(device) | |
model = model.to(device) | |
pred = model(image) | |
pred = torch.sigmoid(pred) | |
mask = (pred > 0.6).float() | |
return mask.cpu().detach() | |
else: | |
with torch.no_grad(): | |
val_Loss = 0 | |
val_Dicescore = 0 | |
model.eval() | |
for x, y in tqdm(valid_loader): | |
x = x.type(torch.cuda.FloatTensor).to(device) | |
y = y.type(torch.cuda.FloatTensor).to(device) | |
predict = model(x) | |
loss = criterion(predict, y) | |
val_Loss += loss.item() | |
predict = torch.sigmoid(predict) | |
predict = (predict > 0.5).float() | |
dice_score = (2 * (y*predict).sum() + 1e-8)/((y+predict).sum() + 1e-8) | |
try: | |
val_Dicescore += dice_score.cpu().item() | |
except: | |
val_Dicescore += dice_score | |
val_Loss /= len(valid_loader) | |
val_Dicescore /= len(valid_loader) | |
logger.info(f"Test Loss: {val_Loss} - Dice Score: {val_Dicescore}") |