Spaces:
Running
on
T4
Running
on
T4
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Feature pre-processing input pipeline for AlphaFold.""" | |
from alphafold.model.tf import data_transforms | |
from alphafold.model.tf import shape_placeholders | |
import tensorflow.compat.v1 as tf | |
import tree | |
# Pylint gets confused by the curry1 decorator because it changes the number | |
# of arguments to the function. | |
# pylint:disable=no-value-for-parameter | |
NUM_RES = shape_placeholders.NUM_RES | |
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ | |
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ | |
NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES | |
def nonensembled_map_fns(data_config): | |
"""Input pipeline functions which are not ensembled.""" | |
common_cfg = data_config.common | |
map_fns = [ | |
data_transforms.correct_msa_restypes, | |
data_transforms.add_distillation_flag(False), | |
data_transforms.cast_64bit_ints, | |
data_transforms.squeeze_features, | |
# Keep to not disrupt RNG. | |
data_transforms.randomly_replace_msa_with_unknown(0.0), | |
data_transforms.make_seq_mask, | |
data_transforms.make_msa_mask, | |
# Compute the HHblits profile if it's not set. This has to be run before | |
# sampling the MSA. | |
data_transforms.make_hhblits_profile, | |
data_transforms.make_random_crop_to_size_seed, | |
] | |
if common_cfg.use_templates: | |
map_fns.extend([ | |
data_transforms.fix_templates_aatype, | |
data_transforms.make_template_mask, | |
data_transforms.make_pseudo_beta('template_') | |
]) | |
map_fns.extend([ | |
data_transforms.make_atom14_masks, | |
]) | |
return map_fns | |
def ensembled_map_fns(data_config): | |
"""Input pipeline functions that can be ensembled and averaged.""" | |
common_cfg = data_config.common | |
eval_cfg = data_config.eval | |
map_fns = [] | |
if common_cfg.reduce_msa_clusters_by_max_templates: | |
pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates | |
else: | |
pad_msa_clusters = eval_cfg.max_msa_clusters | |
max_msa_clusters = pad_msa_clusters | |
max_extra_msa = common_cfg.max_extra_msa | |
map_fns.append( | |
data_transforms.sample_msa( | |
max_msa_clusters, | |
keep_extra=True)) | |
if 'masked_msa' in common_cfg: | |
# Masked MSA should come *before* MSA clustering so that | |
# the clustering and full MSA profile do not leak information about | |
# the masked locations and secret corrupted locations. | |
map_fns.append( | |
data_transforms.make_masked_msa(common_cfg.masked_msa, | |
eval_cfg.masked_msa_replace_fraction)) | |
if common_cfg.msa_cluster_features: | |
map_fns.append(data_transforms.nearest_neighbor_clusters()) | |
map_fns.append(data_transforms.summarize_clusters()) | |
# Crop after creating the cluster profiles. | |
if max_extra_msa: | |
map_fns.append(data_transforms.crop_extra_msa(max_extra_msa)) | |
else: | |
map_fns.append(data_transforms.delete_extra_msa) | |
map_fns.append(data_transforms.make_msa_feat()) | |
crop_feats = dict(eval_cfg.feat) | |
if eval_cfg.fixed_size: | |
map_fns.append(data_transforms.select_feat(list(crop_feats))) | |
map_fns.append(data_transforms.random_crop_to_size( | |
eval_cfg.crop_size, | |
eval_cfg.max_templates, | |
crop_feats, | |
eval_cfg.subsample_templates)) | |
map_fns.append(data_transforms.make_fixed_size( | |
crop_feats, | |
pad_msa_clusters, | |
common_cfg.max_extra_msa, | |
eval_cfg.crop_size, | |
eval_cfg.max_templates)) | |
else: | |
map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates)) | |
return map_fns | |
def process_tensors_from_config(tensors, data_config): | |
"""Apply filters and maps to an existing dataset, based on the config.""" | |
def wrap_ensemble_fn(data, i): | |
"""Function to be mapped over the ensemble dimension.""" | |
d = data.copy() | |
fns = ensembled_map_fns(data_config) | |
fn = compose(fns) | |
d['ensemble_index'] = i | |
return fn(d) | |
eval_cfg = data_config.eval | |
tensors = compose( | |
nonensembled_map_fns( | |
data_config))( | |
tensors) | |
tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0)) | |
num_ensemble = eval_cfg.num_ensemble | |
if data_config.common.resample_msa_in_recycling: | |
# Separate batch per ensembling & recycling step. | |
num_ensemble *= data_config.common.num_recycle + 1 | |
if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1: | |
fn_output_signature = tree.map_structure( | |
tf.TensorSpec.from_tensor, tensors_0) | |
tensors = tf.map_fn( | |
lambda x: wrap_ensemble_fn(tensors, x), | |
tf.range(num_ensemble), | |
parallel_iterations=1, | |
fn_output_signature=fn_output_signature) | |
else: | |
tensors = tree.map_structure(lambda x: x[None], | |
tensors_0) | |
return tensors | |
def compose(x, fs): | |
for f in fs: | |
x = f(x) | |
return x | |