suayptalha commited on
Commit
40f8b85
·
verified ·
1 Parent(s): 6a0b190

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +7 -6
modeling_minGRULM.py CHANGED
@@ -129,13 +129,14 @@ class MinGRULMForCausalLM(PreTrainedModel):
129
  model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
130
  return model
131
 
132
- def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True):
133
  """
134
  Save the model and configuration to a directory.
135
 
136
  Args:
137
  save_directory (str): Directory to save the model.
138
  safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
 
139
  """
140
  import os
141
  os.makedirs(save_directory, exist_ok=True)
@@ -144,13 +145,13 @@ class MinGRULMForCausalLM(PreTrainedModel):
144
  print("Saving with safe serialization.")
145
 
146
  state_dict = {}
147
-
148
  for name, param in self.model.min_gru_model.named_parameters():
149
  state_dict[f"model.{name}"] = param
150
-
151
- for name, param in self.lm_head.named_parameters():
152
- state_dict[f"lm_head.{name}"] = param
153
-
154
  state_dict['config'] = self.config.__dict__
155
  torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
156
 
 
129
  model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
130
  return model
131
 
132
+ def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True, **kwargs):
133
  """
134
  Save the model and configuration to a directory.
135
 
136
  Args:
137
  save_directory (str): Directory to save the model.
138
  safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
139
+ kwargs: Additional arguments like max_shard_size (ignored in this implementation).
140
  """
141
  import os
142
  os.makedirs(save_directory, exist_ok=True)
 
145
  print("Saving with safe serialization.")
146
 
147
  state_dict = {}
148
+
149
  for name, param in self.model.min_gru_model.named_parameters():
150
  state_dict[f"model.{name}"] = param
151
+
152
+ for name, param in self.classifier.named_parameters():
153
+ state_dict[f"classifier.{name}"] = param
154
+
155
  state_dict['config'] = self.config.__dict__
156
  torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
157