Spaces:
Sleeping
Sleeping
File size: 4,653 Bytes
19dfa7a 93d0d1a 19dfa7a b93c8a7 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a b93c8a7 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a b93c8a7 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a b93c8a7 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a fda141d 93d0d1a b93c8a7 19dfa7a b93c8a7 fda141d b93c8a7 fda141d b93c8a7 fda141d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from abc import ABC, abstractmethod
import gradio as gr
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.model import Mammal
class MammalObjectBroker:
def __init__(
self,
model_path: str,
name: str | None = None,
task_list: list[str] | None = None,
*,
force_preload=False,
) -> None:
self.model_path = model_path
if name is None:
name = model_path
self.name = name
self.tasks: list[str] = []
if task_list is not None:
self.tasks = task_list
self._model: Mammal | None = None
self._tokenizer_op = None
if force_preload:
self.force_preload()
@property
def model(self) -> Mammal:
if self._model is None:
self._model = Mammal.from_pretrained(self.model_path)
self._model.eval()
return self._model
@property
def tokenizer_op(self):
if self._tokenizer_op is None:
self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
return self._tokenizer_op
def force_preload(self):
"""pre-load the model and tokenizer (in this order)"""
_ = self.model
_ = self.tokenizer_op
class MammalTask(ABC):
def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None:
self.name = name
self.description = None
self._demo = None
self.model_dict = model_dict
@abstractmethod
def crate_sample_dict(
self, sample_inputs: dict, model_holder: MammalObjectBroker
) -> dict:
"""Formatting prompt to match pre-training syntax
Args:
prompt (str): _description_
Returns:
dict: sample_dict for feeding into model
"""
raise NotImplementedError()
# @abstractmethod
def run_model(self, sample_dict, model: Mammal):
raise NotImplementedError()
def create_demo(self, model_name_widget: gr.component) -> gr.Group:
"""create an gradio demo group
Args:
model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create
gradio actions with the current model name as an input
Raises:
NotImplementedError: _description_
"""
raise NotImplementedError()
def demo(self, model_name_widgit: gr.component = None):
if self._demo is None:
self._demo = self.create_demo(model_name_widget=model_name_widgit)
return self._demo
@abstractmethod
def decode_output(self, batch_dict, model: Mammal) -> list:
raise NotImplementedError()
# classification helpers
@staticmethod
def positive_token_id(tokenizer_op: ModularTokenizerOp) -> int:
"""token for positive binding
Args:
model (MammalTrainedModel): model holding tokenizer
Returns:
int: id of positive binding token
"""
return tokenizer_op.get_token_id("<1>")
@staticmethod
def negative_token_id(tokenizer_op: ModularTokenizerOp) -> int:
"""token for negative binding
Args:
model (MammalTrainedModel): model holding tokenizer
Returns:
int: id of negative binding token
"""
return tokenizer_op.get_token_id("<0>")
@staticmethod
def get_label_from_token(tokenizer_op: ModularTokenizerOp, token_id):
label_mapping = {
MammalTask.negative_token_id(tokenizer_op): "negative",
MammalTask.positive_token_id(tokenizer_op): "positive",
}
return label_mapping.get(token_id, token_id)
class TaskRegistry(dict[str, MammalTask]):
"""just a dictionary with a register method"""
def register_task(self, task: MammalTask):
self[task.name] = task
return task.name
class ModelRegistry(dict[str, MammalObjectBroker]):
"""just a dictionary with a register models"""
def register_model(
self, model_path, task_list=None, name=None, *, force_preload=False
):
"""register a model and return the name of the model
Args:
model_path (_type_): _description_
name (optional str): explicit name for the model
Returns:
str: model name
"""
model_holder = MammalObjectBroker(
model_path=model_path,
task_list=task_list,
name=name,
force_preload=force_preload,
)
self[model_holder.name] = model_holder
return model_holder.name
|