route-explainer / eval_solvers.py
daisuke.kikuta
first commit
719d0db
raw
history blame
3.37 kB
import os
from tqdm import tqdm
import multiprocessing
import numpy as np
from utils.utils import load_dataset, calc_tour_length
from models.solvers.general_solver import GeneralSolver
from models.classifiers.ground_truth.ground_truth import GroundTruth
def eval_solver(solver, instance):
tour = solver.solve(instance)
tour_length = calc_tour_length(tour[0], instance["coords"])
return tour_length
def eval(data_path, problem, solver_name, fix_edges, parallel):
dataset = load_dataset(data_path)
num_cpus = os.cpu_count() if parallel else 1
if fix_edges:
solver = GroundTruth(problem, solver_name)
if parallel:
with multiprocessing.Pool(num_cpus) as pool:
tours = list(tqdm(pool.starmap(solver.solve, [(step, instance["tour"][vehicle_id], instance, f"{i}-{vehicle_id}-{step}")
for i, instance in enumerate(dataset)
for vehicle_id in range(len(instance["tour"]))
for step in range(1, len(instance["tour"][vehicle_id]))]), desc=f"Solving {data_path} with {solver_name}"))
else:
tours = []
for i, instance in enumerate(dataset):
for vehicle_id in range(len(instance["tour"])):
for step in range(1, len(instance["tour"][vehicle_id])):
tours.append(solver.solve(step, instance["tour"][vehicle_id], instance, f"{i}-{vehicle_id}-{step}"))
tour_length = {key: [] for key in tours[0].keys()}
for tour in tours:
for key, value in tour.items():
tour_length[key].append(value)
else:
solver = GeneralSolver(problem, solver_name)
with multiprocessing.Pool(num_cpus) as pool:
tour_length = list(tqdm(pool.starmap(eval_solver, [(solver, instance) for instance in dataset]), total=len(dataset), desc="Solving instances"))
feasible_ratio = 0.0
penalty = 0.0
avg_tour_length = np.mean(tour_length["tsp"])
std_tour_length = np.std(tour_length["tsp"])
return avg_tour_length, std_tour_length, feasible_ratio, penalty
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--problem", default="tsptw", type=str, help="Problem type: [tsptw, pctsp, pctsptw, cvrp]")
parser.add_argument("--solver_name", type=str, default="ortools", help="Select from ")
parser.add_argument("--data_path", type=str, help="Path to a dataset", required=True)
parser.add_argument("--parallel", action="store_true")
parser.add_argument("--all", action="store_true")
parser.add_argument("--fix_edges", action="store_true")
args = parser.parse_args()
avg_tour_length, std_tour_length, feasible_ratio, penalty = eval(data_path=args.data_path,
problem=args.problem,
solver_name=args.solver_name,
fix_edges=args.fix_edges,
parallel=args.parallel)
print(f"tour_length: {avg_tour_length} +/- {std_tour_length}")