File size: 1,906 Bytes
c5ca37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn


class DecoderBase(nn.Module):
    """docstring for Decoder"""
    def __init__(self):
        super(DecoderBase, self).__init__()


    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False

    def decode(self, x, z):
        """
        Args:
            x: (batch_size, seq_len)
            z: (batch_size, n_sample, nz)
        Returns: Tensor1
            Tensor1: the output logits with size (batch_size * n_sample, seq_len, vocab_size)
        """

        raise NotImplementedError

    def reconstruct_error(self, x, z):
        """reconstruction loss
        Args:
            x: (batch_size, *)
            z: (batch_size, n_sample, nz)
        Returns:
            loss: (batch_size, n_sample). Loss
            across different sentence and z
        """

        raise NotImplementedError

    def beam_search_decode(self, z, K):
        """beam search decoding
        Args:
            z: (batch_size, nz)
            K: the beam size
        Returns: List1
            List1: the decoded word sentence list
        """

        raise NotImplementedError

    def sample_decode(self, z):
        """sampling from z
        Args:
            z: (batch_size, nz)
        Returns: List1
            List1: the decoded word sentence list
        """

        raise NotImplementedError

    def greedy_decode(self, z):
        """greedy decoding from z
        Args:
            z: (batch_size, nz)
        Returns: List1
            List1: the decoded word sentence list
        """

        raise NotImplementedError

    def log_probability(self, x, z):
        """
        Args:
            x: (batch_size, *)
            z: (batch_size, n_sample, nz)
        Returns:
            log_p: (batch_size, n_sample).
                log_p(x|z) across different x and z
        """

        raise NotImplementedError