File size: 559 Bytes
81d3845
 
 
 
 
 
 
 
 
32580c1
 
 
81d3845
32580c1
 
 
81d3845
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""
helper util to calculate dataset lengths
"""
import numpy as np


def get_dataset_lengths(dataset):
    if "length" in dataset.data.column_names:
        lengths = np.array(dataset.data.column("length"))
    elif "position_ids" in dataset.data.column_names:
        position_ids = dataset.data.column("position_ids")
        lengths = np.array([x[-1] + 1 for x in position_ids])
    else:
        input_ids = dataset.data.column("input_ids")
        lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
        return lengths
    return lengths