SAM-DiffSR / sam_diffsr /utils_sr /indexed_datasets.py
Traly's picture
init
193c713
import pickle
import numpy as np
class IndexedDataset:
def __init__(self, path):
super().__init__()
self.path = path
self.data_file = None
index_data = np.load(f"{path}.idx", allow_pickle=True).item()
self.byte_offsets = index_data['offsets']
self.id2pos = index_data.get('id2pos', {})
self.data_file = open(f"{path}.data", 'rb', buffering=-1)
def check_index(self, i):
if i < 0 or i >= len(self.byte_offsets) - 1:
raise IndexError('index out of range')
def __del__(self):
if self.data_file:
self.data_file.close()
def __getitem__(self, i):
if self.id2pos is not None and len(self.id2pos) > 0:
i = self.id2pos[i]
self.check_index(i)
self.data_file.seek(self.byte_offsets[i])
b = self.data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i])
item = pickle.loads(b)
return item
def __len__(self):
return len(self.byte_offsets) - 1
def __iter__(self):
self.iter_i = 0
return self
def __next__(self):
if self.iter_i == len(self):
raise StopIteration
else:
item = self[self.iter_i]
self.iter_i += 1
return item
class IndexedDatasetBuilder:
def __init__(self, path, append=False):
self.path = path
if append:
self.data_file = open(f"{path}.data", 'ab')
index_data = np.load(f"{path}.idx", allow_pickle=True).item()
self.byte_offsets = index_data['offsets']
self.id2pos = index_data.get('id2pos', {})
else:
self.data_file = open(f"{path}.data", 'wb')
self.byte_offsets = [0]
self.id2pos = {}
def add_item(self, item, id=None):
s = pickle.dumps(item)
bytes = self.data_file.write(s)
if id is not None:
self.id2pos[id] = len(self.byte_offsets) - 1
self.byte_offsets.append(self.byte_offsets[-1] + bytes)
def finalize(self):
self.data_file.close()
np.save(open(f"{self.path}.idx", 'wb'),
{'offsets': self.byte_offsets, 'id2pos': self.id2pos})