JobShopCPRL / Network.py
Pierre Tassel
wip
d746b98
raw
history blame
4.92 kB
import math
import numpy as np
import torch
from torch import nn, Tensor
from torch.distributions import Categorical
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 100):
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, positions: Tensor) -> Tensor:
return self.pe[positions]
class Actor(nn.Module):
def __init__(self, pos_encoder):
super(Actor, self).__init__()
self.activation = nn.Tanh()
self.project = nn.Linear(4, 8)
nn.init.xavier_uniform_(self.project.weight, gain=1.0)
nn.init.constant_(self.project.bias, 0)
self.pos_encoder = pos_encoder
self.embedding_fixed = nn.Embedding(2, 1)
self.embedding_legal_op = nn.Embedding(2, 1)
self.tokens_start_end = nn.Embedding(3, 4)
# self.conv_transform = nn.Conv1d(5, 1, 1)
# nn.init.kaiming_normal_(self.conv_transform.weight, mode="fan_out", nonlinearity="relu")
# nn.init.constant_(self.conv_transform.bias, 0)
self.enc1 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
norm_first=True)
self.enc2 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
norm_first=True)
self.final_tmp = nn.Sequential(
layer_init_tanh(nn.Linear(8, 32)),
nn.Tanh(),
layer_init_tanh(nn.Linear(32, 1), std=0.01)
)
self.no_op = nn.Sequential(
layer_init_tanh(nn.Linear(8, 32)),
nn.Tanh(),
layer_init_tanh(nn.Linear(32, 1), std=0.01)
)
def forward(self, obs, attention_interval_mask, job_resource, mask, indexes_inter, tokens_start_end):
embedded_obs = torch.cat((self.embedding_fixed(obs[:, :, :, 0].long()), obs[:, :, :, 1:3],
self.embedding_legal_op(obs[:, :, :, 3].long())), dim=3)
non_zero_tokens = tokens_start_end != 0
t = tokens_start_end[non_zero_tokens].long()
embedded_obs[non_zero_tokens] = self.tokens_start_end(t)
pos_encoder = self.pos_encoder(indexes_inter.long())
pos_encoder[non_zero_tokens] = 0
obs = self.project(embedded_obs) + pos_encoder
transformed_obs = obs.view(-1, obs.shape[2], obs.shape[3])
attention_interval_mask = attention_interval_mask.view(-1, attention_interval_mask.shape[-1])
transformed_obs = self.enc1(transformed_obs, src_key_padding_mask=attention_interval_mask == 1)
transformed_obs = transformed_obs.view(obs.shape)
obs = transformed_obs.mean(dim=2)
job_resource = job_resource[:, :-1, :-1] == 0
obs_action = self.enc2(obs, src_mask=job_resource) + obs
logits = torch.cat((self.final_tmp(obs_action).squeeze(2), self.no_op(obs_action).mean(dim=1)), dim=1)
return logits.masked_fill(mask == 0, -3.4028234663852886e+38)
class Agent(nn.Module):
def __init__(self):
super(Agent, self).__init__()
self.pos_encoder = PositionalEncoding(8)
self.actor = Actor(self.pos_encoder)
def forward(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end,
action=None):
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
probs = Categorical(logits=logits)
if action is None:
probabilities = probs.probs
actions = torch.multinomial(probabilities, probabilities.shape[1])
return actions, torch.log(probabilities), probs.entropy()
else:
return logits, probs.log_prob(action), probs.entropy()
def get_action_only(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
probs = Categorical(logits=logits)
return probs.sample()
def get_logits_only(self,data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
return logits
def layer_init_tanh(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
if layer.bias is not None:
torch.nn.init.constant_(layer.bias, bias_const)
return layer