Update modeling_minGRULM.py
Browse files- modeling_minGRULM.py +7 -6
modeling_minGRULM.py
CHANGED
@@ -129,13 +129,14 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
129 |
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
130 |
return model
|
131 |
|
132 |
-
def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True):
|
133 |
"""
|
134 |
Save the model and configuration to a directory.
|
135 |
|
136 |
Args:
|
137 |
save_directory (str): Directory to save the model.
|
138 |
safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
|
|
|
139 |
"""
|
140 |
import os
|
141 |
os.makedirs(save_directory, exist_ok=True)
|
@@ -144,13 +145,13 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
144 |
print("Saving with safe serialization.")
|
145 |
|
146 |
state_dict = {}
|
147 |
-
|
148 |
for name, param in self.model.min_gru_model.named_parameters():
|
149 |
state_dict[f"model.{name}"] = param
|
150 |
-
|
151 |
-
for name, param in self.
|
152 |
-
state_dict[f"
|
153 |
-
|
154 |
state_dict['config'] = self.config.__dict__
|
155 |
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
156 |
|
|
|
129 |
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
130 |
return model
|
131 |
|
132 |
+
def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True, **kwargs):
|
133 |
"""
|
134 |
Save the model and configuration to a directory.
|
135 |
|
136 |
Args:
|
137 |
save_directory (str): Directory to save the model.
|
138 |
safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
|
139 |
+
kwargs: Additional arguments like max_shard_size (ignored in this implementation).
|
140 |
"""
|
141 |
import os
|
142 |
os.makedirs(save_directory, exist_ok=True)
|
|
|
145 |
print("Saving with safe serialization.")
|
146 |
|
147 |
state_dict = {}
|
148 |
+
|
149 |
for name, param in self.model.min_gru_model.named_parameters():
|
150 |
state_dict[f"model.{name}"] = param
|
151 |
+
|
152 |
+
for name, param in self.classifier.named_parameters():
|
153 |
+
state_dict[f"classifier.{name}"] = param
|
154 |
+
|
155 |
state_dict['config'] = self.config.__dict__
|
156 |
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
157 |
|