sino
commited on
Commit
·
0b6f771
1
Parent(s):
a1edc98
Update modeling_maelm.py
Browse files- modeling_maelm.py +13 -13
modeling_maelm.py
CHANGED
@@ -192,9 +192,9 @@ class MAEForCausalLM(PreTrainedModel):
|
|
192 |
if bk_name == 'MAEViT':
|
193 |
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
|
194 |
self.backbone = MAEViT(**backbone)
|
195 |
-
if ckpt_path is not None:
|
196 |
-
|
197 |
-
|
198 |
|
199 |
elif bk_name == 'HTSAT':
|
200 |
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
|
@@ -239,16 +239,16 @@ class MAEForCausalLM(PreTrainedModel):
|
|
239 |
# float32 --> bfloat16
|
240 |
for p in self.parameters():
|
241 |
p.data = p.data.to(torch.bfloat16)
|
242 |
-
if config.resume_from_checkpoint is not None:
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
elif config.resume_from_pth is not None:
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
|
253 |
if False:
|
254 |
self.patch_llm()
|
|
|
192 |
if bk_name == 'MAEViT':
|
193 |
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
|
194 |
self.backbone = MAEViT(**backbone)
|
195 |
+
#if ckpt_path is not None:
|
196 |
+
# ckpt = torch.load( ckpt_path,'cpu')
|
197 |
+
# self.backbone.load_state_dict(ckpt['state_dict'])
|
198 |
|
199 |
elif bk_name == 'HTSAT':
|
200 |
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
|
|
|
239 |
# float32 --> bfloat16
|
240 |
for p in self.parameters():
|
241 |
p.data = p.data.to(torch.bfloat16)
|
242 |
+
#if config.resume_from_checkpoint is not None:
|
243 |
+
# drain_loader = True
|
244 |
+
# accelerator.load_state(config.resume_from_checkpoint, load_module_strict=False)
|
245 |
+
# # start_epoch, start_step, all_step = [int(_.split('_')[1]) for _ in args.resume_from_checkpoint.split('/')[-2].split('-')]
|
246 |
+
#elif config.resume_from_pth is not None:
|
247 |
+
# print(f'###########loading##########{config.resume_from_pth}###########loading##########')
|
248 |
+
# ckpt = torch.load(config.resume_from_pth, map_location='cpu')
|
249 |
+
# ckpt_copy = {k[7:]: v for k, v in ckpt.items()}
|
250 |
+
# self.load_state_dict(ckpt_copy, strict=False)
|
251 |
+
# print(f'###########loaded##########{config.resume_from_pth}###########loaded##########')
|
252 |
|
253 |
if False:
|
254 |
self.patch_llm()
|