File size: 5,515 Bytes
4cb4fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from flax.core import FrozenDict
import flax.linen as nn
import jax
import jax.numpy as jnp
from functools import partial


# --- Base functions ---


def scale(state: jnp.ndarray) -> jnp.ndarray:
    return state / 255.0


class Torso(nn.Module):
    initialization_type: str

    @nn.compact
    def __call__(self, state):
        if self.initialization_type == "dqn":
            initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal")
        elif self.initialization_type == "iqn":
            initializer = nn.initializers.variance_scaling(
                scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
            )

        x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(state)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x)
        x = nn.relu(x)

        return x.flatten()


class Head(nn.Module):
    n_actions: int
    initialization_type: str

    @nn.compact
    def __call__(self, x):
        if self.initialization_type == "dqn":
            initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal")
        elif self.initialization_type == "iqn":
            initializer = nn.initializers.variance_scaling(
                scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
            )

        x = nn.Dense(features=512, kernel_init=initializer)(x)
        x = nn.relu(x)

        return nn.Dense(features=self.n_actions, kernel_init=initializer)(x)


class QuantileEmbedding(nn.Module):
    n_features: int = 7744
    quantile_embedding_dim: int = 64

    @nn.compact
    def __call__(self, key, n_quantiles):
        initializer = nn.initializers.variance_scaling(scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform")

        quantiles = jax.random.uniform(key, shape=(n_quantiles, 1))
        arange = jnp.arange(1, self.quantile_embedding_dim + 1).reshape((1, self.quantile_embedding_dim))

        quantile_embedding = nn.Dense(features=self.n_features, kernel_init=initializer)(
            jnp.cos(jnp.pi * quantiles @ arange)
        )
        # output (n_quantiles, n_features) | (n_quantiles)
        return (nn.relu(quantile_embedding), jnp.squeeze(quantiles, axis=1))


# --- i-DQN networks ---


class AtariSharediDQNNet:
    def __init__(self, n_actions: int) -> None:
        self.n_heads = 5
        self.n_actions = n_actions
        self.torso = Torso("dqn")
        self.head = Head(self.n_actions, "dqn")

    def apply(self, params: FrozenDict, idx_head: int, state: jnp.ndarray) -> jnp.ndarray:
        feature = self.torso.apply(
            params[f"torso_params_{min(idx_head, 1)}"],
            state,
        )

        return self.head.apply(params[f"head_params_{idx_head}"], feature)


class AtariiDQN:
    def __init__(self, n_actions: int, idx_head: int) -> None:
        self.network = AtariSharediDQNNet(n_actions)
        self.idx_head = idx_head

    @partial(jax.jit, static_argnames="self")
    def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8:
        return jnp.argmax(self.network.apply(params, self.idx_head, scale(state))).astype(jnp.int8)


# --- i-IQN networks ---


class AtariSharediIQNNet:
    def __init__(self, n_actions: int) -> None:
        self.n_heads = 4
        self.n_actions = n_actions
        self.torso = Torso("iqn")
        self.quantile_embedding = QuantileEmbedding()
        self.head = Head(self.n_actions, "iqn")

    def apply(
        self, params: FrozenDict, idx_head: int, state: jnp.ndarray, key: jax.random.PRNGKey, n_quantiles: int
    ) -> jnp.ndarray:
        # output (n_features)
        state_feature = self.torso.apply(
            jax.tree_util.tree_map(
                lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["torso_params"]
            ),
            state,
        )

        # output (n_quantiles, n_features)
        quantiles_feature, _ = self.quantile_embedding.apply(
            jax.tree_util.tree_map(
                lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["quantiles_params"]
            ),
            key,
            n_quantiles,
        )

        # mapping over the quantiles | output (n_quantiles, n_features)
        feature = jax.vmap(
            lambda quantile_feature_, state_feature_: quantile_feature_ * state_feature_, in_axes=(0, None)
        )(quantiles_feature, state_feature)

        return self.head.apply(
            jax.tree_util.tree_map(lambda param: param[idx_head], params["head_params"]), feature
        )  # output (n_quantiles, n_actions)


class AtariiIQN:
    def __init__(self, n_actions: int, idx_head: int) -> None:
        self.network = AtariSharediIQNNet(n_actions)
        self.idx_head = idx_head
        self.n_quantiles_policy = 32

    @partial(jax.jit, static_argnames="self")
    def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8:
        # output (n_quantiles, n_actions)
        q_quantiles = self.network.apply(params, self.idx_head, scale(state), key, self.n_quantiles_policy)
        q_values = jnp.mean(q_quantiles, axis=0)

        return jnp.argmax(q_values).astype(jnp.int8)