# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import math import os from typing import Callable, List, Optional, Tuple import numpy as np # deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None # ds4s_is_installed = ( # deepspeed_is_installed # and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None # ) deepspeed_is_installed = False ds4s_is_installed = False if deepspeed_is_installed: import deepspeed if ds4s_is_installed: from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention fa_is_installed = importlib.util.find_spec("flash_attn") is not None if fa_is_installed: from flash_attn.bert_padding import unpad_input from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func fastln_is_installed = os.getenv("LAYERNORM_TYPE", None) == "fast_layernorm" if fastln_is_installed: # LayerNorm is a time bottomneck, so we use a custom implementation. from protenix.model.layer_norm.layer_norm import FusedLayerNorm import torch import torch.nn as nn from scipy.stats import truncnorm from protenix.openfold_local.utils.checkpointing import get_checkpoint_fn from protenix.openfold_local.utils.precision_utils import is_fp16_enabled from protenix.openfold_local.utils.tensor_utils import ( flatten_final_dims, permute_final_dims, ) DEFAULT_LMA_Q_CHUNK_SIZE = 1024 DEFAULT_LMA_KV_CHUNK_SIZE = 4096 def _prod(nums): out = 1 for n in nums: out = out * n return out def _calculate_fan(linear_weight_shape, fan="fan_in"): fan_out, fan_in = linear_weight_shape if fan == "fan_in": f = fan_in elif fan == "fan_out": f = fan_out elif fan == "fan_avg": f = (fan_in + fan_out) / 2 else: raise ValueError("Invalid fan option") return f def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): shape = weights.shape f = _calculate_fan(shape, fan) scale = scale / max(1, f) a = -2 b = 2 std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) size = _prod(shape) samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) samples = np.reshape(samples, shape) with torch.no_grad(): weights.copy_(torch.tensor(samples, device=weights.device)) def lecun_normal_init_(weights): trunc_normal_init_(weights, scale=1.0) def he_normal_init_(weights): trunc_normal_init_(weights, scale=2.0) def glorot_uniform_init_(weights): nn.init.xavier_uniform_(weights, gain=1) def final_init_(weights): with torch.no_grad(): weights.fill_(0.0) def gating_init_(weights): with torch.no_grad(): weights.fill_(0.0) def normal_init_(weights): torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") def ipa_point_weights_init_(weights): with torch.no_grad(): softplus_inverse_1 = 0.541324854612918 weights.fill_(softplus_inverse_1) class Linear(nn.Linear): """ A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear. Implements the initializers in 1.11.4, plus some additional ones found in the code. """ def __init__( self, in_dim: int, out_dim: int, bias: bool = True, init: str = "default", init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, precision=None, ): """ Args: in_dim: The final dimension of inputs to the layer out_dim: The final dimension of layer outputs bias: Whether to learn an additive bias. True by default init: The initializer to use. Choose from: "default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal": Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0 Overridden by init_fn if the latter is not None. init_fn: A custom initializer taking weight and bias as inputs. Overrides init if not None. """ super(Linear, self).__init__(in_dim, out_dim, bias=bias) if bias: with torch.no_grad(): self.bias.fill_(0) with torch.no_grad(): if init_fn is not None: init_fn(self.weight, self.bias) else: if init == "default": lecun_normal_init_(self.weight) elif init == "relu": he_normal_init_(self.weight) elif init == "glorot": glorot_uniform_init_(self.weight) elif init == "gating": gating_init_(self.weight) if bias: self.bias.fill_(1.0) elif init == "normal": normal_init_(self.weight) elif init == "final": final_init_(self.weight) else: raise ValueError("Invalid init string.") self.precision = precision def forward(self, input: torch.Tensor) -> torch.Tensor: d = input.dtype deepspeed_is_initialized = ( deepspeed_is_installed and deepspeed.comm.comm.is_initialized() ) if self.precision is not None: with torch.cuda.amp.autocast(enabled=False): bias = ( self.bias.to(dtype=self.precision) if self.bias is not None else None ) return nn.functional.linear( input.to(dtype=self.precision), self.weight.to(dtype=self.precision), bias, ).to(dtype=d) if d is torch.bfloat16 and not deepspeed_is_initialized: with torch.cuda.amp.autocast(enabled=False): bias = self.bias.to(dtype=d) if self.bias is not None else None return nn.functional.linear(input, self.weight.to(dtype=d), bias) return nn.functional.linear(input, self.weight, self.bias) class OpenFoldLayerNorm(nn.Module): def __init__(self, c_in, eps=1e-5): super(OpenFoldLayerNorm, self).__init__() self.c_in = (c_in,) self.eps = eps self.weight = nn.Parameter(torch.ones(c_in)) self.bias = nn.Parameter(torch.zeros(c_in)) def forward(self, x): d = x.dtype deepspeed_is_initialized = ( deepspeed_is_installed and deepspeed.comm.comm.is_initialized() ) if d is torch.bfloat16 and not deepspeed_is_initialized: with torch.cuda.amp.autocast(enabled=False): out = nn.functional.layer_norm( x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps, ) else: out = nn.functional.layer_norm( x, self.c_in, self.weight, self.bias, self.eps, ) return out # Keep the function name for code simplicity def LayerNorm(c_in, eps: float = 1e-5): # if specify "fast_layernorm" and fastln_is_installed, use the FusedLayerNorm, # Otherwise, OpenFoldLayerNorm is used! if fastln_is_installed: # print("use fast layernorm") return FusedLayerNorm(c_in, eps) # print("use openfold layernorm") return OpenFoldLayerNorm(c_in, eps) @torch.jit.ignore def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Softmax, but without automatic casting to fp32 when the input is of type bfloat16 """ d = t.dtype deepspeed_is_initialized = ( deepspeed_is_installed and deepspeed.comm.comm.is_initialized() ) if d is torch.bfloat16 and not deepspeed_is_initialized: with torch.cuda.amp.autocast(enabled=False): s = torch.nn.functional.softmax(t, dim=dim) else: s = torch.nn.functional.softmax(t, dim=dim) return s # @torch.jit.script def _attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor], ) -> torch.Tensor: # [*, H, C_hidden, K] key = permute_final_dims(key, (1, 0)) # [*, H, Q, K] a = torch.matmul(query, key) for b in biases: a += b a = softmax_no_cast(a, -1) # [*, H, Q, C_hidden] a = torch.matmul(a, value) return a @torch.jit.ignore def _attention_chunked_trainable( query, key, value, biases, chunk_size, chunk_dim, checkpoint, ): if checkpoint and len(biases) > 2: raise ValueError("Checkpointed version permits only permits two bias terms") def _checkpointable_attention(q, k, v, b1, b2): bs = [b for b in [b1, b2] if b is not None] a = _attention(q, k, v, bs) return a o_chunks = [] checkpoint_fn = get_checkpoint_fn() count = query.shape[chunk_dim] for start in range(0, count, chunk_size): end = start + chunk_size idx = [slice(None)] * len(query.shape) idx[chunk_dim] = slice(start, end) idx_tup = tuple(idx) q_chunk = query[idx_tup] k_chunk = key[idx_tup] v_chunk = value[idx_tup] def _slice_bias(b): idx[chunk_dim] = ( slice(start, end) if b.shape[chunk_dim] != 1 else slice(None) ) return b[tuple(idx)] if checkpoint: bias_1_chunk, bias_2_chunk = [ _slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2] ] o_chunk = checkpoint_fn( _checkpointable_attention, q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk, ) else: bias_chunks = [_slice_bias(b) for b in biases] o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks) o_chunk = o_chunk.transpose(-2, -3) o_chunks.append(o_chunk) o = torch.cat(o_chunks, dim=chunk_dim) return o class Attention(nn.Module): """ Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors. """ def __init__( self, c_q: int, c_k: int, c_v: int, c_hidden: int, no_heads: int, gating: bool = True, ): """ Args: c_q: Input dimension of query data c_k: Input dimension of key data c_v: Input dimension of value data c_hidden: Per-head hidden dimension no_heads: Number of attention heads gating: Whether the output should be gated using query data """ super(Attention, self).__init__() self.c_q = c_q self.c_k = c_k self.c_v = c_v self.c_hidden = c_hidden self.no_heads = no_heads self.gating = gating # DISCREPANCY: c_hidden is not the per-head channel dimension, as # stated in the supplement, but the overall channel dimension. self.linear_q = Linear( self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot" ) self.linear_k = Linear( self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot" ) self.linear_v = Linear( self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot" ) self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final") self.linear_g = None if self.gating: self.linear_g = Linear( self.c_q, self.c_hidden * self.no_heads, init="gating" ) self.sigmoid = nn.Sigmoid() def _prep_qkv( self, q_x: torch.Tensor, kv_x: torch.Tensor, apply_scale: bool = True ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # [*, Q/K/V, H * C_hidden] q = self.linear_q(q_x) k = self.linear_k(kv_x) v = self.linear_v(kv_x) # [*, Q/K, H, C_hidden] q = q.view(q.shape[:-1] + (self.no_heads, -1)) k = k.view(k.shape[:-1] + (self.no_heads, -1)) v = v.view(v.shape[:-1] + (self.no_heads, -1)) # [*, H, Q/K, C_hidden] q = q.transpose(-2, -3) k = k.transpose(-2, -3) v = v.transpose(-2, -3) if apply_scale: q /= math.sqrt(self.c_hidden) return q, k, v def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: if self.linear_g is not None: g = self.sigmoid(self.linear_g(q_x)) # [*, Q, H, C_hidden] g = g.view(g.shape[:-1] + (self.no_heads, -1)) o = o * g # [*, Q, H * C_hidden] o = flatten_final_dims(o, 2) # [*, Q, C_q] o = self.linear_o(o) return o def forward( self, q_x: torch.Tensor, kv_x: torch.Tensor, biases: Optional[List[torch.Tensor]] = None, use_memory_efficient_kernel: bool = False, use_deepspeed_evo_attention: bool = False, use_lma: bool = False, lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE, lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE, use_flash: bool = False, flash_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: q_x: [*, Q, C_q] query data kv_x: [*, K, C_k] key data biases: List of biases that broadcast to [*, H, Q, K] use_memory_efficient_kernel: Whether to use a custom memory-efficient attention kernel. This should be the default choice for most. If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead use_deepspeed_evo_attention: Whether to use DeepSpeed memory-efficient attention kernel. If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead use_lma: Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead lma_q_chunk_size: Query chunk size (for LMA) lma_kv_chunk_size: Key/Value chunk size (for LMA) Returns [*, Q, C_q] attention update """ if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None): raise ValueError( "If use_lma is specified, lma_q_chunk_size and " "lma_kv_chunk_size must be provided" ) if use_flash and biases is not None: raise ValueError( "use_flash is incompatible with the bias option. For masking, " "use flash_mask instead" ) attn_options = [ use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, use_flash, ] if sum(attn_options) > 1: raise ValueError("Choose at most one alternative attention algorithm") if biases is None: biases = [] # DeepSpeed attention kernel applies scaling internally q, k, v = self._prep_qkv(q_x, kv_x, apply_scale=not use_deepspeed_evo_attention) if is_fp16_enabled(): use_memory_efficient_kernel = False if use_memory_efficient_kernel: raise Exception(f"use_memory_efficient_kernel=True not supported!!!") if len(biases) > 2: raise ValueError( "If use_memory_efficient_kernel is True, you may only " "provide up to two bias terms" ) o = attention_core(q, k, v, *((biases + [None] * 2)[:2])) o = o.transpose(-2, -3) elif use_deepspeed_evo_attention: if len(biases) > 2: raise ValueError( "If use_deepspeed_evo_attention is True, you may only " "provide up to two bias terms" ) o = _deepspeed_evo_attn(q, k, v, biases) elif use_lma: biases = [ b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases ] o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size) o = o.transpose(-2, -3) elif use_flash: o = _flash_attn(q, k, v, flash_mask) else: o = _attention(q, k, v, biases) o = o.transpose(-2, -3) o = self._wrap_up(o, q_x) return o class GlobalAttention(nn.Module): def __init__(self, c_in, c_hidden, no_heads, inf, eps): super(GlobalAttention, self).__init__() self.c_in = c_in self.c_hidden = c_hidden self.no_heads = no_heads self.inf = inf self.eps = eps self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot") self.linear_k = Linear( c_in, c_hidden, bias=False, init="glorot", ) self.linear_v = Linear( c_in, c_hidden, bias=False, init="glorot", ) self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") self.sigmoid = nn.Sigmoid() def forward( self, m: torch.Tensor, mask: torch.Tensor, use_lma: bool = False, ) -> torch.Tensor: # [*, N_res, C_in] q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / ( torch.sum(mask, dim=-1)[..., None] + self.eps ) # [*, N_res, H * C_hidden] q = self.linear_q(q) q *= self.c_hidden ** (-0.5) # [*, N_res, H, C_hidden] q = q.view(q.shape[:-1] + (self.no_heads, -1)) # [*, N_res, N_seq, C_hidden] k = self.linear_k(m) v = self.linear_v(m) bias = (self.inf * (mask - 1))[..., :, None, :] if not use_lma: # [*, N_res, H, N_seq] a = torch.matmul( q, k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] ) a += bias a = softmax_no_cast(a) # [*, N_res, H, C_hidden] o = torch.matmul( a, v, ) else: o = _lma( q, k, v, [bias], DEFAULT_LMA_Q_CHUNK_SIZE, DEFAULT_LMA_KV_CHUNK_SIZE ) # [*, N_res, N_seq, C_hidden] g = self.sigmoid(self.linear_g(m)) # [*, N_res, N_seq, H, C_hidden] g = g.view(g.shape[:-1] + (self.no_heads, -1)) # [*, N_res, N_seq, H, C_hidden] o = o.unsqueeze(-3) * g # [*, N_res, N_seq, H * C_hidden] o = o.reshape(o.shape[:-2] + (-1,)) # [*, N_res, N_seq, C_in] m = self.linear_o(o) return m @torch.jit.ignore def _deepspeed_evo_attn( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, biases: List[torch.Tensor], ): """ "" Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel. Args: q: [*, H, Q, C_hidden] query data k: [*, H, K, C_hidden] key data v: [*, H, V, C_hidden] value data biases: List of biases that broadcast to [*, H, Q, K] """ if not ds4s_is_installed: raise ValueError( "_deepspeed_evo_attn requires that DeepSpeed be installed " "and that the deepspeed.ops.deepspeed4science package exists" ) def reshape_dims(x): no_batch_dims = len(x.shape[:-3]) if no_batch_dims < 2: return x.reshape(*((1,) * (2 - no_batch_dims) + x.shape)) if no_batch_dims > 2: return x.reshape(*((x.shape[0], -1) + x.shape[-3:])) return x # [*, Q/K, H, C_hidden] q = q.transpose(-2, -3) k = k.transpose(-2, -3) v = v.transpose(-2, -3) # Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden] # for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed. orig_shape = q.shape if len(orig_shape[:-3]) != 2: q = reshape_dims(q) k = reshape_dims(k) v = reshape_dims(v) biases = [reshape_dims(b) for b in biases] # DeepSpeed attn. kernel requires inputs to be type bf16 or fp16 # Cast to bf16 so kernel can be used during inference orig_dtype = q.dtype if orig_dtype not in [torch.bfloat16, torch.float16]: o = DS4Sci_EvoformerAttention( q.to(dtype=torch.bfloat16), k.to(dtype=torch.bfloat16), v.to(dtype=torch.bfloat16), [b.to(dtype=torch.bfloat16) for b in biases], ) o = o.to(dtype=orig_dtype) else: o = DS4Sci_EvoformerAttention(q, k, v, biases) o = o.reshape(orig_shape) return o def _lma( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, biases: List[torch.Tensor], q_chunk_size: int, kv_chunk_size: int, ): no_q, no_kv = q.shape[-2], k.shape[-2] # [*, H, Q, C_hidden] o = q.new_zeros(q.shape) for q_s in range(0, no_q, q_chunk_size): q_chunk = q[..., q_s : q_s + q_chunk_size, :] large_bias_chunks = [b[..., q_s : q_s + q_chunk_size, :] for b in biases] maxes = [] weights = [] values = [] for kv_s in range(0, no_kv, kv_chunk_size): k_chunk = k[..., kv_s : kv_s + kv_chunk_size, :] v_chunk = v[..., kv_s : kv_s + kv_chunk_size, :] small_bias_chunks = [ b[..., kv_s : kv_s + kv_chunk_size] for b in large_bias_chunks ] a = torch.einsum( "...hqd,...hkd->...hqk", q_chunk, k_chunk, ) for b in small_bias_chunks: a += b max_a = torch.max(a, dim=-1, keepdim=True)[0] exp_a = torch.exp(a - max_a) exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a) maxes.append(max_a.detach().squeeze(-1)) weights.append(torch.sum(exp_a, dim=-1)) values.append(exp_v) chunk_max = torch.stack(maxes, dim=-3) chunk_weights = torch.stack(weights, dim=-3) chunk_values = torch.stack(values, dim=-4) global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] max_diffs = torch.exp(chunk_max - global_max) chunk_values = chunk_values * max_diffs.unsqueeze(-1) chunk_weights = chunk_weights * max_diffs all_values = torch.sum(chunk_values, dim=-4) all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) q_chunk_out = all_values / all_weights o[..., q_s : q_s + q_chunk_size, :] = q_chunk_out return o @torch.jit.ignore def _flash_attn(q, k, v, kv_mask): if not fa_is_installed: raise ValueError("_flash_attn requires that FlashAttention be installed") batch_dims = q.shape[:-3] no_heads, n, c = q.shape[-3:] dtype = q.dtype q = q.half() k = k.half() v = v.half() kv_mask = kv_mask.half() # [*, B, N, H, C] q = q.transpose(-2, -3) k = k.transpose(-2, -3) v = v.transpose(-2, -3) # [B_flat, N, H, C] q = q.reshape(-1, *q.shape[-3:]) k = k.reshape(-1, *k.shape[-3:]) v = v.reshape(-1, *v.shape[-3:]) # Flattened batch size batch_size = q.shape[0] # [B_flat * N, H, C] q = q.reshape(-1, *q.shape[-2:]) q_max_s = n q_cu_seqlens = torch.arange( 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device ) # [B_flat, N, 2, H, C] kv = torch.stack([k, v], dim=-3) kv_shape = kv.shape # [B_flat, N, 2 * H * C] kv = kv.reshape(*kv.shape[:-3], -1) kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask) kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:]) out = flash_attn_unpadded_kvpacked_func( q, kv_unpad, q_cu_seqlens, kv_cu_seqlens, q_max_s, kv_max_s, dropout_p=0.0, softmax_scale=1.0, # q has been scaled already ) # [*, B, N, H, C] out = out.reshape(*batch_dims, n, no_heads, c) out = out.to(dtype=dtype) return out