File size: 6,344 Bytes
85bd48b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Datasets consisting of proteins."""
from typing import Dict, Mapping, Optional, Sequence
from alphafold.model.tf import protein_features
import numpy as np
import tensorflow.compat.v1 as tf

TensorDict = Dict[str, tf.Tensor]


def parse_tfexample(
    raw_data: bytes,
    features: protein_features.FeaturesMetadata,
    key: Optional[str] = None) -> Dict[str, tf.train.Feature]:
  """Read a single TF Example proto and return a subset of its features.

  Args:
    raw_data: A serialized tf.Example proto.
    features: A dictionary of features, mapping string feature names to a tuple
      (dtype, shape). This dictionary should be a subset of
      protein_features.FEATURES (or the dictionary itself for all features).
    key: Optional string with the SSTable key of that tf.Example. This will be
      added into features as a 'key' but only if requested in features.

  Returns:
    A dictionary of features mapping feature names to features. Only the given
    features are returned, all other ones are filtered out.
  """
  feature_map = {
      k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True)
      for k, v in features.items()
  }
  parsed_features = tf.io.parse_single_example(raw_data, feature_map)
  reshaped_features = parse_reshape_logic(parsed_features, features, key=key)

  return reshaped_features


def _first(tensor: tf.Tensor) -> tf.Tensor:
  """Returns the 1st element - the input can be a tensor or a scalar."""
  return tf.reshape(tensor, shape=(-1,))[0]


def parse_reshape_logic(
    parsed_features: TensorDict,
    features: protein_features.FeaturesMetadata,
    key: Optional[str] = None) -> TensorDict:
  """Transforms parsed serial features to the correct shape."""
  # Find out what is the number of sequences and the number of alignments.
  num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32)

  if "num_alignments" in parsed_features:
    num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32)
  else:
    num_msa = 0

  if "template_domain_names" in parsed_features:
    num_templates = tf.cast(
        tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32)
  else:
    num_templates = 0

  if key is not None and "key" in features:
    parsed_features["key"] = [key]  # Expand dims from () to (1,).

  # Reshape the tensors according to the sequence length and num alignments.
  for k, v in parsed_features.items():
    new_shape = protein_features.shape(
        feature_name=k,
        num_residues=num_residues,
        msa_length=num_msa,
        num_templates=num_templates,
        features=features)
    new_shape_size = tf.constant(1, dtype=tf.int32)
    for dim in new_shape:
      new_shape_size *= tf.cast(dim, tf.int32)

    assert_equal = tf.assert_equal(
        tf.size(v), new_shape_size,
        name="assert_%s_shape_correct" % k,
        message="The size of feature %s (%s) could not be reshaped "
        "into %s" % (k, tf.size(v), new_shape))
    if "template" not in k:
      # Make sure the feature we are reshaping is not empty.
      assert_non_empty = tf.assert_greater(
          tf.size(v), 0, name="assert_%s_non_empty" % k,
          message="The feature %s is not set in the tf.Example. Either do not "
          "request the feature or use a tf.Example that has the "
          "feature set." % k)
      with tf.control_dependencies([assert_non_empty, assert_equal]):
        parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
    else:
      with tf.control_dependencies([assert_equal]):
        parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)

  return parsed_features


def _make_features_metadata(
    feature_names: Sequence[str]) -> protein_features.FeaturesMetadata:
  """Makes a feature name to type and shape mapping from a list of names."""
  # Make sure these features are always read.
  required_features = ["aatype", "sequence", "seq_length"]
  feature_names = list(set(feature_names) | set(required_features))

  features_metadata = {name: protein_features.FEATURES[name]
                       for name in feature_names}
  return features_metadata


def create_tensor_dict(
    raw_data: bytes,
    features: Sequence[str],
    key: Optional[str] = None,
    ) -> TensorDict:
  """Creates a dictionary of tensor features.

  Args:
    raw_data: A serialized tf.Example proto.
    features: A list of strings of feature names to be returned in the dataset.
    key: Optional string with the SSTable key of that tf.Example. This will be
      added into features as a 'key' but only if requested in features.

  Returns:
    A dictionary of features mapping feature names to features. Only the given
    features are returned, all other ones are filtered out.
  """
  features_metadata = _make_features_metadata(features)
  return parse_tfexample(raw_data, features_metadata, key)


def np_to_tensor_dict(
    np_example: Mapping[str, np.ndarray],
    features: Sequence[str],
    ) -> TensorDict:
  """Creates dict of tensors from a dict of NumPy arrays.

  Args:
    np_example: A dict of NumPy feature arrays.
    features: A list of strings of feature names to be returned in the dataset.

  Returns:
    A dictionary of features mapping feature names to features. Only the given
    features are returned, all other ones are filtered out.
  """
  features_metadata = _make_features_metadata(features)
  tensor_dict = {k: tf.constant(v) for k, v in np_example.items()
                 if k in features_metadata}

  # Ensures shapes are as expected. Needed for setting size of empty features
  # e.g. when no template hits were found.
  tensor_dict = parse_reshape_logic(tensor_dict, features_metadata)
  return tensor_dict