File size: 1,637 Bytes
c5ca37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import math
import torch
import torch.nn as nn

from ..utils import log_sum_exp

class EncoderBase(nn.Module):
    """docstring for EncoderBase"""
    def __init__(self):
        super(EncoderBase, self).__init__()

    def forward(self, x):
        """
        Args:
            x: (batch_size, *)
        Returns: the tensors required to parameterize a distribution.
        E.g. for Gaussian encoder it returns the mean and variance tensors
        """

        raise NotImplementedError

    def sample(self, input, nsamples):
        """sampling from the encoder
        Returns: Tensor1
            Tensor1: the tensor latent z with shape [batch, nsamples, nz]
        """

        raise NotImplementedError

    def encode(self, input, nsamples):
        """perform the encoding and compute the KL term
        Returns: Tensor1, Tensor2
            Tensor1: the tensor latent z with shape [batch, nsamples, nz]
            Tensor2: the tenor of KL for each x with shape [batch]
        """

        raise NotImplementedError


    def eval_inference_dist(self, x, z, param=None):
        """this function computes log q(z | x)
        Args:
            z: tensor
                different z points that will be evaluated, with
                shape [batch, nsamples, nz]
        Returns: Tensor1
            Tensor1: log q(z|x) with shape [batch, nsamples]
        """

        raise NotImplementedError

    def calc_mi(self, x):
        """Approximate the mutual information between x and z
        I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))
        Returns: Float
        """

        raise NotImplementedError