sino commited on
Commit
0b6f771
·
1 Parent(s): a1edc98

Update modeling_maelm.py

Browse files
Files changed (1) hide show
  1. modeling_maelm.py +13 -13
modeling_maelm.py CHANGED
@@ -192,9 +192,9 @@ class MAEForCausalLM(PreTrainedModel):
192
  if bk_name == 'MAEViT':
193
  ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
194
  self.backbone = MAEViT(**backbone)
195
- if ckpt_path is not None:
196
- ckpt = torch.load( ckpt_path,'cpu')
197
- self.backbone.load_state_dict(ckpt['state_dict'])
198
 
199
  elif bk_name == 'HTSAT':
200
  ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
@@ -239,16 +239,16 @@ class MAEForCausalLM(PreTrainedModel):
239
  # float32 --> bfloat16
240
  for p in self.parameters():
241
  p.data = p.data.to(torch.bfloat16)
242
- if config.resume_from_checkpoint is not None:
243
- drain_loader = True
244
- accelerator.load_state(config.resume_from_checkpoint, load_module_strict=False)
245
- # start_epoch, start_step, all_step = [int(_.split('_')[1]) for _ in args.resume_from_checkpoint.split('/')[-2].split('-')]
246
- elif config.resume_from_pth is not None:
247
- print(f'###########loading##########{config.resume_from_pth}###########loading##########')
248
- ckpt = torch.load(config.resume_from_pth, map_location='cpu')
249
- ckpt_copy = {k[7:]: v for k, v in ckpt.items()}
250
- self.load_state_dict(ckpt_copy, strict=False)
251
- print(f'###########loaded##########{config.resume_from_pth}###########loaded##########')
252
 
253
  if False:
254
  self.patch_llm()
 
192
  if bk_name == 'MAEViT':
193
  ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
194
  self.backbone = MAEViT(**backbone)
195
+ #if ckpt_path is not None:
196
+ # ckpt = torch.load( ckpt_path,'cpu')
197
+ # self.backbone.load_state_dict(ckpt['state_dict'])
198
 
199
  elif bk_name == 'HTSAT':
200
  ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
 
239
  # float32 --> bfloat16
240
  for p in self.parameters():
241
  p.data = p.data.to(torch.bfloat16)
242
+ #if config.resume_from_checkpoint is not None:
243
+ # drain_loader = True
244
+ # accelerator.load_state(config.resume_from_checkpoint, load_module_strict=False)
245
+ # # start_epoch, start_step, all_step = [int(_.split('_')[1]) for _ in args.resume_from_checkpoint.split('/')[-2].split('-')]
246
+ #elif config.resume_from_pth is not None:
247
+ # print(f'###########loading##########{config.resume_from_pth}###########loading##########')
248
+ # ckpt = torch.load(config.resume_from_pth, map_location='cpu')
249
+ # ckpt_copy = {k[7:]: v for k, v in ckpt.items()}
250
+ # self.load_state_dict(ckpt_copy, strict=False)
251
+ # print(f'###########loaded##########{config.resume_from_pth}###########loaded##########')
252
 
253
  if False:
254
  self.patch_llm()