Spaces:
Sleeping
Sleeping
| # Deep learning | |
| import torch | |
| # Data | |
| from pubchem_encoder import Encoder | |
| from datasets import load_dataset | |
| # Standard library | |
| import os | |
| import getpass | |
| import glob | |
| class MoleculeModule: | |
| def __init__(self, max_len, dataset, data_path): | |
| super().__init__() | |
| self.dataset = dataset | |
| self.data_path = data_path | |
| self.text_encoder = Encoder(max_len) | |
| def prepare_data(self): | |
| pass | |
| def get_vocab(self): | |
| #using home made tokenizer, should look into existing tokenizer | |
| return self.text_encoder.char2id | |
| def get_cache(self): | |
| return self.cache_files | |
| def setup(self, stage=None): | |
| #using huggingface dataloader | |
| # create cache in tmp directory of locale mabchine under the current users name to prevent locking issues | |
| pubchem_path = {'train': self.data_path} | |
| if 'canonical' in pubchem_path['train'].lower(): | |
| pubchem_script = './pubchem_canon_script.py' | |
| else: | |
| pubchem_script = './pubchem_script.py' | |
| zinc_path = './data/ZINC' | |
| global dataset_dict | |
| if 'ZINC' in self.dataset or 'zinc' in self.dataset: | |
| zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))] | |
| for zfile in zinc_files: | |
| print(zfile) | |
| self.dataset = {'train': zinc_files} | |
| dataset_dict = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True) | |
| elif 'pubchem' in self.dataset: | |
| dataset_dict = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'), split='train') | |
| elif 'both' in self.dataset or 'Both' in self.dataset or 'BOTH' in self.dataset: | |
| dataset_dict_pubchem = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'),split='train', trust_remote_code=True) | |
| zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))] | |
| for zfile in zinc_files: | |
| print(zfile) | |
| self.dataset = {'train': zinc_files} | |
| dataset_dict_zinc = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True) | |
| dataset_dict = concatenate_datasets([dataset_dict_zinc, dataset_dict_pubchem]) | |
| self.pubchem= dataset_dict | |
| print(dataset_dict.cache_files) | |
| self.cache_files = [] | |
| for cache in dataset_dict.cache_files: | |
| tmp = '/'.join(cache['filename'].split('/')[:4]) | |
| self.cache_files.append(tmp) | |
| def get_optim_groups(module): | |
| # setup optimizer | |
| # separate out all parameters to those that will and won't experience regularizing weight decay | |
| decay = set() | |
| no_decay = set() | |
| whitelist_weight_modules = (torch.nn.Linear,) | |
| blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) | |
| for mn, m in module.named_modules(): | |
| for pn, p in m.named_parameters(): | |
| fpn = '%s.%s' % (mn, pn) if mn else pn # full param name | |
| if pn.endswith('bias'): | |
| # all biases will not be decayed | |
| no_decay.add(fpn) | |
| elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): | |
| # weights of whitelist modules will be weight decayed | |
| decay.add(fpn) | |
| elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): | |
| # weights of blacklist modules will NOT be weight decayed | |
| no_decay.add(fpn) | |
| # validate that we considered every parameter | |
| param_dict = {pn: p for pn, p in module.named_parameters()} | |
| # create the pytorch optimizer object | |
| optim_groups = [ | |
| {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0}, | |
| {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, | |
| ] | |
| return optim_groups |