Spaces:
Runtime error
Runtime error
import os | |
import json | |
import numpy as np | |
import random | |
# Randomly sample a subset of prompts for benchmarking | |
def main(prompt_path, overwrite_inputs=False): | |
prompts = json.load(open(prompt_path, "r")) | |
# construct dimension_count map | |
dimension_count_map = {} | |
dimension_prompt_idx_map = {} | |
dimensions_count = 0 | |
for i in range(len(prompts)): | |
prompt = prompts[i] | |
dimensions = prompt["dimension"] | |
for dimension in dimensions: | |
if dimension not in dimension_prompt_idx_map: | |
dimension_prompt_idx_map[dimension] = [] | |
dimension_prompt_idx_map[dimension].append(i) | |
if dimension not in dimension_count_map: | |
dimension_count_map[dimension] = 0 | |
dimension_count_map[dimension] += 1 | |
dimensions_count += 1 | |
print( | |
"Dimensions count (each prompt can contribute to more than one dimension count):", | |
dimensions_count, | |
) | |
print(dimension_count_map) | |
target_prompts_count = 800 | |
# sample prompts based on the distribution of dimensions | |
sampled_prompts = list() | |
remaining_prompts = list() | |
dimension_probs = np.array(list(dimension_count_map.values())) / dimensions_count | |
dimensions = list(dimension_count_map.keys()) | |
sample_counts = np.random.multinomial(target_prompts_count, dimension_probs) | |
print(sample_counts) | |
for dimension, count in zip(dimensions, sample_counts): | |
sampled_prompts_idx = random.sample(dimension_prompt_idx_map[dimension], count) | |
for idx in range(len(prompts)): | |
if idx in sampled_prompts_idx: | |
sampled_prompts.append(prompts[idx]) | |
else: | |
remaining_prompts.append(prompts[idx]) | |
save_path = "./t2v_vbench_1000.json" | |
remaing_data_save_path = "./t2v_vbench_remain_1000.json" | |
if overwrite_inputs or not os.path.exists(save_path): | |
# if not os.path.exists(os.path.join(result_folder, experiment_name)): | |
# os.makedirs(os.path.join(result_folder, experiment_name)) | |
with open(save_path, "w") as f: | |
json.dump(sampled_prompts, f, indent=4) | |
with open(remaing_data_save_path, "w") as f: | |
json.dump(remaining_prompts, f, indent=4) | |
else: | |
print("Dataset already exists, skipping generation") | |
if __name__ == "__main__": | |
main(prompt_path="VBench_full_info.json") | |
# main(prompt_path="t2v_vbench_remain_200.json") |