Spaces:
Running
Running
feat: sinkhorn in lse mode (#155)
Browse files- src/dalle_mini/model/modeling.py +18 -26
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -187,9 +187,11 @@ def dot_product_attention_weights(
|
|
| 187 |
dtype: Any = jnp.float32,
|
| 188 |
precision: PrecisionLike = None,
|
| 189 |
sinkhorn_iters: int = 1,
|
|
|
|
| 190 |
):
|
| 191 |
"""
|
| 192 |
Computes dot-product attention weights given query and key.
|
|
|
|
| 193 |
|
| 194 |
Adapted from flax.linen.attention.dot_product_attention_weights"
|
| 195 |
"""
|
|
@@ -207,33 +209,22 @@ def dot_product_attention_weights(
|
|
| 207 |
# apply attention bias: masking, dropout, proximity bias, etc.
|
| 208 |
if bias is not None:
|
| 209 |
attn_weights = attn_weights + bias
|
| 210 |
-
# apply attention mask
|
| 211 |
-
if mask is not None:
|
| 212 |
-
big_neg = jnp.finfo(dtype).min
|
| 213 |
-
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
| 214 |
|
| 215 |
# normalize the attention weights
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
0.0,
|
| 231 |
-
)
|
| 232 |
-
else:
|
| 233 |
-
attn_weights = attn_weights / (
|
| 234 |
-
1e-5
|
| 235 |
-
+ jax.lax.stop_gradient(jnp.sum(attn_weights, axis=axis, keepdims=True))
|
| 236 |
-
)
|
| 237 |
|
| 238 |
# apply attention dropout
|
| 239 |
if not deterministic and dropout_rate > 0.0:
|
|
@@ -392,7 +383,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
| 392 |
attention_bias = lax.select(
|
| 393 |
attention_mask > 0,
|
| 394 |
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
| 395 |
-
jnp.full(attention_mask.shape,
|
| 396 |
)
|
| 397 |
else:
|
| 398 |
attention_bias = None
|
|
@@ -421,6 +412,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
| 421 |
dtype=self.dtype,
|
| 422 |
precision=None,
|
| 423 |
sinkhorn_iters=self.config.sinkhorn_iters,
|
|
|
|
| 424 |
)
|
| 425 |
if self.config.use_cosine_attention:
|
| 426 |
# divide by tau
|
|
|
|
| 187 |
dtype: Any = jnp.float32,
|
| 188 |
precision: PrecisionLike = None,
|
| 189 |
sinkhorn_iters: int = 1,
|
| 190 |
+
causal: bool = False,
|
| 191 |
):
|
| 192 |
"""
|
| 193 |
Computes dot-product attention weights given query and key.
|
| 194 |
+
mask is included into the bias.
|
| 195 |
|
| 196 |
Adapted from flax.linen.attention.dot_product_attention_weights"
|
| 197 |
"""
|
|
|
|
| 209 |
# apply attention bias: masking, dropout, proximity bias, etc.
|
| 210 |
if bias is not None:
|
| 211 |
attn_weights = attn_weights + bias
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
# normalize the attention weights
|
| 214 |
+
if causal or sinkhorn_iters == 1:
|
| 215 |
+
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
| 216 |
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
| 217 |
+
else:
|
| 218 |
+
# adapted from https://github.com/lucidrains/sinkhorn-transformer
|
| 219 |
+
for i in range(sinkhorn_iters):
|
| 220 |
+
# when causal, some attn_weights have been set to -inf through bias
|
| 221 |
+
if i % 2 == 0:
|
| 222 |
+
attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True)
|
| 223 |
+
else:
|
| 224 |
+
attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True)
|
| 225 |
+
if mask is not None:
|
| 226 |
+
attn_weights = jnp.where(mask, attn_weights, -jnp.inf)
|
| 227 |
+
attn_weights = jnp.exp(attn_weights).astype(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
# apply attention dropout
|
| 230 |
if not deterministic and dropout_rate > 0.0:
|
|
|
|
| 383 |
attention_bias = lax.select(
|
| 384 |
attention_mask > 0,
|
| 385 |
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
| 386 |
+
jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype),
|
| 387 |
)
|
| 388 |
else:
|
| 389 |
attention_bias = None
|
|
|
|
| 412 |
dtype=self.dtype,
|
| 413 |
precision=None,
|
| 414 |
sinkhorn_iters=self.config.sinkhorn_iters,
|
| 415 |
+
causal=self.causal,
|
| 416 |
)
|
| 417 |
if self.config.use_cosine_attention:
|
| 418 |
# divide by tau
|