Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Data Cache Utils | |
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
Please cite our work if the code is helpful to you. | |
""" | |
import os | |
import SharedArray | |
try: | |
from multiprocessing.shared_memory import ShareableList | |
except ImportError: | |
import warnings | |
warnings.warn("Please update python version >= 3.8 to enable shared_memory") | |
import numpy as np | |
def shared_array(name, var=None): | |
if var is not None: | |
# check exist | |
if os.path.exists(f"/dev/shm/{name}"): | |
return SharedArray.attach(f"shm://{name}") | |
# create shared_array | |
data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype) | |
data[...] = var[...] | |
data.flags.writeable = False | |
else: | |
data = SharedArray.attach(f"shm://{name}").copy() | |
return data | |
def shared_dict(name, var=None): | |
name = str(name) | |
assert "." not in name # '.' is used as sep flag | |
data = {} | |
if var is not None: | |
assert isinstance(var, dict) | |
keys = var.keys() | |
# current version only cache np.array | |
keys_valid = [] | |
for key in keys: | |
if isinstance(var[key], np.ndarray): | |
keys_valid.append(key) | |
keys = keys_valid | |
ShareableList(sequence=keys, name=name + ".keys") | |
for key in keys: | |
if isinstance(var[key], np.ndarray): | |
data[key] = shared_array(name=f"{name}.{key}", var=var[key]) | |
else: | |
keys = list(ShareableList(name=name + ".keys")) | |
for key in keys: | |
data[key] = shared_array(name=f"{name}.{key}") | |
return data | |