Question on fine-tuning the model without a down-stream task

#110
by abayegan - opened

Hi, great work! Do you have any suggestions for updating the pre-trained model without a down-stream task? Would it be legit to use BertForMaskedLM.from_pretrained(<path to your pretrained model>) and continue training it on a new dataset? Which datacollator should I use?

Thank you for your interest in Geneformer! Yes, I would suggest following the example script for pretraining, but substituting "model = BertForMaskedLM(config)" with "model = BertForMaskedLM.from_pretrained(/path/to/pretrained_model/)"

ctheodoris changed discussion status to closed

Thank you for the suggestion about this. I encountered the invalid key' error when I substituted

config = {
    "hidden_size": num_embed_dim,
    "num_hidden_layers": num_layers,
    "initializer_range": initializer_range,
    "layer_norm_eps": layer_norm_eps,
    "attention_probs_dropout_prob": attention_probs_dropout_prob,
    "hidden_dropout_prob": hidden_dropout_prob,
    "intermediate_size": intermed_size,
    "hidden_act": activ_fn,
    "max_position_embeddings": max_input_size,
    "model_type": model_type,
    "num_attention_heads": num_attn_heads,
    "pad_token_id": token_dictionary.get("<pad>"),
    "vocab_size": len(token_dictionary),  # genes+2 for <mask> and <pad> tokens
}

config = BertConfig(**config)
model = BertForMaskedLM(config)

to

model = BertForMaskedLM.from_pretrained("/mnt/c/Users/pc/Downloads/Geneformer/geneformer-12L-30M", 
                                                          output_attentions = False,
                                                          output_hidden_states = False) 

It seemed the 25196624 were the cells you used for the pre-trained model, and the 53950 were the cells I used for further fine-tuning. I found the problem might be the use of improper lengths_file (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/genecorpus_30M_2048_lengths.pkl). Could you kindly refer me to the way to generate my own lengths_file.pkl? Thanks.

error:

DESKTOP-6FHRRIO:5553:5700 [0] NCCL INFO comm 0x55c7ab4ecae0 rank 0 nranks 1 cudaDev 0 busId 3000 - Init COMPLETE
  0%|                                                                                       | 0/6851556 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/mnt/c/Users/pc/Downloads/pretrain_geneformer_w_deepspeed.py", line 168, in <module>
    trainer.train()
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/transformers/trainer.py", line 1645, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/transformers/trainer.py", line 1916, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 2796, in __getitems__
    batch = self.__getitem__(keys)
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 2792, in __getitem__
    return self._getitem(key)
           ^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 2776, in _getitem
    pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/datasets/formatting/formatting.py", line 583, in query_table
    _check_valid_index_key(key, size)
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/datasets/formatting/formatting.py", line 536, in _check_valid_index_key
    _check_valid_index_key(int(max(key)), size=size)
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/datasets/formatting/formatting.py", line 526, in _check_valid_index_key
    raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
IndexError: Invalid key: 25196624 is out of bounds for size 53950
  0%|                                                                                       | 0/6851556 [00:18<?, ?it/s]

Thank you for your question. The lengths file is the list of lengths of each rank value encoding in the pretraining corpus. The transcriptomic tokenizer already maps the lengths so they should be in a column in the resulting .dataset. They can be extracted with:
dataset_lengths = dataset["length"]

Please also see the relevant closed discussion here: https://huggingface.co/ctheodoris/Geneformer/discussions/61

If you did not use the transcriptome tokenizer, you can map the lengths as shown in the tokenizer:

def measure_length(example):
example["length"] = len(example["input_ids"])
return example

dataset = dataset.map(measure_length, num_proc=nproc)

Thank you for your answer. I successfully get the length file into pkl with the following code

data = train_dataset['length']
file_path = '/mnt/c/Users/pc/Downloads/train_dataset_length.pkl'

# Save data as a pickle file
with open(file_path, 'wb') as file:
    pickle.dump(data, file)

print("Data saved as pickle file:", file_path)

However, another error occurred below:

