|
import retro |
|
import gym |
|
import math |
|
import random |
|
import numpy as np |
|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
from collections import namedtuple, deque |
|
from itertools import count |
|
from gym import spaces |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
import cv2 |
|
import torch |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
class MaxAndSkipEnv(gym.Wrapper): |
|
def __init__(self, env, skip=4): |
|
"""Return only every `skip`-th frame""" |
|
gym.Wrapper.__init__(self, env) |
|
|
|
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) |
|
self._skip = skip |
|
|
|
def step(self, action): |
|
"""Repeat action, sum reward, and max over last observations.""" |
|
total_reward = 0.0 |
|
done = None |
|
for i in range(self._skip): |
|
obs, reward, done, info = self.env.step(action) |
|
if i == self._skip - 2: self._obs_buffer[0] = obs |
|
if i == self._skip - 1: self._obs_buffer[1] = obs |
|
total_reward += reward |
|
if done: |
|
break |
|
|
|
|
|
max_frame = self._obs_buffer.max(axis=0) |
|
|
|
return max_frame, total_reward, done, info |
|
|
|
def reset(self, **kwargs): |
|
return self.env.reset(**kwargs) |
|
|
|
|
|
class LazyFrames(object): |
|
def __init__(self, frames): |
|
"""This object ensures that common frames between the observations are only stored once. |
|
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay |
|
buffers. |
|
This object should only be converted to numpy array before being passed to the model. |
|
You'd not believe how complex the previous solution was.""" |
|
self._frames = frames |
|
self._out = None |
|
|
|
def _force(self): |
|
if self._out is None: |
|
self._out = np.concatenate(self._frames, axis=2) |
|
self._frames = None |
|
return self._out |
|
|
|
def __array__(self, dtype=None): |
|
out = self._force() |
|
if dtype is not None: |
|
out = out.astype(dtype) |
|
return out |
|
|
|
def __len__(self): |
|
return len(self._force()) |
|
|
|
def __getitem__(self, i): |
|
return self._force()[i] |
|
|
|
|
|
class FrameStack(gym.Wrapper): |
|
def __init__(self, env, k): |
|
"""Stack k last frames. |
|
Returns lazy array, which is much more memory efficient. |
|
See Also |
|
-------- |
|
baselines.common.atari_wrappers.LazyFrames |
|
""" |
|
gym.Wrapper.__init__(self, env) |
|
self.k = k |
|
self.frames = deque([], maxlen=k) |
|
shp = env.observation_space.shape |
|
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype) |
|
|
|
def reset(self): |
|
ob = self.env.reset() |
|
for _ in range(self.k): |
|
self.frames.append(ob) |
|
return self._get_ob() |
|
|
|
def step(self, action): |
|
ob, reward, done, info = self.env.step(action) |
|
self.frames.append(ob) |
|
return self._get_ob(), reward, done, info |
|
|
|
def _get_ob(self): |
|
assert len(self.frames) == self.k |
|
return LazyFrames(list(self.frames)) |
|
|
|
class ClipRewardEnv(gym.RewardWrapper): |
|
def __init__(self, env): |
|
gym.RewardWrapper.__init__(self, env) |
|
|
|
def reward(self, reward): |
|
"""Bin reward to {+1, 0, -1} by its sign.""" |
|
return np.sign(reward) |
|
|
|
|
|
class ImageToPyTorch(gym.ObservationWrapper): |
|
def __init__(self, env): |
|
super(ImageToPyTorch, self).__init__(env) |
|
old_shape = self.observation_space.shape |
|
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32) |
|
|
|
def observation(self, observation): |
|
return np.moveaxis(observation, 2, 0) |
|
|
|
|
|
class WarpFrame(gym.ObservationWrapper): |
|
def __init__(self, env): |
|
"""Warp frames to 84x84 as done in the Nature paper and later work.""" |
|
gym.ObservationWrapper.__init__(self, env) |
|
self.width = 84 |
|
self.height = 84 |
|
self.observation_space = spaces.Box(low=0, high=255, |
|
shape=(self.height, self.width, 1), dtype=np.uint8) |
|
|
|
def observation(self, frame): |
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) |
|
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) |
|
return frame[:, :, None] |
|
|
|
class AirstrikerDiscretizer(gym.ActionWrapper): |
|
|
|
def __init__(self, env): |
|
super(AirstrikerDiscretizer, self).__init__(env) |
|
buttons = ['B', 'A', 'MODE', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'C', 'Y', 'X', 'Z'] |
|
actions = [['LEFT'], ['RIGHT'], ['B']] |
|
self._actions = [] |
|
for action in actions: |
|
arr = np.array([False] * 12) |
|
for button in action: |
|
arr[buttons.index(button)] = True |
|
self._actions.append(arr) |
|
self.action_space = gym.spaces.Discrete(len(self._actions)) |
|
|
|
|
|
def action(self, a): |
|
return self._actions[a].copy() |
|
|
|
|
|
env = retro.make(game='Airstriker-Genesis') |
|
env = MaxAndSkipEnv(env) |
|
env = WarpFrame(env) |
|
env = ImageToPyTorch(env) |
|
env = FrameStack(env, 4) |
|
|
|
env = AirstrikerDiscretizer(env) |
|
env = ClipRewardEnv(env) |
|
|
|
|
|
is_ipython = 'inline' in matplotlib.get_backend() |
|
if is_ipython: |
|
from IPython import display |
|
|
|
plt.ion() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
Transition = namedtuple('Transition', |
|
('state', 'action', 'next_state', 'reward')) |
|
|
|
|
|
class ReplayMemory(object): |
|
|
|
def __init__(self, capacity): |
|
self.memory = deque([],maxlen=capacity) |
|
|
|
def push(self, *args): |
|
"""Save a transition""" |
|
self.memory.append(Transition(*args)) |
|
|
|
def sample(self, batch_size): |
|
return random.sample(self.memory, batch_size) |
|
|
|
def __len__(self): |
|
return len(self.memory) |
|
|
|
|
|
class DQN(nn.Module): |
|
|
|
def __init__(self, n_observations, n_actions): |
|
super(DQN, self).__init__() |
|
|
|
|
|
|
|
|
|
self.layer1 = nn.Conv2d(in_channels=n_observations, out_channels=32, kernel_size=8, stride=4) |
|
self.layer2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2) |
|
self.layer3 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), nn.ReLU(), nn.Flatten()) |
|
self.layer4 = nn.Linear(17024, 512) |
|
self.layer5 = nn.Linear(512, n_actions) |
|
|
|
|
|
|
|
def forward(self, x): |
|
x = F.relu(self.layer1(x)) |
|
x = F.relu(self.layer2(x)) |
|
x = F.relu(self.layer3(x)) |
|
x = F.relu(self.layer4(x)) |
|
return self.layer5(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BATCH_SIZE = 512 |
|
GAMMA = 0.99 |
|
EPS_START = 1 |
|
EPS_END = 0.01 |
|
EPS_DECAY = 10000 |
|
TAU = 0.005 |
|
|
|
LR = 0.00025 |
|
|
|
|
|
n_actions = env.action_space.n |
|
state = env.reset() |
|
n_observations = len(state) |
|
|
|
policy_net = DQN(n_observations, n_actions).to(device) |
|
target_net = DQN(n_observations, n_actions).to(device) |
|
target_net.load_state_dict(policy_net.state_dict()) |
|
|
|
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True) |
|
memory = ReplayMemory(10000) |
|
|
|
|
|
steps_done = 0 |
|
|
|
|
|
def select_action(state): |
|
global steps_done |
|
sample = random.random() |
|
eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY) |
|
steps_done += 1 |
|
if sample > eps_threshold: |
|
with torch.no_grad(): |
|
|
|
|
|
|
|
return policy_net(state).max(1)[1].view(1, 1), eps_threshold |
|
else: |
|
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long), eps_threshold |
|
|
|
|
|
episode_durations = [] |
|
|
|
|
|
def plot_durations(show_result=False): |
|
plt.figure(1) |
|
durations_t = torch.tensor(episode_durations, dtype=torch.float) |
|
if show_result: |
|
plt.title('Result') |
|
else: |
|
plt.clf() |
|
plt.title('Training...') |
|
plt.xlabel('Episode') |
|
plt.ylabel('Duration') |
|
plt.plot(durations_t.numpy()) |
|
|
|
if len(durations_t) >= 100: |
|
means = durations_t.unfold(0, 100, 1).mean(1).view(-1) |
|
means = torch.cat((torch.zeros(99), means)) |
|
plt.plot(means.numpy()) |
|
|
|
plt.pause(0.001) |
|
if is_ipython: |
|
if not show_result: |
|
display.display(plt.gcf()) |
|
display.clear_output(wait=True) |
|
else: |
|
display.display(plt.gcf()) |
|
|
|
|
|
|
|
def optimize_model(): |
|
if len(memory) < BATCH_SIZE: |
|
return |
|
transitions = memory.sample(BATCH_SIZE) |
|
|
|
|
|
|
|
batch = Transition(*zip(*transitions)) |
|
|
|
|
|
|
|
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, |
|
batch.next_state)), device=device, dtype=torch.bool) |
|
non_final_next_states = torch.cat([s for s in batch.next_state |
|
if s is not None]) |
|
state_batch = torch.cat(batch.state) |
|
action_batch = torch.cat(batch.action) |
|
reward_batch = torch.cat(batch.reward) |
|
|
|
|
|
|
|
|
|
state_action_values = policy_net(state_batch).gather(1, action_batch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
next_state_values = torch.zeros(BATCH_SIZE, device=device) |
|
with torch.no_grad(): |
|
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] |
|
|
|
expected_state_action_values = (next_state_values * GAMMA) + reward_batch |
|
|
|
|
|
criterion = nn.SmoothL1Loss() |
|
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1)) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
|
|
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100) |
|
optimizer.step() |
|
|
|
|
|
with SummaryWriter() as writer: |
|
if torch.cuda.is_available(): |
|
num_episodes = 600 |
|
else: |
|
num_episodes = 50 |
|
epsilon = 1 |
|
episode_rewards = [] |
|
for i_episode in range(num_episodes): |
|
|
|
|
|
state = env.reset() |
|
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) |
|
episode_reward = 0 |
|
for t in count(): |
|
action, epsilon = select_action(state) |
|
observation, reward, done, info = env.step(action.item()) |
|
reward = torch.tensor([reward], device=device) |
|
|
|
done = done or info["gameover"] == 1 |
|
if done: |
|
episode_durations.append(t + 1) |
|
print(f"Episode {i_episode} done") |
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0) |
|
|
|
|
|
memory.push(state, action, next_state, reward) |
|
episode_reward += reward |
|
|
|
state = next_state |
|
|
|
|
|
optimize_model() |
|
|
|
|
|
|
|
target_net_state_dict = target_net.state_dict() |
|
policy_net_state_dict = policy_net.state_dict() |
|
for key in policy_net_state_dict: |
|
target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU) |
|
target_net.load_state_dict(target_net_state_dict) |
|
|
|
|
|
|
|
|
|
|
|
writer.add_scalar("Rewards/Episode", episode_reward, i_episode) |
|
writer.add_scalar("Epsilon", epsilon, i_episode) |
|
writer.flush() |
|
print('Complete') |
|
plot_durations(show_result=True) |
|
plt.ioff() |
|
plt.show() |
|
|