jgpeters commited on
Commit
bfdf7ef
·
verified ·
1 Parent(s): 6faf039

Delete structure.py

Browse files
Files changed (1) hide show
  1. structure.py +0 -184
structure.py DELETED
@@ -1,184 +0,0 @@
1
- import os
2
- from os.path import isfile
3
- from enum import Enum, auto
4
-
5
- import numpy as np
6
- from scipy.spatial.distance import cdist
7
- import networkx as nx
8
- from biopandas.pdb import PandasPdb
9
-
10
-
11
- class GraphType(Enum):
12
- LINEAR = auto()
13
- COMPLETE = auto()
14
- DISCONNECTED = auto()
15
- DIST_THRESH = auto()
16
- DIST_THRESH_SHUFFLED = auto()
17
-
18
-
19
- def save_graph(g, fn):
20
- """ Saves graph to file """
21
- nx.write_gexf(g, fn)
22
-
23
-
24
- def load_graph(fn):
25
- """ Loads graph from file """
26
- g = nx.read_gexf(fn, node_type=int)
27
- return g
28
-
29
-
30
- def shuffle_nodes(g, seed=7):
31
- """ Shuffles the nodes of the given graph and returns a copy of the shuffled graph """
32
- # get the list of nodes in this graph
33
- nodes = g.nodes()
34
-
35
- # create a permuted list of nodes
36
- np.random.seed(seed)
37
- nodes_shuffled = np.random.permutation(nodes)
38
-
39
- # create a dictionary mapping from old node label to new node label
40
- mapping = {n: ns for n, ns in zip(nodes, nodes_shuffled)}
41
-
42
- g_shuffled = nx.relabel_nodes(g, mapping, copy=True)
43
-
44
- return g_shuffled
45
-
46
-
47
- def linear_graph(num_residues):
48
- """ Creates a linear graph where each node is connected to its sequence neighbor in order """
49
- g = nx.Graph()
50
- g.add_nodes_from(np.arange(0, num_residues))
51
- for i in range(num_residues-1):
52
- g.add_edge(i, i+1)
53
- return g
54
-
55
-
56
- def complete_graph(num_residues):
57
- """ Creates a graph where each node is connected to all other nodes"""
58
- g = nx.complete_graph(num_residues)
59
- return g
60
-
61
-
62
- def disconnected_graph(num_residues):
63
- g = nx.Graph()
64
- g.add_nodes_from(np.arange(0, num_residues))
65
- return g
66
-
67
-
68
- def dist_thresh_graph(dist_mtx, threshold):
69
- """ Creates undirected graph based on a distance threshold """
70
- g = nx.Graph()
71
- g.add_nodes_from(np.arange(0, dist_mtx.shape[0]))
72
-
73
- # loop through each residue
74
- for rn1 in range(len(dist_mtx)):
75
- # find all residues that are within threshold distance of current
76
- rns_within_threshold = np.where(dist_mtx[rn1] < threshold)[0]
77
-
78
- # add edges from current residue to those that are within threshold
79
- for rn2 in rns_within_threshold:
80
- # don't add self edges
81
- if rn1 != rn2:
82
- g.add_edge(rn1, rn2)
83
- return g
84
-
85
-
86
- def ordered_adjacency_matrix(g):
87
- """ returns the adjacency matrix ordered by node label in increasing order as a numpy array """
88
- node_order = sorted(g.nodes())
89
- adj_mtx = nx.to_numpy_matrix(g, nodelist=node_order)
90
- return np.asarray(adj_mtx).astype(np.float32)
91
-
92
-
93
- def cbeta_distance_matrix(pdb_fn, start=0, end=None):
94
- # note that start and end are not going by residue number
95
- # they are going by whatever the listing in the pdb file is
96
-
97
- # read the pdb file into a biopandas object
98
- ppdb = PandasPdb().read_pdb(pdb_fn)
99
-
100
- # group by residue number
101
- # important to specify sort=True so that group keys (residue number) are in order
102
- # the reason is we loop through group keys below, and assume that residues are in order
103
- # the pandas function has sort=True by default, but we specify it anyway because it is important
104
- grouped = ppdb.df["ATOM"].groupby("residue_number", sort=True)
105
-
106
- # a list of coords for the cbeta or calpha of each residue
107
- coords = []
108
-
109
- # loop through each residue and find the coordinates of cbeta
110
- for i, (residue_number, values) in enumerate(grouped):
111
-
112
- # skip residues not in the range
113
- end_index = (len(grouped) if end is None else end)
114
- if i not in range(start, end_index):
115
- continue
116
-
117
- residue_group = grouped.get_group(residue_number)
118
-
119
- atom_names = residue_group["atom_name"]
120
- if "CB" in atom_names.values:
121
- # print("Using CB...")
122
- atom_name = "CB"
123
- elif "CA" in atom_names.values:
124
- # print("Using CA...")
125
- atom_name = "CA"
126
- else:
127
- raise ValueError("Couldn't find CB or CA for residue {}".format(residue_number))
128
-
129
- # get the coordinates of cbeta (or calpha)
130
- coords.append(
131
- residue_group[residue_group["atom_name"] == atom_name][["x_coord", "y_coord", "z_coord"]].values[0])
132
-
133
- # stack the coords into a numpy array where each row has the x,y,z coords for a different residue
134
- coords = np.stack(coords)
135
-
136
- # compute pairwise euclidean distance between all cbetas
137
- dist_mtx = cdist(coords, coords, metric="euclidean")
138
-
139
- return dist_mtx
140
-
141
-
142
- def get_neighbors(g, nodes):
143
- """ returns a list (set) of neighbors of all given nodes """
144
- neighbors = set()
145
- for n in nodes:
146
- neighbors.update(g.neighbors(n))
147
- return sorted(list(neighbors))
148
-
149
-
150
- def gen_graph(graph_type, res_dist_mtx, dist_thresh=7, shuffle_seed=7, graph_save_dir=None, save=False):
151
- """ generate the specified structure graph using the specified residue distance matrix """
152
- if graph_type is GraphType.LINEAR:
153
- g = linear_graph(len(res_dist_mtx))
154
- save_fn = None if not save else os.path.join(graph_save_dir, "linear.graph")
155
-
156
- elif graph_type is GraphType.COMPLETE:
157
- g = complete_graph(len(res_dist_mtx))
158
- save_fn = None if not save else os.path.join(graph_save_dir, "complete.graph")
159
-
160
- elif graph_type is GraphType.DISCONNECTED:
161
- g = disconnected_graph(len(res_dist_mtx))
162
- save_fn = None if not save else os.path.join(graph_save_dir, "disconnected.graph")
163
-
164
- elif graph_type is GraphType.DIST_THRESH:
165
- g = dist_thresh_graph(res_dist_mtx, dist_thresh)
166
- save_fn = None if not save else os.path.join(graph_save_dir, "dist_thresh_{}.graph".format(dist_thresh))
167
-
168
- elif graph_type is GraphType.DIST_THRESH_SHUFFLED:
169
- g = dist_thresh_graph(res_dist_mtx, dist_thresh)
170
- g = shuffle_nodes(g, seed=shuffle_seed)
171
- save_fn = None if not save else \
172
- os.path.join(graph_save_dir, "dist_thresh_{}_shuffled_r{}.graph".format(dist_thresh, shuffle_seed))
173
-
174
- else:
175
- raise ValueError("Graph type {} is not implemented".format(graph_type))
176
-
177
- if save:
178
- if isfile(save_fn):
179
- print("err: graph already exists: {}. to overwrite, delete the existing file first".format(save_fn))
180
- else:
181
- os.makedirs(graph_save_dir, exist_ok=True)
182
- save_graph(g, save_fn)
183
-
184
- return g