|  | import os | 
					
						
						|  | import yaml | 
					
						
						|  | import torch | 
					
						
						|  | from transformers import AlbertConfig, AlbertModel | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CustomAlbert(AlbertModel): | 
					
						
						|  | def forward(self, *args, **kwargs): | 
					
						
						|  |  | 
					
						
						|  | outputs = super().forward(*args, **kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return outputs.last_hidden_state | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_plbert(log_dir): | 
					
						
						|  | config_path = os.path.join(log_dir, "config.yml") | 
					
						
						|  | plbert_config = yaml.safe_load(open(config_path)) | 
					
						
						|  |  | 
					
						
						|  | albert_base_configuration = AlbertConfig(**plbert_config["model_params"]) | 
					
						
						|  | bert = CustomAlbert(albert_base_configuration) | 
					
						
						|  |  | 
					
						
						|  | files = os.listdir(log_dir) | 
					
						
						|  | ckpts = [] | 
					
						
						|  | for f in os.listdir(log_dir): | 
					
						
						|  | if f.startswith("step_"): | 
					
						
						|  | ckpts.append(f) | 
					
						
						|  |  | 
					
						
						|  | iters = [ | 
					
						
						|  | int(f.split("_")[-1].split(".")[0]) | 
					
						
						|  | for f in ckpts | 
					
						
						|  | if os.path.isfile(os.path.join(log_dir, f)) | 
					
						
						|  | ] | 
					
						
						|  | iters = sorted(iters)[-1] | 
					
						
						|  |  | 
					
						
						|  | checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location="cpu") | 
					
						
						|  | state_dict = checkpoint["net"] | 
					
						
						|  | from collections import OrderedDict | 
					
						
						|  |  | 
					
						
						|  | new_state_dict = OrderedDict() | 
					
						
						|  | for k, v in state_dict.items(): | 
					
						
						|  | name = k[7:] | 
					
						
						|  | if name.startswith("encoder."): | 
					
						
						|  | name = name[8:] | 
					
						
						|  | new_state_dict[name] = v | 
					
						
						|  | del new_state_dict["embeddings.position_ids"] | 
					
						
						|  | bert.load_state_dict(new_state_dict, strict=False) | 
					
						
						|  |  | 
					
						
						|  | return bert | 
					
						
						|  |  |