Spaces:
Runtime error
Runtime error
File size: 2,640 Bytes
dea6c7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
import json
import numpy as np
import random
# Randomly sample a subset of prompts for benchmarking
def main(prompt_path, overwrite_inputs=False):
prompts = json.load(open(prompt_path, "r"))
# construct dimension_count map
dimension_count_map = {}
dimension_prompt_idx_map = {}
dimensions_count = 0
for key, prompt in prompts.items():
dimensions = prompt["dimension"]
for dimension in dimensions:
if dimension not in dimension_prompt_idx_map:
dimension_prompt_idx_map[dimension] = []
dimension_prompt_idx_map[dimension].append(key)
if dimension not in dimension_count_map:
dimension_count_map[dimension] = 0
dimension_count_map[dimension] += 1
dimensions_count += 1
print(
"Dimensions count (each prompt can contribute to more than one dimension count):",
dimensions_count,
)
print(dimension_count_map)
target_prompts_count = 800
# sample prompts based on the distribution of dimensions
sampled_prompts = {}
remaining_prompts = {}
dimension_probs = np.array(list(dimension_count_map.values())) / dimensions_count
dimensions = list(dimension_count_map.keys())
sample_counts = np.random.multinomial(target_prompts_count, dimension_probs)
print(np.sum(sample_counts))
print(sample_counts)
for dimension, count in zip(dimensions, sample_counts):
sampled_prompts_keys = random.sample(dimension_prompt_idx_map[dimension], count)
for key in prompts.keys():
if key in sampled_prompts_keys:
while key in sampled_prompts:
key = random.sample(dimension_prompt_idx_map[dimension], 1)[0]
sampled_prompts[key] = prompts[key]
else:
remaining_prompts[key] = prompts[key]
save_path = "./t2v_vbench_800.json"
remaing_data_save_path = "./t2v_vbench_remain_1000.json"
print(len(sampled_prompts.keys()))
if overwrite_inputs or not os.path.exists(save_path):
# if not os.path.exists(os.path.join(result_folder, experiment_name)):
# os.makedirs(os.path.join(result_folder, experiment_name))
with open(save_path, "w") as f:
json.dump(sampled_prompts, f, indent=4)
with open(remaing_data_save_path, "w") as f:
json.dump(remaining_prompts, f, indent=4)
else:
print("Dataset already exists, skipping generation")
if __name__ == "__main__":
# main(prompt_path="VBench_full_info.json")
main(prompt_path="t2v_vbench_remain_200.json") |