File size: 2,162 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# UniMax Language Dataset Sampler with DDP support

This repository contains an unofficial implementation of the UNIMAX sampling algorithm using PyTorch. The UNIMAX algorithm ["UniMax: Fairer and more Effective Language Sampling for Large-Scale Multilingual Pretraining" by HW Chung et al. (ICLR 2023)](https://arxiv.org/abs/2304.09151) is used to generate a sampling distribution of languages based on their character counts, a total character budget, and a specified number of epochs per language. This can be useful for training language models on datasets with imbalanced language distribution.

## Contents

1. `unimax_sampler.py`: This Python file contains the `UnimaxSampler` class, a PyTorch `Sampler` that uses the UNIMAX algorithm.

2. `test_unimax_sampler.py`: This Python file contains a unit test for the `UnimaxSampler` class to ensure its correct functionality.

## Usage

```python
from torch.utils.data import Dataset, DataLoader
from unimax_sampler import UnimaxSampler

# Define your parameters
language_character_counts = [100, 200, 300, 400, 500]
total_character_budget = 1000
num_epochs = 2

# Create the UnimaxSampler
unimax_sampler = UnimaxSampler(language_character_counts, total_character_budget, num_epochs)
```

Then, use the sampler as the sampler argument when creating a DataLoader.

```python
# Disable shuffle when using custom sampler...
data_loader = DataLoader(my_dataset, batch_size=2, shuffle=None, sampler=unimax_sampler)
```

For DDP,
```python
if torch.distributed.is_initialized():
    sampler = DistributedUnimaxSampler(...)
else:
    return unimax_sampler(...)
```

## Note
The initial version of this code was created by [Chat GPT-4](https://chat.openai.com/), based on the pseudocode provided in the [UNIMAX](https://arxiv.org/abs/2304.09151) paper. Subsequently, the code was manually revised for `PyTorch` Distributed Data Parallel ([DDP](https://pytorch.org/docs/stable/notes/ddp.html)) framework. The DistributedSamplerWrapper implementation is derived from an earlier version found in the [Catalyst](https://github.com/catalyst-team/catalyst) project.

## License
This project is licensed under the MIT License.