Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # SiamMask | |
| # Licensed under The MIT License | |
| # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) | |
| # -------------------------------------------------------- | |
| import torch.nn as nn | |
| import logging | |
| logger = logging.getLogger('global') | |
| class Features(nn.Module): | |
| def __init__(self): | |
| super(Features, self).__init__() | |
| self.feature_size = -1 | |
| def forward(self, x): | |
| raise NotImplementedError | |
| def param_groups(self, start_lr, feature_mult=1): | |
| params = filter(lambda x:x.requires_grad, self.parameters()) | |
| params = [{'params': params, 'lr': start_lr * feature_mult}] | |
| return params | |
| def load_model(self, f='pretrain.model'): | |
| with open(f) as f: | |
| pretrained_dict = torch.load(f) | |
| model_dict = self.state_dict() | |
| print(pretrained_dict.keys()) | |
| pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} | |
| print(pretrained_dict.keys()) | |
| model_dict.update(pretrained_dict) | |
| self.load_state_dict(model_dict) | |
| class MultiStageFeature(Features): | |
| def __init__(self): | |
| super(MultiStageFeature, self).__init__() | |
| self.layers = [] | |
| self.train_num = -1 | |
| self.change_point = [] | |
| self.train_nums = [] | |
| def unfix(self, ratio=0.0): | |
| if self.train_num == -1: | |
| self.train_num = 0 | |
| self.unlock() | |
| self.eval() | |
| for p, t in reversed(list(zip(self.change_point, self.train_nums))): | |
| if ratio >= p: | |
| if self.train_num != t: | |
| self.train_num = t | |
| self.unlock() | |
| return True | |
| break | |
| return False | |
| def train_layers(self): | |
| return self.layers[:self.train_num] | |
| def unlock(self): | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| logger.info('Current training {} layers:\n\t'.format(self.train_num, self.train_layers())) | |
| for m in self.train_layers(): | |
| for p in m.parameters(): | |
| p.requires_grad = True | |
| def train(self, mode): | |
| self.training = mode | |
| if mode == False: | |
| super(MultiStageFeature,self).train(False) | |
| else: | |
| for m in self.train_layers(): | |
| m.train(True) | |
| return self | |