Spaces:
Runtime error
Runtime error
import math | |
import numpy as np | |
import torch | |
from torch import nn, Tensor | |
from torch.distributions import Categorical | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model: int, max_len: int = 100): | |
super().__init__() | |
position = torch.arange(max_len).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) | |
pe = torch.zeros(max_len, d_model) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
self.register_buffer('pe', pe) | |
def forward(self, positions: Tensor) -> Tensor: | |
return self.pe[positions] | |
class Actor(nn.Module): | |
def __init__(self, pos_encoder): | |
super(Actor, self).__init__() | |
self.activation = nn.Tanh() | |
self.project = nn.Linear(4, 8) | |
nn.init.xavier_uniform_(self.project.weight, gain=1.0) | |
nn.init.constant_(self.project.bias, 0) | |
self.pos_encoder = pos_encoder | |
self.embedding_fixed = nn.Embedding(2, 1) | |
self.embedding_legal_op = nn.Embedding(2, 1) | |
self.tokens_start_end = nn.Embedding(3, 4) | |
# self.conv_transform = nn.Conv1d(5, 1, 1) | |
# nn.init.kaiming_normal_(self.conv_transform.weight, mode="fan_out", nonlinearity="relu") | |
# nn.init.constant_(self.conv_transform.bias, 0) | |
self.enc1 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True, | |
norm_first=True) | |
self.enc2 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True, | |
norm_first=True) | |
self.final_tmp = nn.Sequential( | |
layer_init_tanh(nn.Linear(8, 32)), | |
nn.Tanh(), | |
layer_init_tanh(nn.Linear(32, 1), std=0.01) | |
) | |
self.no_op = nn.Sequential( | |
layer_init_tanh(nn.Linear(8, 32)), | |
nn.Tanh(), | |
layer_init_tanh(nn.Linear(32, 1), std=0.01) | |
) | |
def forward(self, obs, attention_interval_mask, job_resource, mask, indexes_inter, tokens_start_end): | |
embedded_obs = torch.cat((self.embedding_fixed(obs[:, :, :, 0].long()), obs[:, :, :, 1:3], | |
self.embedding_legal_op(obs[:, :, :, 3].long())), dim=3) | |
non_zero_tokens = tokens_start_end != 0 | |
t = tokens_start_end[non_zero_tokens].long() | |
embedded_obs[non_zero_tokens] = self.tokens_start_end(t) | |
pos_encoder = self.pos_encoder(indexes_inter.long()) | |
pos_encoder[non_zero_tokens] = 0 | |
obs = self.project(embedded_obs) + pos_encoder | |
transformed_obs = obs.view(-1, obs.shape[2], obs.shape[3]) | |
attention_interval_mask = attention_interval_mask.view(-1, attention_interval_mask.shape[-1]) | |
transformed_obs = self.enc1(transformed_obs, src_key_padding_mask=attention_interval_mask == 1) | |
transformed_obs = transformed_obs.view(obs.shape) | |
obs = transformed_obs.mean(dim=2) | |
job_resource = job_resource[:, :-1, :-1] == 0 | |
obs_action = self.enc2(obs, src_mask=job_resource) + obs | |
logits = torch.cat((self.final_tmp(obs_action).squeeze(2), self.no_op(obs_action).mean(dim=1)), dim=1) | |
return logits.masked_fill(mask == 0, -3.4028234663852886e+38) | |
class Agent(nn.Module): | |
def __init__(self): | |
super(Agent, self).__init__() | |
self.pos_encoder = PositionalEncoding(8) | |
self.actor = Actor(self.pos_encoder) | |
def forward(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end, | |
action=None): | |
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end) | |
probs = Categorical(logits=logits) | |
if action is None: | |
probabilities = probs.probs | |
actions = torch.multinomial(probabilities, probabilities.shape[1]) | |
return actions, torch.log(probabilities), probs.entropy() | |
else: | |
return logits, probs.log_prob(action), probs.entropy() | |
def get_action_only(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end): | |
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end) | |
probs = Categorical(logits=logits) | |
return probs.sample() | |
def get_logits_only(self,data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end): | |
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end) | |
return logits | |
def layer_init_tanh(layer, std=np.sqrt(2), bias_const=0.0): | |
torch.nn.init.orthogonal_(layer.weight, std) | |
if layer.bias is not None: | |
torch.nn.init.constant_(layer.bias, bias_const) | |
return layer | |