File size: 2,640 Bytes
dea6c7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import json
import numpy as np
import random

# Randomly sample a subset of prompts for benchmarking
def main(prompt_path, overwrite_inputs=False):
    prompts = json.load(open(prompt_path, "r"))

    # construct dimension_count map
    dimension_count_map = {}
    dimension_prompt_idx_map = {}
    dimensions_count = 0
    for key, prompt in prompts.items():
        dimensions = prompt["dimension"]
        for dimension in dimensions:
            if dimension not in dimension_prompt_idx_map:
                dimension_prompt_idx_map[dimension] = []
            dimension_prompt_idx_map[dimension].append(key)

            if dimension not in dimension_count_map:
                dimension_count_map[dimension] = 0

            dimension_count_map[dimension] += 1

            dimensions_count += 1

    print(
        "Dimensions count (each prompt can contribute to more than one dimension count):",
        dimensions_count,
    )
    print(dimension_count_map)

    target_prompts_count = 800
    # sample prompts based on the distribution of dimensions
    sampled_prompts = {}
    remaining_prompts = {}
    dimension_probs = np.array(list(dimension_count_map.values())) / dimensions_count
    dimensions = list(dimension_count_map.keys())
    sample_counts = np.random.multinomial(target_prompts_count, dimension_probs)
    print(np.sum(sample_counts))
    print(sample_counts)
    for dimension, count in zip(dimensions, sample_counts):

        sampled_prompts_keys = random.sample(dimension_prompt_idx_map[dimension], count)
        for key in prompts.keys():
            if key in sampled_prompts_keys:
                while key in sampled_prompts:
                    key = random.sample(dimension_prompt_idx_map[dimension], 1)[0]
                sampled_prompts[key] = prompts[key]
            else:
                remaining_prompts[key] = prompts[key]

    save_path = "./t2v_vbench_800.json"
    remaing_data_save_path = "./t2v_vbench_remain_1000.json"
    print(len(sampled_prompts.keys()))
    if overwrite_inputs or not os.path.exists(save_path):
        # if not os.path.exists(os.path.join(result_folder, experiment_name)):
        #     os.makedirs(os.path.join(result_folder, experiment_name))
        with open(save_path, "w") as f:
            json.dump(sampled_prompts, f, indent=4)
            
        with open(remaing_data_save_path, "w") as f:
            json.dump(remaining_prompts, f, indent=4)
    else:
        print("Dataset already exists, skipping generation")

if __name__ == "__main__":
    # main(prompt_path="VBench_full_info.json")
    main(prompt_path="t2v_vbench_remain_200.json")