KublaiKhan1 commited on
Commit
b542e1d
·
verified ·
1 Parent(s): 78bad1e

Delete dt0_1/targets_shortcut.py

Browse files
Files changed (1) hide show
  1. dt0_1/targets_shortcut.py +0 -136
dt0_1/targets_shortcut.py DELETED
@@ -1,136 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- import numpy as np
4
-
5
- def get_targets(FLAGS, key, train_state, images, labels, force_t=-1, force_dt=-1):
6
- label_key, time_key, noise_key = jax.random.split(key, 3)
7
- info = {}
8
-
9
- #Convert dt_base to 0-1
10
- #Make everything be continuous
11
- #So if everything is continuous, then we don't sample dt_base like this
12
- #We just allow us to sample anywhere
13
- #And we say that two small steps = 1 big step
14
- #But the biggest step, full = 1.0
15
- #So two small steps to equal one big in log 2...???
16
-
17
-
18
- # 1) =========== Sample dt. ============
19
- bootstrap_batchsize = FLAGS.batch_size // FLAGS.model['bootstrap_every']
20
- log2_sections = np.log2(FLAGS.model['denoise_timesteps']).astype(np.int32)
21
- if FLAGS.model['bootstrap_dt_bias'] == 0:
22
- dt_base = jnp.repeat(log2_sections - 1 - jnp.arange(log2_sections), bootstrap_batchsize // log2_sections)
23
- dt_base = jnp.concatenate([dt_base, jnp.zeros(bootstrap_batchsize-dt_base.shape[0],)])
24
- num_dt_cfg = bootstrap_batchsize // log2_sections
25
- else:
26
- dt_base = jnp.repeat(log2_sections - 1 - jnp.arange(log2_sections-2), (bootstrap_batchsize // 2) // log2_sections)
27
- dt_base = jnp.concatenate([dt_base, jnp.ones(bootstrap_batchsize // 4), jnp.zeros(bootstrap_batchsize // 4)])
28
- dt_base = jnp.concatenate([dt_base, jnp.zeros(bootstrap_batchsize-dt_base.shape[0],)])
29
- num_dt_cfg = (bootstrap_batchsize // 2) // log2_sections
30
- force_dt_vec = jnp.ones(bootstrap_batchsize, dtype=jnp.float32) * force_dt
31
- dt_base = jnp.where(force_dt_vec != -1, force_dt_vec, dt_base)
32
-
33
- #Continuous time is easy
34
- #And then just divide by 7 as needed for 0-1 log space.
35
- #I guess we can also just have a special embedding for maximum or something
36
- if False:
37
- #dt_base = jnp.randint(0,7)#7 because exclusive.
38
- dt_base = jax.random.uniform(0,1)*6
39
- dt_base = dt_base / 7#First step
40
-
41
- dt = 1 / (2 ** (dt_base)) # [1, 1/2, 1/4, 1/8, 1/16, 1/32]
42
- dt_base_bootstrap = dt_base + 1
43
- dt_bootstrap = dt / 2
44
-
45
- # 2) =========== Sample t. ============
46
- dt_sections = jnp.power(2, dt_base) # [1, 2, 4, 8, 16, 32]
47
- t = jax.random.randint(time_key, (bootstrap_batchsize,), minval=0, maxval=dt_sections).astype(jnp.float32)
48
- t = t / dt_sections # Between 0 and 1.
49
- force_t_vec = jnp.ones(bootstrap_batchsize, dtype=jnp.float32) * force_t
50
- t = jnp.where(force_t_vec != -1, force_t_vec, t)
51
- t_full = t[:, None, None, None]
52
-
53
- # 3) =========== Generate Bootstrap Targets ============
54
- x_1 = images[:bootstrap_batchsize]
55
- x_0 = jax.random.normal(noise_key, x_1.shape)
56
- x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
57
- bst_labels = labels[:bootstrap_batchsize]
58
- call_model_fn = train_state.call_model if FLAGS.model['bootstrap_ema'] == 0 else train_state.call_model_ema
59
- if not FLAGS.model['bootstrap_cfg']:
60
- #We should just have dt_base /= 7
61
- v_b1 = call_model_fn(x_t, t, dt_base_bootstrap, bst_labels, train=False)
62
- t2 = t + dt_bootstrap
63
- x_t2 = x_t + dt_bootstrap[:, None, None, None] * v_b1
64
- x_t2 = jnp.clip(x_t2, -4, 4)
65
- v_b2 = call_model_fn(x_t2, t2, dt_base_bootstrap, bst_labels, train=False)
66
- v_target = (v_b1 + v_b2) / 2
67
- else:
68
- x_t_extra = jnp.concatenate([x_t, x_t[:num_dt_cfg]], axis=0)
69
- t_extra = jnp.concatenate([t, t[:num_dt_cfg]], axis=0)
70
- dt_base_extra = jnp.concatenate([dt_base_bootstrap, dt_base_bootstrap[:num_dt_cfg]], axis=0)
71
- labels_extra = jnp.concatenate([bst_labels, jnp.ones(num_dt_cfg, dtype=jnp.int32) * FLAGS.model['num_classes']], axis=0)
72
- #step 1
73
- dt_base_extra = dt_base_extra / 7.0
74
- v_b1_raw = call_model_fn(x_t_extra, t_extra, dt_base_extra, labels_extra, train=False)
75
- v_b_cond = v_b1_raw[:x_1.shape[0]]
76
- v_b_uncond = v_b1_raw[x_1.shape[0]:]
77
- v_cfg = v_b_uncond + FLAGS.model['cfg_scale'] * (v_b_cond[:num_dt_cfg] - v_b_uncond)
78
- v_b1 = jnp.concatenate([v_cfg, v_b_cond[num_dt_cfg:]], axis=0)
79
-
80
- t2 = t + dt_bootstrap
81
- x_t2 = x_t + dt_bootstrap[:, None, None, None] * v_b1
82
- x_t2 = jnp.clip(x_t2, -4, 4)
83
- x_t2_extra = jnp.concatenate([x_t2, x_t2[:num_dt_cfg]], axis=0)
84
- t2_extra = jnp.concatenate([t2, t2[:num_dt_cfg]], axis=0)
85
- v_b2_raw = call_model_fn(x_t2_extra, t2_extra, dt_base_extra, labels_extra, train=False)
86
- v_b2_cond = v_b2_raw[:x_1.shape[0]]
87
- v_b2_uncond = v_b2_raw[x_1.shape[0]:]
88
- v_b2_cfg = v_b2_uncond + FLAGS.model['cfg_scale'] * (v_b2_cond[:num_dt_cfg] - v_b2_uncond)
89
- v_b2 = jnp.concatenate([v_b2_cfg, v_b2_cond[num_dt_cfg:]], axis=0)
90
- v_target = (v_b1 + v_b2) / 2
91
-
92
- v_target = jnp.clip(v_target, -4, 4)
93
- bst_v = v_target
94
- bst_dt = dt_base
95
- bst_t = t
96
- bst_xt = x_t
97
- bst_l = bst_labels
98
-
99
- # 4) =========== Generate Flow-Matching Targets ============
100
-
101
- labels_dropout = jax.random.bernoulli(label_key, FLAGS.model['class_dropout_prob'], (labels.shape[0],))
102
- labels_dropped = jnp.where(labels_dropout, FLAGS.model['num_classes'], labels)
103
- info['dropped_ratio'] = jnp.mean(labels_dropped == FLAGS.model['num_classes'])
104
-
105
- # Sample t.
106
- t = jax.random.randint(time_key, (images.shape[0],), minval=0, maxval=FLAGS.model['denoise_timesteps']).astype(jnp.float32)
107
- t /= FLAGS.model['denoise_timesteps']
108
- force_t_vec = jnp.ones(images.shape[0], dtype=jnp.float32) * force_t
109
- t = jnp.where(force_t_vec != -1, force_t_vec, t) # If force_t is not -1, then use force_t.
110
- t_full = t[:, None, None, None] # [batch, 1, 1, 1]
111
-
112
- # Sample flow pairs x_t, v_t.
113
- x_0 = jax.random.normal(noise_key, images.shape)
114
- x_1 = images
115
- x_t = x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
116
- v_t = v_t = x_1 - (1 - 1e-5) * x_0
117
- dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
118
- dt_base = jnp.ones(images.shape[0], dtype=jnp.int32) * dt_flow
119
-
120
- # ==== 5) Merge Flow+Bootstrap ====
121
- bst_size = FLAGS.batch_size // FLAGS.model['bootstrap_every']
122
- bst_size_data = FLAGS.batch_size - bst_size
123
- x_t = jnp.concatenate([bst_xt, x_t[:bst_size_data]], axis=0)
124
- t = jnp.concatenate([bst_t, t[:bst_size_data]], axis=0)
125
- dt_base = jnp.concatenate([bst_dt, dt_base[:bst_size_data]], axis=0)
126
- v_t = jnp.concatenate([bst_v, v_t[:bst_size_data]], axis=0)
127
- labels_dropped = jnp.concatenate([bst_l, labels_dropped[:bst_size_data]], axis=0)
128
- info['bootstrap_ratio'] = jnp.mean(dt_base != dt_flow)
129
-
130
- info['v_magnitude_bootstrap'] = jnp.sqrt(jnp.mean(jnp.square(bst_v)))
131
- info['v_magnitude_b1'] = jnp.sqrt(jnp.mean(jnp.square(v_b1)))
132
- info['v_magnitude_b2'] = jnp.sqrt(jnp.mean(jnp.square(v_b2)))
133
-
134
- dt_base = dt_base / 7.0
135
- #print("dt base", dt_base)
136
- return x_t, v_t, t, dt_base, labels_dropped, info