File size: 4,653 Bytes
19dfa7a
 
93d0d1a
 
 
 
 
19dfa7a
 
 
 
 
 
b93c8a7
 
19dfa7a
93d0d1a
 
 
19dfa7a
 
 
93d0d1a
19dfa7a
 
93d0d1a
b93c8a7
 
19dfa7a
93d0d1a
19dfa7a
93d0d1a
19dfa7a
b93c8a7
93d0d1a
19dfa7a
93d0d1a
 
 
19dfa7a
93d0d1a
19dfa7a
b93c8a7
 
 
 
 
93d0d1a
 
19dfa7a
 
 
 
 
93d0d1a
 
19dfa7a
 
 
93d0d1a
 
 
 
 
 
 
 
 
 
 
19dfa7a
93d0d1a
19dfa7a
93d0d1a
 
 
 
 
 
 
 
 
 
 
 
 
19dfa7a
93d0d1a
 
 
 
 
fda141d
93d0d1a
 
b93c8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19dfa7a
b93c8a7
 
 
 
 
fda141d
 
 
 
 
 
 
 
 
 
 
 
 
b93c8a7
 
 
fda141d
 
 
 
 
 
 
 
 
b93c8a7
 
 
 
fda141d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from abc import ABC, abstractmethod

import gradio as gr
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.model import Mammal


class MammalObjectBroker:
    def __init__(
        self,
        model_path: str,
        name: str | None = None,
        task_list: list[str] | None = None,
        *,
        force_preload=False,
    ) -> None:
        self.model_path = model_path
        if name is None:
            name = model_path
        self.name = name

        self.tasks: list[str] = []
        if task_list is not None:
            self.tasks = task_list
        self._model: Mammal | None = None
        self._tokenizer_op = None
        if force_preload:
            self.force_preload()

    @property
    def model(self) -> Mammal:
        if self._model is None:
            self._model = Mammal.from_pretrained(self.model_path)
            self._model.eval()
        return self._model

    @property
    def tokenizer_op(self):
        if self._tokenizer_op is None:
            self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
        return self._tokenizer_op

    def force_preload(self):
        """pre-load the model and tokenizer (in this order)"""
        _ = self.model
        _ = self.tokenizer_op


class MammalTask(ABC):
    def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None:
        self.name = name
        self.description = None
        self._demo = None
        self.model_dict = model_dict

    @abstractmethod
    def crate_sample_dict(
        self, sample_inputs: dict, model_holder: MammalObjectBroker
    ) -> dict:
        """Formatting prompt to match pre-training syntax

        Args:
            prompt (str): _description_

        Returns:
            dict: sample_dict for feeding into model
        """
        raise NotImplementedError()

    # @abstractmethod
    def run_model(self, sample_dict, model: Mammal):
        raise NotImplementedError()

    def create_demo(self, model_name_widget: gr.component) -> gr.Group:
        """create an gradio demo group

        Args:
            model_name_widgit (gr.Component): widget holding the model name to use.  This is needed to create
                gradio actions with the current model name as an input


        Raises:
            NotImplementedError: _description_
        """
        raise NotImplementedError()

    def demo(self, model_name_widgit: gr.component = None):
        if self._demo is None:
            self._demo = self.create_demo(model_name_widget=model_name_widgit)
        return self._demo

    @abstractmethod
    def decode_output(self, batch_dict, model: Mammal) -> list:
        raise NotImplementedError()

    # classification helpers
    @staticmethod
    def positive_token_id(tokenizer_op: ModularTokenizerOp) -> int:
        """token for positive binding

        Args:
            model (MammalTrainedModel): model holding tokenizer

        Returns:
            int: id of positive binding token
        """
        return tokenizer_op.get_token_id("<1>")

    @staticmethod
    def negative_token_id(tokenizer_op: ModularTokenizerOp) -> int:
        """token for negative binding

        Args:
            model (MammalTrainedModel): model holding tokenizer

        Returns:
            int: id of negative binding token
        """
        return tokenizer_op.get_token_id("<0>")

    @staticmethod
    def get_label_from_token(tokenizer_op: ModularTokenizerOp, token_id):

        label_mapping = {
            MammalTask.negative_token_id(tokenizer_op): "negative",
            MammalTask.positive_token_id(tokenizer_op): "positive",
        }
        return label_mapping.get(token_id, token_id)


class TaskRegistry(dict[str, MammalTask]):
    """just a dictionary with a register method"""

    def register_task(self, task: MammalTask):
        self[task.name] = task
        return task.name


class ModelRegistry(dict[str, MammalObjectBroker]):
    """just a dictionary with a register models"""

    def register_model(
        self, model_path, task_list=None, name=None, *, force_preload=False
    ):
        """register a model and return the name of the model
        Args:
            model_path (_type_): _description_
            name (optional str): explicit name for the model

        Returns:
            str: model name
        """
        model_holder = MammalObjectBroker(
            model_path=model_path,
            task_list=task_list,
            name=name,
            force_preload=force_preload,
        )
        self[model_holder.name] = model_holder
        return model_holder.name