Spaces:
Running
Running
import yaml | |
import os | |
import torch | |
import random | |
import numpy as np | |
class BaseConfig: | |
def __init__(self): | |
config_path = os.path.join(os.path.dirname(__file__), 'config.yml') | |
with open(config_path, 'r') as file: | |
self.config = yaml.safe_load(file) | |
self.setup_environment() | |
def setup_environment(self): | |
seed = 42 | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
os.environ['CUDA_VISIBLE_DEVICES'] = self.config["gpu"]["visible_device"] | |
self.device = torch.device(self.config["gpu"]["device"]) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
torch.set_float32_matmul_precision("medium") | |
def custom_collate(self, batch): | |
"""Handles variable size of the scans and pads the sequence dimension.""" | |
images = [item['image'] for item in batch] | |
labels = [item['label'] for item in batch] | |
max_len = self.config["data"]["collate"] # Single scan input | |
padded_images = [] | |
for img in images: | |
pad_size = max_len - img.shape[0] | |
if pad_size > 0: | |
padding = torch.zeros((pad_size,) + img.shape[1:]) | |
img_padded = torch.cat([img, padding], dim=0) | |
padded_images.append(img_padded) | |
else: | |
padded_images.append(img) | |
return {"image": torch.stack(padded_images, dim=0), "label": torch.stack(labels)} | |
def get_config(self): | |
return self.config |