FunSR / tools /data_tools /create_lmdb_with_keys.py
KyanChen's picture
add
02c5426
raw
history blame
1.83 kB
import glob
import os
import lmdb
import numpy as np
import pickle
import sys
import tqdm
import shutil
pre_path = r'H:\DataSet\SceneCls\UCMerced_LandUse\UCMerced_LandUse\Images'
file_list = glob.glob(pre_path+'/*/*')
dataset_name = 'UCMerced'
cache_keys = ['filename', 'gt_label']
lmdb_path = os.path.abspath(pre_path + f'/../{dataset_name}_lmdb')
# if os.path.exists(pre_path + f'/lmdb'):
# shutil.rmtree(pre_path + f'/lmdb')
os.makedirs(lmdb_path, exist_ok=True)
data_size_per_item = sys.getsizeof(open(file_list[0], 'rb').read())
print(f'data size:{data_size_per_item}')
env = lmdb.open(lmdb_path+f'\\{os.path.basename(lmdb_path)}.lmdb', map_size=data_size_per_item * 1e5)
txn = env.begin(write=True)
commit_interval = 5
keys_list = []
for idx, file in enumerate(file_list):
key = f'{dataset_name}_{os.path.basename(file).split(".")[0]}'
keys_list.append(key)
for cache_key in cache_keys:
if cache_key == 'filename':
value = os.path.basename(os.path.dirname(file)) + '/' + os.path.basename(file)
elif cache_key == 'img':
with open(file, 'rb') as f:
# 读取图像文件的二进制格式数据
value = f.read()
elif cache_key == 'gt_label':
value = os.path.basename(os.path.dirname(file))
cache_key = key + f'_{cache_key}'
cache_key = cache_key.encode()
if isinstance(value, bytes):
txn.put(cache_key, value)
else:
# 标签类型为str, 转为bytes
txn.put(cache_key, value.encode()) # 编码
if idx % commit_interval == 1:
txn.commit()
txn = env.begin(write=True)
txn.commit()
env.close()
keys_list = np.array(keys_list)
np.savetxt(open(pre_path+'/../keys_list.txt', 'w'), keys_list, fmt='%s')
print(f'Finish writing!')