File size: 2,265 Bytes
abd2a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import random
import torch
import torch.distributed
import torch.utils.data

from . import data_transforms
from .model import FrameFieldModel
from .evaluator import Evaluator

from lydorn_utils import print_utils

try:
    import apex
    from apex import amp
    APEX_AVAILABLE = True
except ModuleNotFoundError:
    APEX_AVAILABLE = False


def evaluate(gpu: int, config: dict, shared_dict, barrier, eval_ds, backbone):
    # --- Setup DistributedDataParallel --- #
    rank = config["nr"] * config["gpus"] + gpu
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=config["world_size"],
        rank=rank
    )

    if gpu == 0:
        print("# --- Start evaluating --- #")

    # Choose device
    torch.cuda.set_device(gpu)

    # --- Online transform performed on the device (GPU):
    eval_online_cuda_transform = data_transforms.get_eval_online_cuda_transform(config)

    if "samples" in config:
        rng_samples = random.Random(0)
        eval_ds = torch.utils.data.Subset(eval_ds, rng_samples.sample(range(len(eval_ds)), config["samples"]))
        # eval_ds = torch.utils.data.Subset(eval_ds, range(config["samples"]))

    eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_ds, num_replicas=config["world_size"], rank=rank)

    eval_ds = torch.utils.data.DataLoader(eval_ds, batch_size=config["optim_params"]["eval_batch_size"], pin_memory=True, sampler=eval_sampler, num_workers=config["num_workers"])

    model = FrameFieldModel(config, backbone=backbone, eval_transform=eval_online_cuda_transform)
    model.cuda(gpu)

    if config["use_amp"] and APEX_AVAILABLE:
        amp.register_float_function(torch, 'sigmoid')
        model = amp.initialize(model, opt_level="O1")
    elif config["use_amp"] and not APEX_AVAILABLE and gpu == 0:
        print_utils.print_warning("WARNING: Cannot use amp because the apex library is not available!")

    # Wrap the model for distributed training
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

    evaluator = Evaluator(gpu, config, shared_dict, barrier, model, run_dirpath=config["eval_params"]["run_dirpath"])
    split_name = config["fold"][0]
    evaluator.evaluate(split_name, eval_ds)