Spaces:
Sleeping
Sleeping
| import os | |
| import pytest | |
| import time | |
| import random | |
| import functools | |
| import tempfile | |
| from typing import Callable | |
| from ding.data.buffer import DequeBuffer | |
| from ding.data.buffer.buffer import BufferedData | |
| from torch.utils.data import DataLoader | |
| class RateLimit: | |
| r""" | |
| Add rate limit threshold to push function | |
| """ | |
| def __init__(self, max_rate: int = float("inf"), window_seconds: int = 30) -> None: | |
| self.max_rate = max_rate | |
| self.window_seconds = window_seconds | |
| self.buffered = [] | |
| def __call__(self, action: str, chain: Callable, *args, **kwargs): | |
| if action == "push": | |
| return self.push(chain, *args, **kwargs) | |
| return chain(*args, **kwargs) | |
| def push(self, chain, data, *args, **kwargs) -> None: | |
| current = time.time() | |
| # Cut off stale records | |
| self.buffered = [t for t in self.buffered if t > current - self.window_seconds] | |
| if len(self.buffered) < self.max_rate: | |
| self.buffered.append(current) | |
| return chain(data, *args, **kwargs) | |
| else: | |
| return None | |
| def add_10() -> Callable: | |
| """ | |
| Transform data on sampling | |
| """ | |
| def sample(chain: Callable, size: int, replace: bool = False, *args, **kwargs): | |
| sampled_data = chain(size, replace, *args, **kwargs) | |
| return [BufferedData(data=item.data + 10, index=item.index, meta=item.meta) for item in sampled_data] | |
| def _subview(action: str, chain: Callable, *args, **kwargs): | |
| if action == "sample": | |
| return sample(chain, *args, **kwargs) | |
| return chain(*args, **kwargs) | |
| return _subview | |
| def test_naive_push_sample(): | |
| # Push and sample | |
| buffer = DequeBuffer(size=10) | |
| for i in range(20): | |
| buffer.push(i) | |
| assert buffer.count() == 10 | |
| assert 0 not in [item.data for item in buffer.sample(10)] | |
| # Clear | |
| buffer.clear() | |
| assert buffer.count() == 0 | |
| # Test replace sample | |
| for i in range(5): | |
| buffer.push(i) | |
| assert buffer.count() == 5 | |
| assert len(buffer.sample(10, replace=True)) == 10 | |
| # Test slicing | |
| buffer.clear() | |
| for i in range(10): | |
| buffer.push(i) | |
| assert len(buffer.sample(5, sample_range=slice(5, 10))) == 5 | |
| assert 0 not in [item.data for item in buffer.sample(5, sample_range=slice(5, 10))] | |
| def test_rate_limit_push_sample(): | |
| buffer = DequeBuffer(size=10).use(RateLimit(max_rate=5)) | |
| for i in range(10): | |
| buffer.push(i) | |
| assert buffer.count() == 5 | |
| assert 5 not in buffer.sample(5) | |
| def test_load_and_save(): | |
| buffer = DequeBuffer(size=10).use(RateLimit(max_rate=5)) | |
| buffer.meta_index = {"label": []} | |
| for i in range(10): | |
| buffer.push(i, meta={"label": i}) | |
| assert buffer.count() == 5 | |
| assert 5 not in buffer.sample(5) | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| test_file = os.path.join(tmpdirname, "data.hkl") | |
| buffer.save_data(test_file) | |
| buffer_new = DequeBuffer(size=10).use(RateLimit(max_rate=5)) | |
| buffer_new.load_data(test_file) | |
| assert buffer_new.count() == 5 | |
| assert 5 not in buffer_new.sample(5) | |
| assert len(buffer.meta_index["label"]) == 5 | |
| assert all([index < 5 for index in buffer.meta_index["label"]]) | |
| def test_buffer_view(): | |
| buf1 = DequeBuffer(size=10) | |
| for i in range(1): | |
| buf1.push(i) | |
| assert buf1.count() == 1 | |
| buf2 = buf1.view().use(RateLimit(max_rate=5)).use(add_10()) | |
| for i in range(10): | |
| buf2.push(i) | |
| # With 1 record written by buf1 and 5 records written by buf2 | |
| assert len(buf1._middleware) == 0 | |
| assert buf1.count() == 6 | |
| # All data in buffer should bigger than 10 because of `add_10` | |
| assert all(d.data >= 10 for d in buf2.sample(5)) | |
| # But data in storage is still less than 10 | |
| assert all(d.data < 10 for d in buf1.sample(5)) | |
| def test_sample_with_index(): | |
| buf = DequeBuffer(size=10) | |
| for i in range(10): | |
| buf.push({"data": i}, {"meta": i}) | |
| # Random sample and get indices | |
| indices = [item.index for item in buf.sample(10)] | |
| assert len(indices) == 10 | |
| random.shuffle(indices) | |
| indices = indices[:5] | |
| # Resample by indices | |
| new_indices = [item.index for item in buf.sample(indices=indices)] | |
| assert len(new_indices) == len(indices) | |
| for index in new_indices: | |
| assert index in indices | |
| def test_update(): | |
| buf = DequeBuffer(size=10) | |
| for i in range(1): | |
| buf.push({"data": i}, {"meta": i}) | |
| # Update one data | |
| [item] = buf.sample(1) | |
| item.data["new_prop"] = "any" | |
| meta = None | |
| success = buf.update(item.index, item.data, item.meta) | |
| assert success | |
| # Resample | |
| [item] = buf.sample(1) | |
| assert "new_prop" in item.data | |
| assert meta is None | |
| # Update object that not exists in buffer | |
| success = buf.update("invalidindex", {}, None) | |
| assert not success | |
| # When exceed buffer size | |
| for i in range(20): | |
| buf.push({"data": i}) | |
| assert len(buf.indices) == 10 | |
| assert len(buf.storage) == 10 | |
| for i in range(10): | |
| index = buf.storage[i].index | |
| assert buf.indices.get(index) == i | |
| def test_delete(): | |
| maxlen = 100 | |
| cumlen = 40 | |
| dellen = 20 | |
| buf = DequeBuffer(size=maxlen) | |
| for i in range(cumlen): | |
| buf.push(i) | |
| # Delete data | |
| del_indices = [item.index for item in buf.sample(dellen)] | |
| buf.delete(del_indices) | |
| # Reappend | |
| for i in range(10): | |
| buf.push(i) | |
| remlen = min(cumlen, maxlen) - dellen + 10 | |
| assert len(buf.indices) == remlen | |
| assert len(buf.storage) == remlen | |
| for i in range(remlen): | |
| index = buf.storage[i].index | |
| assert buf.indices.get(index) == i | |
| def test_ignore_insufficient(): | |
| buffer = DequeBuffer(size=10) | |
| for i in range(2): | |
| buffer.push(i) | |
| with pytest.raises(ValueError): | |
| buffer.sample(3, ignore_insufficient=False) | |
| data = buffer.sample(3, ignore_insufficient=True) | |
| assert len(data) == 0 | |
| def test_independence(): | |
| # By replace | |
| buffer = DequeBuffer(size=1) | |
| data = {"key": "origin"} | |
| buffer.push(data) | |
| sampled_data = buffer.sample(2, replace=True) | |
| assert len(sampled_data) == 2 | |
| sampled_data[0].data["key"] = "new" | |
| assert sampled_data[1].data["key"] == "origin" | |
| # By indices | |
| buffer = DequeBuffer(size=1) | |
| data = {"key": "origin"} | |
| buffered = buffer.push(data) | |
| indices = [buffered.index, buffered.index] | |
| sampled_data = buffer.sample(indices=indices) | |
| assert len(sampled_data) == 2 | |
| sampled_data[0].data["key"] = "new" | |
| assert sampled_data[1].data["key"] == "origin" | |
| def test_groupby(): | |
| buffer = DequeBuffer(size=3) | |
| buffer.push("a", {"group": 1}) | |
| buffer.push("b", {"group": 2}) | |
| buffer.push("c", {"group": 2}) | |
| sampled_data = buffer.sample(2, groupby="group") | |
| assert len(sampled_data) == 2 | |
| group1 = sampled_data[0] if len(sampled_data[0]) == 1 else sampled_data[1] | |
| group2 = sampled_data[0] if len(sampled_data[0]) == 2 else sampled_data[1] | |
| # Group1 should contain a | |
| assert "a" == group1[0].data | |
| # Group2 should contain b and c | |
| data = [buffered.data for buffered in group2] # ["b", "c"] | |
| assert "b" in data | |
| assert "c" in data | |
| # Push new data and swap out a, the result will all in group 2 | |
| buffer.push("d", {"group": 2}) | |
| sampled_data = buffer.sample(1, groupby="group") | |
| assert len(sampled_data) == 1 | |
| assert len(sampled_data[0]) == 3 | |
| data = [buffered.data for buffered in sampled_data[0]] | |
| assert "d" in data | |
| # Update meta, set first data's group to 1 | |
| first: BufferedData = buffer.storage[0] | |
| buffer.update(first.index, first.data, {"group": 1}) | |
| sampled_data = buffer.sample(2, groupby="group") | |
| assert len(sampled_data) == 2 | |
| # Delete last record, each group will only have one record | |
| last: BufferedData = buffer.storage[-1] | |
| buffer.delete(last.index) | |
| sampled_data = buffer.sample(2, groupby="group") | |
| assert len(sampled_data) == 2 | |
| def test_dataset(): | |
| buffer = DequeBuffer(size=10) | |
| for i in range(10): | |
| buffer.push(i) | |
| dataloader = DataLoader(buffer, batch_size=6, shuffle=True, collate_fn=lambda batch: batch) | |
| for batch in dataloader: | |
| assert len(batch) in [4, 6] | |
| def test_unroll_len_in_group(): | |
| buffer = DequeBuffer(size=100) | |
| for i in range(10): | |
| for env_id in list("ABC"): | |
| buffer.push(i, {"env": env_id}) | |
| sampled_data = buffer.sample(3, groupby="env", unroll_len=4) | |
| assert len(sampled_data) == 3 | |
| for grouped_data in sampled_data: | |
| assert len(grouped_data) == 4 | |
| # Ensure each group has the same env | |
| env_ids = set(map(lambda sample: sample.meta["env"], grouped_data)) | |
| assert len(env_ids) == 1 | |
| # Ensure samples in each group is continuous | |
| result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data) | |
| assert isinstance(result, BufferedData), "Not continuous" | |
| def test_insufficient_unroll_len_in_group(): | |
| buffer = DequeBuffer(size=100) | |
| num = 3 # Items in group A,B,C is 3,4,5 | |
| for env_id in list("ABC"): | |
| for i in range(num): | |
| buffer.push(i, {"env": env_id}) | |
| num += 1 | |
| with pytest.raises(ValueError) as exc_info: | |
| buffer.sample(3, groupby="env", unroll_len=4) | |
| e = exc_info._excinfo[1] | |
| assert "There are less than" in str(e) | |
| # Sample with replace | |
| sampled_data = buffer.sample(3, groupby="env", unroll_len=4, replace=True) | |
| assert len(sampled_data) == 3 | |
| for grouped_data in sampled_data: | |
| assert len(grouped_data) == 4 | |
| # Ensure each group has the same env | |
| env_ids = set(map(lambda sample: sample.meta["env"], grouped_data)) | |
| assert len(env_ids) == 1 | |
| # Ensure samples in each group is continuous | |
| result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data) | |
| assert isinstance(result, BufferedData), "Not continuous" | |
| def test_slice_unroll_len_in_group(): | |
| buffer = DequeBuffer(size=100, sliced=True) | |
| data_len = 10 | |
| unroll_len = 4 | |
| start_index = list(range(0, data_len, unroll_len)) + [data_len - unroll_len] | |
| for i in range(data_len): | |
| for env_id in list("ABC"): | |
| buffer.push(i, {"env": env_id}) | |
| sampled_data = buffer.sample(3, groupby="env", unroll_len=unroll_len) | |
| assert len(sampled_data) == 3 | |
| for grouped_data in sampled_data: | |
| assert len(grouped_data) == 4 | |
| # Ensure each group has the same env | |
| env_ids = set(map(lambda sample: sample.meta["env"], grouped_data)) | |
| assert len(env_ids) == 1 | |
| # Ensure samples in each group is continuous | |
| result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data) | |
| assert isinstance(result, BufferedData), "Not continuous" | |
| # Ensure data after sliced start from correct index | |
| assert grouped_data[0].data in start_index | |