File size: 2,388 Bytes
0d7e8be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod

import torch


class PromptAlignment(ABC):
    @abstractmethod
    def start_index(self, input_ids: list[list[int]]) -> int:
        ...

    @abstractmethod
    def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
        ...

    @abstractmethod
    def postprocess_inputs(
        self, inputs: torch.Tensor, original_inputs: torch.Tensor
    ) -> torch.Tensor:
        ...


class AlignPromptRight(PromptAlignment):
    def __init__(self, pad_id: int):
        self.pad_id = pad_id

    def start_index(self, input_ids: list[list[int]]) -> int:
        return max(len(sublist) for sublist in input_ids)

    def prepare_inputs(self, input_ids: list[list[int]]) -> torch.LongTensor:
        max_length = max(len(sublist) for sublist in input_ids)
        return torch.tensor(
            [
                ([self.pad_id] * (max_length - len(sublist))) + sublist
                for sublist in input_ids
            ],
            requires_grad=False,
        )

    def postprocess_inputs(
        self,
        inputs: torch.Tensor,
        original_inputs: torch.Tensor,
    ) -> torch.Tensor:
        return inputs


class AlignPromptLeft(PromptAlignment):
    def __init__(self, pad_id: int = -1):
        self.pad_id = pad_id

    def start_index(self, input_ids: list[list[int]]) -> int:
        return min(len(sublist) for sublist in input_ids)

    def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
        max_length = max(len(sublist) for sublist in input_ids)
        return torch.tensor(
            [
                sublist + ([self.pad_id] * (max_length - len(sublist)))
                for sublist in input_ids
            ],
            requires_grad=False,
        )

    def postprocess_inputs(
        self,
        inputs: torch.Tensor,
        original_inputs: torch.Tensor,
    ) -> torch.Tensor:
        max_init_len = original_inputs.shape[1]
        if inputs.shape[1] <= max_init_len:
            original_inputs_limited = original_inputs[:, : inputs.shape[1]]
            mask = original_inputs_limited != self.pad_id
            inputs[mask] = original_inputs_limited[mask]
        return inputs