DESKTOP-6FHRRIO:4363:4517 [0] NCCL INFO comm 0x555c70f150e0 rank 0 nranks 1 cudaDev 0 busId 3000 - Init COMPLETE
  0%|                                                                                         | 0/13488 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/mnt/c/Users/pc/Downloads/pretrain_geneformer_w_deepspeed.py", line 192, in <module>
    trainer.train()
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/transformers/trainer.py", line 1645, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/transformers/trainer.py", line 1916, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/transformers/data/data_collator.py", line 45, in __call__
    return self.torch_call(features)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/transformers/data/data_collator.py", line 732, in torch_call
    batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/geneformer/pretrainer.py", line 397, in pad
    padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pc/miniconda3/envs/Transformers/lib/python3.11/site-packages/geneformer/pretrainer.py", line 250, in _get_padding_truncation_strategies
    not self.pad_token or self.pad_token_id < 0
                          ^^^^^^^^^^^^^^^^^^^^^
TypeError: '<' not supported between instances of 'NoneType' and 'int'
  0%|                                                                                         | 0/13488 [00:00<?, ?it/s]
DESKTOP-6FHRRIO:4363:4518 [0] NCCL INFO [Service thread] Connection closed by localRank 0
DESKTOP-6FHRRIO:4363:4363 [0] NCCL INFO comm 0x555c70f150e0 rank 0 nranks 1 cudaDev 0 busId 3000 - Abort COMPLETE
[2023-07-17 09:13:53,853] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 4363
[2023-07-17 09:13:53,853] [ERROR] [launch.py:321:sigkill_handler] ['/home/pc/miniconda3/envs/Transformers/bin/python3.11', '-u', '/mnt/c/Users/pc/Downloads/pretrain_geneformer_w_deepspeed.py', '--local_rank=0', '--deepspeed', '/home/pc/transformers/tests/deepspeed/ds_config_zero3.json'] exits with return code = 1

It was due to line 250 not self.pad_token or self.pad_token_id < 0 in the pretrainer.py. Was that due to something wrong with the token dictionary, the original pretrained model, or my dataset(Below is the head of the token_dictionary)?

In [5]: token_dictionary
Out[5]:
{0: '<pad>',
 1: '<mask>',
 2: 'ENSG00000000003',
 3: 'ENSG00000000005',
 4: 'ENSG00000000419',
 5: 'ENSG00000000457',
 6: 'ENSG00000000460',
 7: 'ENSG00000000938',
 8: 'ENSG00000000971',
 9: 'ENSG00000001036',
 10: 'ENSG00000001084',
 11: 'ENSG00000001167',
 12: 'ENSG00000001460',
 13: 'ENSG00000001461',
 14: 'ENSG00000001497',
 15: 'ENSG00000001561',
 16: 'ENSG00000001617',
 17: 'ENSG00000001626',
 18: 'ENSG00000001629',
 19: 'ENSG00000001630',
 20: 'ENSG00000001631',
 21: 'ENSG00000002016',
.
.

Thank you for your question. The token dictionary in the repository is inverted from the one you pasted in your comment. Please pull the current repository to ensure you are not using an outdated version.

The update worked. Thank you.

Two more questions:

  1. If I have 3 new datasets (A, B, and C), would you recommend sequential fine-tuning (A first, B, and finally C) or doing all 3 together concurrently? In other words, will the classifying accuracy of A be worse after the fine-tuning of C if I do it sequentially?
  2. If I cannot do this at the batch size of 12, is a smaller batch size detrimental to this sort of "fine-tuning"?

Thank you for your question. The best route depends on the data and your scientific question. If the three datasets are equivalent, it would likely be better to fine tune with them together so the model gets a more generalizable understanding of the data. If there is a specific reason you would do them sequentially (e.g. the first dataset is more general to get a baseline view of the informational space, the second fine-tunes to a more specific question, etc.), then it may be helpful to fine-tune sequentially. In terms of whether the last dataset will be better remembered than the first if you fine-tune sequentially, this could happen, but it depends on other factors as well such as the size of each dataset and how many layers you freeze vs. allow to be tunable at each step.

Sign up or log in to comment