File size: 2,058 Bytes
719d0db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import matplotlib.pyplot as plt
import numpy as np
from utils.util_calc import calc_tour_length

def add_arrow(tour, coords, step, color, ax):
    if len(tour) > 1:
        x = coords[:, 0]
        y = coords[:, 1]
        x0 = x[tour[step]]; y0 = y[tour[step]]
        x1 = x[tour[step+1]]; y1 = y[tour[step+1]] 
        ax.annotate('', xy=[x1, y1], xytext=[x0, y0],
                    arrowprops=dict(shrink=0, width=1, headwidth=8,
                                    headlength=10, connectionstyle="arc3",
                                    facecolor=color, edgecolor=color))

def visualize_tsp_tour(coords, tour, ax, linestyle="--"):
    """
    Parameters
    ----------
    instance: 2d list [num_nodes x coordinates]
    tour: 1d list [seq_length]
    """
    points = np.array(coords)
    tour = np.array(tour)
    # tour = tour - 1 # offset to make the first index 0
    x = points[:, 0]
    y = points[:, 1]

    # visualize points
    ax.scatter(x, y, c="black", zorder=2)

    # visualize pathes
    ax.plot(x[tour], y[tour], linestyle, c='black', zorder=1)

    # add an arrow indicating initial direction
    add_arrow(tour, points, 0, "black", ax)

def visualize_factual_and_cf_tours(factual_tour, cf_tour, coords, cf_step, vis_filename):
    fig = plt.figure(figsize=(20, 10))
    ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 2, 2)
    visualize_tsp_tour(coords, factual_tour, ax1)
    visualize_tsp_tour(coords, cf_tour, ax2)
    visualize_tsp_tour(coords, factual_tour[:cf_step], ax1, linestyle="-")
    visualize_tsp_tour(coords, cf_tour[:cf_step], ax2, linestyle="-")
    add_arrow(factual_tour, coords, cf_step-1, "red", ax1) # factual visit
    add_arrow(cf_tour, coords, cf_step-1, "blue", ax2) # counterfactual visit
    factual_tour_length = calc_tour_length(factual_tour, coords)
    cf_tour_length = calc_tour_length(cf_tour, coords)
    ax1.set_title(f"Factual tour\nTour length={factual_tour_length:.3f}")
    ax2.set_title(f"Counterfactual tour\nTour length={cf_tour_length:.3f}")
    plt.savefig(vis_filename)