import pickle

import numpy as np


class IndexedDataset:
    def __init__(self, path):
        super().__init__()
        self.path = path
        self.data_file = None
        index_data = np.load(f"{path}.idx", allow_pickle=True).item()
        self.byte_offsets = index_data['offsets']
        self.id2pos = index_data.get('id2pos', {})
        self.data_file = open(f"{path}.data", 'rb', buffering=-1)
    
    def check_index(self, i):
        if i < 0 or i >= len(self.byte_offsets) - 1:
            raise IndexError('index out of range')
    
    def __del__(self):
        if self.data_file:
            self.data_file.close()
    
    def __getitem__(self, i):
        if self.id2pos is not None and len(self.id2pos) > 0:
            i = self.id2pos[i]
        self.check_index(i)
        self.data_file.seek(self.byte_offsets[i])
        b = self.data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i])
        item = pickle.loads(b)
        return item
    
    def __len__(self):
        return len(self.byte_offsets) - 1
    
    def __iter__(self):
        self.iter_i = 0
        return self
    
    def __next__(self):
        if self.iter_i == len(self):
            raise StopIteration
        else:
            item = self[self.iter_i]
            self.iter_i += 1
            return item


class IndexedDatasetBuilder:
    def __init__(self, path, append=False):
        self.path = path
        if append:
            self.data_file = open(f"{path}.data", 'ab')
            index_data = np.load(f"{path}.idx", allow_pickle=True).item()
            self.byte_offsets = index_data['offsets']
            self.id2pos = index_data.get('id2pos', {})
        else:
            self.data_file = open(f"{path}.data", 'wb')
            self.byte_offsets = [0]
            self.id2pos = {}
    
    def add_item(self, item, id=None):
        s = pickle.dumps(item)
        bytes = self.data_file.write(s)
        if id is not None:
            self.id2pos[id] = len(self.byte_offsets) - 1
        self.byte_offsets.append(self.byte_offsets[-1] + bytes)
    
    def finalize(self):
        self.data_file.close()
        np.save(open(f"{self.path}.idx", 'wb'),
                {'offsets': self.byte_offsets, 'id2pos': self.id2pos})