Spaces:
Runtime error
Runtime error
import pickle | |
import numpy as np | |
class IndexedDataset: | |
def __init__(self, path): | |
super().__init__() | |
self.path = path | |
self.data_file = None | |
index_data = np.load(f"{path}.idx", allow_pickle=True).item() | |
self.byte_offsets = index_data['offsets'] | |
self.id2pos = index_data.get('id2pos', {}) | |
self.data_file = open(f"{path}.data", 'rb', buffering=-1) | |
def check_index(self, i): | |
if i < 0 or i >= len(self.byte_offsets) - 1: | |
raise IndexError('index out of range') | |
def __del__(self): | |
if self.data_file: | |
self.data_file.close() | |
def __getitem__(self, i): | |
if self.id2pos is not None and len(self.id2pos) > 0: | |
i = self.id2pos[i] | |
self.check_index(i) | |
self.data_file.seek(self.byte_offsets[i]) | |
b = self.data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i]) | |
item = pickle.loads(b) | |
return item | |
def __len__(self): | |
return len(self.byte_offsets) - 1 | |
def __iter__(self): | |
self.iter_i = 0 | |
return self | |
def __next__(self): | |
if self.iter_i == len(self): | |
raise StopIteration | |
else: | |
item = self[self.iter_i] | |
self.iter_i += 1 | |
return item | |
class IndexedDatasetBuilder: | |
def __init__(self, path, append=False): | |
self.path = path | |
if append: | |
self.data_file = open(f"{path}.data", 'ab') | |
index_data = np.load(f"{path}.idx", allow_pickle=True).item() | |
self.byte_offsets = index_data['offsets'] | |
self.id2pos = index_data.get('id2pos', {}) | |
else: | |
self.data_file = open(f"{path}.data", 'wb') | |
self.byte_offsets = [0] | |
self.id2pos = {} | |
def add_item(self, item, id=None): | |
s = pickle.dumps(item) | |
bytes = self.data_file.write(s) | |
if id is not None: | |
self.id2pos[id] = len(self.byte_offsets) - 1 | |
self.byte_offsets.append(self.byte_offsets[-1] + bytes) | |
def finalize(self): | |
self.data_file.close() | |
np.save(open(f"{self.path}.idx", 'wb'), | |
{'offsets': self.byte_offsets, 'id2pos': self.id2pos}) | |