File size: 1,214 Bytes
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from src.drl.ac_models import SoftActor
from src.drl.nets import esmb_sample


class PMOESoftActor(SoftActor):
    def __init__(self, net_constructor, tar_ent=None):
        super(PMOESoftActor, self).__init__(net_constructor, tar_ent)

    def forward(self, obs, grad=True, mono=True):
        if grad:
            return self.net(obs, mono)
        with torch.no_grad():
            return self.net(obs, mono)

    def backward_policy(self, critic, obs):
        muss, stdss, betas = self.net.get_intermediate(obs)
        actss, logpss, _ = esmb_sample(muss, stdss, betas, False)
        obss = torch.unsqueeze(obs, dim=1).expand(-1, actss.shape[1], -1)
        qvaluess = critic.forward(obss, actss)
        l_pri = (torch.sum(self.alpha_coe(logpss, grad=False) - qvaluess, dim=-1)).mean()
        t = qvaluess - torch.max(qvaluess, -1, True).values
        v = torch.where(t == 0., 1., 0.) - betas
        l_frep = (v * v).sum(-1).mean()
        l = l_frep + l_pri
        l.backward()
        pass

    def backward_alpha(self, obs):
        _, logps = self.forward(obs, grad=False)
        loss_alpha = -(self.alpha_coe(logps + self.tar_ent)).mean()
        loss_alpha.backward()
        pass