# coding=utf-8 # Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Dataset to distilled models adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) """ import numpy as np import torch from torch.utils.data import Dataset from utils import logger class LmSeqsDataset(Dataset): """Custom Dataset wrapping language modeling sequences. Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths. Input: ------ params: `NameSpace` parameters data: `List[np.array[int]] """ def __init__(self, params, data): self.params = params self.token_ids = np.array(data) self.lengths = np.array([len(t) for t in data]) self.check() self.remove_long_sequences() self.remove_empty_sequences() self.remove_unknown_sequences() self.check() self.print_statistics() def __getitem__(self, index): return (self.token_ids[index], self.lengths[index]) def __len__(self): return len(self.lengths) def check(self): """ Some sanity checks """ assert len(self.token_ids) == len(self.lengths) assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths))) def remove_long_sequences(self): """ Sequences that are too long are split by chunk of max_model_input_size. """ max_len = self.params.max_model_input_size indices = self.lengths > max_len logger.info(f"Splitting {sum(indices)} too long sequences.") def divide_chunks(l, n): return [l[i : i + n] for i in range(0, len(l), n)] new_tok_ids = [] new_lengths = [] if self.params.mlm: cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"] else: cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"] for seq_, len_ in zip(self.token_ids, self.lengths): assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_ if len_ <= max_len: new_tok_ids.append(seq_) new_lengths.append(len_) else: sub_seqs = [] for sub_s in divide_chunks(seq_, max_len - 2): if sub_s[0] != cls_id: sub_s = np.insert(sub_s, 0, cls_id) if sub_s[-1] != sep_id: sub_s = np.insert(sub_s, len(sub_s), sep_id) assert len(sub_s) <= max_len assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s sub_seqs.append(sub_s) new_tok_ids.extend(sub_seqs) new_lengths.extend([len(l) for l in sub_seqs]) self.token_ids = np.array(new_tok_ids) self.lengths = np.array(new_lengths) def remove_empty_sequences(self): """ Too short sequences are simply removed. This could be tuned. """ init_size = len(self) indices = self.lengths > 11 self.token_ids = self.token_ids[indices] self.lengths = self.lengths[indices] new_size = len(self) logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.") def remove_unknown_sequences(self): """ Remove sequences with a (too) high level of unknown tokens. """ if "unk_token" not in self.params.special_tok_ids: return else: unk_token_id = self.params.special_tok_ids["unk_token"] init_size = len(self) unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids]) indices = (unk_occs / self.lengths) < 0.5 self.token_ids = self.token_ids[indices] self.lengths = self.lengths[indices] new_size = len(self) logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).") def print_statistics(self): """ Print some statistics on the corpus. Only the master process. """ if not self.params.is_master: return logger.info(f"{len(self)} sequences") # data_len = sum(self.lengths) # nb_unique_tokens = len(Counter(list(chain(*self.token_ids)))) # logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)') # unk_idx = self.params.special_tok_ids['unk_token'] # nb_unknown = sum([(t==unk_idx).sum() for t in self.token_ids]) # logger.info(f'{nb_unknown} unknown tokens (covering {100*nb_unknown/data_len:.2f}% of the data)') def batch_sequences(self, batch): """ Do the padding and transform into torch.tensor. """ token_ids = [t[0] for t in batch] lengths = [t[1] for t in batch] assert len(token_ids) == len(lengths) # Max for paddings max_seq_len_ = max(lengths) # Pad token ids if self.params.mlm: pad_idx = self.params.special_tok_ids["pad_token"] else: pad_idx = self.params.special_tok_ids["unk_token"] tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids] assert len(tk_) == len(token_ids) assert all(len(t) == max_seq_len_ for t in tk_) tk_t = torch.tensor(tk_) # (bs, max_seq_len_) lg_t = torch.tensor(lengths) # (bs) return tk_t, lg_t