Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- cas12.py +220 -0
- cas12lstm.py +258 -0
- cas12lstmvcf.py +351 -0
- cas9att.py +296 -0
- cas9attvcf.py +393 -0
- cas9off.py +134 -0
- tiger.py +417 -0
cas12.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from keras import Model
|
2 |
+
from keras.layers import Input
|
3 |
+
from keras.layers import Multiply
|
4 |
+
from keras.layers import Dense, Dropout, Activation, Flatten
|
5 |
+
from keras.layers import Convolution1D, AveragePooling1D
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
import keras
|
9 |
+
import requests
|
10 |
+
from functools import reduce
|
11 |
+
from operator import add
|
12 |
+
from Bio.SeqRecord import SeqRecord
|
13 |
+
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
14 |
+
from Bio.Seq import Seq
|
15 |
+
from Bio import SeqIO
|
16 |
+
|
17 |
+
ntmap = {'A': (1, 0, 0, 0),
|
18 |
+
'C': (0, 1, 0, 0),
|
19 |
+
'G': (0, 0, 1, 0),
|
20 |
+
'T': (0, 0, 0, 1)
|
21 |
+
}
|
22 |
+
|
23 |
+
def get_seqcode(seq):
|
24 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
25 |
+
|
26 |
+
def Seq_DeepCpf1_model(input_shape):
|
27 |
+
Seq_deepCpf1_Input_SEQ = Input(shape=input_shape)
|
28 |
+
Seq_deepCpf1_C1 = Convolution1D(80, 5, activation='relu')(Seq_deepCpf1_Input_SEQ)
|
29 |
+
Seq_deepCpf1_P1 = AveragePooling1D(2)(Seq_deepCpf1_C1)
|
30 |
+
Seq_deepCpf1_F = Flatten()(Seq_deepCpf1_P1)
|
31 |
+
Seq_deepCpf1_DO1 = Dropout(0.3)(Seq_deepCpf1_F)
|
32 |
+
Seq_deepCpf1_D1 = Dense(80, activation='relu')(Seq_deepCpf1_DO1)
|
33 |
+
Seq_deepCpf1_DO2 = Dropout(0.3)(Seq_deepCpf1_D1)
|
34 |
+
Seq_deepCpf1_D2 = Dense(40, activation='relu')(Seq_deepCpf1_DO2)
|
35 |
+
Seq_deepCpf1_DO3 = Dropout(0.3)(Seq_deepCpf1_D2)
|
36 |
+
Seq_deepCpf1_D3 = Dense(40, activation='relu')(Seq_deepCpf1_DO3)
|
37 |
+
Seq_deepCpf1_DO4 = Dropout(0.3)(Seq_deepCpf1_D3)
|
38 |
+
Seq_deepCpf1_Output = Dense(1, activation='linear')(Seq_deepCpf1_DO4)
|
39 |
+
Seq_deepCpf1 = Model(inputs=[Seq_deepCpf1_Input_SEQ], outputs=[Seq_deepCpf1_Output])
|
40 |
+
return Seq_deepCpf1
|
41 |
+
|
42 |
+
# seq-ca model (DeepCpf1)
|
43 |
+
def DeepCpf1_model(input_shape):
|
44 |
+
DeepCpf1_Input_SEQ = Input(shape=input_shape)
|
45 |
+
DeepCpf1_C1 = Convolution1D(80, 5, activation='relu')(DeepCpf1_Input_SEQ)
|
46 |
+
DeepCpf1_P1 = AveragePooling1D(2)(DeepCpf1_C1)
|
47 |
+
DeepCpf1_F = Flatten()(DeepCpf1_P1)
|
48 |
+
DeepCpf1_DO1 = Dropout(0.3)(DeepCpf1_F)
|
49 |
+
DeepCpf1_D1 = Dense(80, activation='relu')(DeepCpf1_DO1)
|
50 |
+
DeepCpf1_DO2 = Dropout(0.3)(DeepCpf1_D1)
|
51 |
+
DeepCpf1_D2 = Dense(40, activation='relu')(DeepCpf1_DO2)
|
52 |
+
DeepCpf1_DO3 = Dropout(0.3)(DeepCpf1_D2)
|
53 |
+
DeepCpf1_D3_SEQ = Dense(40, activation='relu')(DeepCpf1_DO3)
|
54 |
+
DeepCpf1_Input_CA = Input(shape=(1,))
|
55 |
+
DeepCpf1_D3_CA = Dense(40, activation='relu')(DeepCpf1_Input_CA)
|
56 |
+
DeepCpf1_M = Multiply()([DeepCpf1_D3_SEQ, DeepCpf1_D3_CA])
|
57 |
+
DeepCpf1_DO4 = Dropout(0.3)(DeepCpf1_M)
|
58 |
+
DeepCpf1_Output = Dense(1, activation='linear')(DeepCpf1_DO4)
|
59 |
+
DeepCpf1 = Model(inputs=[DeepCpf1_Input_SEQ, DeepCpf1_Input_CA], outputs=[DeepCpf1_Output])
|
60 |
+
return DeepCpf1
|
61 |
+
|
62 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
63 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
64 |
+
response = requests.get(url)
|
65 |
+
if response.status_code == 200:
|
66 |
+
gene_data = response.json()
|
67 |
+
if 'Transcript' in gene_data:
|
68 |
+
return gene_data['Transcript']
|
69 |
+
else:
|
70 |
+
print("No transcripts found for gene:", gene_symbol)
|
71 |
+
return None
|
72 |
+
else:
|
73 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
74 |
+
return None
|
75 |
+
|
76 |
+
def fetch_ensembl_sequence(transcript_id):
|
77 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
78 |
+
response = requests.get(url)
|
79 |
+
if response.status_code == 200:
|
80 |
+
sequence_data = response.json()
|
81 |
+
if 'seq' in sequence_data:
|
82 |
+
return sequence_data['seq']
|
83 |
+
else:
|
84 |
+
print("No sequence found for transcript:", transcript_id)
|
85 |
+
return None
|
86 |
+
else:
|
87 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
88 |
+
return None
|
89 |
+
|
90 |
+
def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="TTTN", target_length=34):
|
91 |
+
targets = []
|
92 |
+
len_sequence = len(sequence)
|
93 |
+
complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
|
94 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
95 |
+
|
96 |
+
for i in range(len_sequence - target_length + 1):
|
97 |
+
target_seq = sequence[i:i + target_length]
|
98 |
+
if target_seq[4:7] == 'TTT':
|
99 |
+
if strand == -1:
|
100 |
+
tar_start = end - i - target_length + 1
|
101 |
+
tar_end = end -i
|
102 |
+
#seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
|
103 |
+
else:
|
104 |
+
tar_start = start + i
|
105 |
+
tar_end = start + i + target_length - 1
|
106 |
+
#seq_in_ref = target_seq
|
107 |
+
gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
|
108 |
+
targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
|
109 |
+
return targets
|
110 |
+
|
111 |
+
def format_prediction_output(targets, model_path):
|
112 |
+
# Loading weights for the model
|
113 |
+
Seq_deepCpf1 = Seq_DeepCpf1_model(input_shape=(34, 4))
|
114 |
+
Seq_deepCpf1.load_weights(model_path)
|
115 |
+
|
116 |
+
formatted_data = []
|
117 |
+
for target in targets:
|
118 |
+
# Predict
|
119 |
+
encoded_seq = get_seqcode(target[0])
|
120 |
+
prediction = float(list(Seq_deepCpf1.predict(encoded_seq)[0])[0])
|
121 |
+
if prediction > 100:
|
122 |
+
prediction = 100
|
123 |
+
|
124 |
+
# Format output
|
125 |
+
gRNA = target[1]
|
126 |
+
chr = target[2]
|
127 |
+
start = target[3]
|
128 |
+
end = target[4]
|
129 |
+
strand = target[5]
|
130 |
+
transcript_id = target[6]
|
131 |
+
exon_id = target[7]
|
132 |
+
formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
|
133 |
+
|
134 |
+
return formatted_data
|
135 |
+
|
136 |
+
def process_gene(gene_symbol, model_path):
|
137 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
138 |
+
results = []
|
139 |
+
all_exons = []
|
140 |
+
all_gene_sequences = []
|
141 |
+
|
142 |
+
if transcripts:
|
143 |
+
for i in range(len(transcripts)):
|
144 |
+
Exons = transcripts[i]['Exon']
|
145 |
+
all_exons.append(Exons)
|
146 |
+
transcript_id = transcripts[i]['id']
|
147 |
+
for j in range(len(Exons)):
|
148 |
+
exon_id = Exons[j]['id']
|
149 |
+
gene_sequence = fetch_ensembl_sequence(exon_id)
|
150 |
+
if gene_sequence:
|
151 |
+
all_gene_sequences.append(gene_sequence)
|
152 |
+
start = Exons[j]['start']
|
153 |
+
end = Exons[j]['end']
|
154 |
+
strand = Exons[j]['strand']
|
155 |
+
chr = Exons[j]['seq_region_name']
|
156 |
+
targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
|
157 |
+
if targets:
|
158 |
+
formatted_data = format_prediction_output(targets, model_path)
|
159 |
+
results.append(formatted_data)
|
160 |
+
# for data in formatted_data:
|
161 |
+
# print(f"Chr: {data[0]}, Start: {data[1]}, End: {data[2]}, Strand: {data[3]}, target: {data[4]}, gRNA: {data[5]}, pred_Score: {data[6]}")
|
162 |
+
else:
|
163 |
+
print("Failed to retrieve gene sequence.")
|
164 |
+
else:
|
165 |
+
print("Failed to retrieve transcripts.")
|
166 |
+
|
167 |
+
return results, all_gene_sequences, all_exons
|
168 |
+
|
169 |
+
|
170 |
+
# def create_genbank_features(formatted_data):
|
171 |
+
# features = []
|
172 |
+
# for data in formatted_data:
|
173 |
+
# try:
|
174 |
+
# # Attempt to convert start and end positions to integers
|
175 |
+
# start = int(data[1])
|
176 |
+
# end = int(data[2])
|
177 |
+
# except ValueError as e:
|
178 |
+
# # Log the error and skip this iteration if conversion fails
|
179 |
+
# print(f"Error converting start/end to int: {data[1]}, {data[2]} - {e}")
|
180 |
+
# continue # Skip this iteration
|
181 |
+
#
|
182 |
+
# # Proceed as normal if conversion is successful
|
183 |
+
# strand = 1 if data[3] == '+' else -1
|
184 |
+
# location = FeatureLocation(start=start, end=end, strand=strand)
|
185 |
+
# feature = SeqFeature(location=location, type="misc_feature", qualifiers={
|
186 |
+
# 'label': data[5], # gRNA as label
|
187 |
+
# 'note': f"Prediction: {data[6]}" # Prediction score in note
|
188 |
+
# })
|
189 |
+
# features.append(feature)
|
190 |
+
# return features
|
191 |
+
#
|
192 |
+
# def generate_genbank_file_from_data(formatted_data, gene_sequence, gene_symbol, output_path):
|
193 |
+
# features = create_genbank_features(formatted_data)
|
194 |
+
# record = SeqRecord(Seq(gene_sequence), id=gene_symbol, name=gene_symbol,
|
195 |
+
# description='CRISPR Cas12 predicted targets', features=features)
|
196 |
+
# record.annotations["molecule_type"] = "DNA"
|
197 |
+
# SeqIO.write(record, output_path, "genbank")
|
198 |
+
#
|
199 |
+
# def create_csv_from_df(df, output_path):
|
200 |
+
# df.to_csv(output_path, index=False)
|
201 |
+
#
|
202 |
+
# def generate_bed_file_from_data(formatted_data, output_path):
|
203 |
+
# with open(output_path, 'w') as bed_file:
|
204 |
+
# for data in formatted_data:
|
205 |
+
# try:
|
206 |
+
# # Ensure data has the expected number of elements
|
207 |
+
# if len(data) < 7:
|
208 |
+
# raise ValueError("Incomplete data item")
|
209 |
+
#
|
210 |
+
# chrom = data[0]
|
211 |
+
# start = data[1]
|
212 |
+
# end = data[2]
|
213 |
+
# strand = '+' if data[3] == '+' else '-'
|
214 |
+
# gRNA = data[5]
|
215 |
+
# score = data[6] # Ensure this index exists
|
216 |
+
#
|
217 |
+
# bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
|
218 |
+
# except ValueError as e:
|
219 |
+
# print(f"Skipping an item due to error: {e}")
|
220 |
+
# continue
|
cas12lstm.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from keras import regularizers
|
3 |
+
from keras.layers import Input, Dense, Dropout, Activation, Conv1D
|
4 |
+
from keras.layers import GlobalAveragePooling1D, AveragePooling1D
|
5 |
+
from keras.layers import Bidirectional, LSTM
|
6 |
+
from keras import Model
|
7 |
+
from keras.metrics import MeanSquaredError
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import requests
|
13 |
+
from functools import reduce
|
14 |
+
from operator import add
|
15 |
+
import tabulate
|
16 |
+
from difflib import SequenceMatcher
|
17 |
+
from Bio import SeqIO
|
18 |
+
from Bio.SeqRecord import SeqRecord
|
19 |
+
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
20 |
+
from Bio.Seq import Seq
|
21 |
+
|
22 |
+
import cyvcf2
|
23 |
+
import parasail
|
24 |
+
|
25 |
+
import re
|
26 |
+
|
27 |
+
ntmap = {'A': (1, 0, 0, 0),
|
28 |
+
'C': (0, 1, 0, 0),
|
29 |
+
'G': (0, 0, 1, 0),
|
30 |
+
'T': (0, 0, 0, 1)
|
31 |
+
}
|
32 |
+
|
33 |
+
def get_seqcode(seq):
|
34 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
35 |
+
|
36 |
+
def BiLSTM_model(input_shape):
|
37 |
+
input = Input(shape=input_shape)
|
38 |
+
|
39 |
+
conv1 = Conv1D(128, 5, activation="relu")(input)
|
40 |
+
pool1 = AveragePooling1D(2)(conv1)
|
41 |
+
drop1 = Dropout(0.1)(pool1)
|
42 |
+
|
43 |
+
conv2 = Conv1D(128, 5, activation="relu")(drop1)
|
44 |
+
pool2 = AveragePooling1D(2)(conv2)
|
45 |
+
drop2 = Dropout(0.1)(pool2)
|
46 |
+
|
47 |
+
lstm1 = Bidirectional(LSTM(128,
|
48 |
+
dropout=0.1,
|
49 |
+
activation='tanh',
|
50 |
+
return_sequences=True,
|
51 |
+
kernel_regularizer=regularizers.l2(1e-4)))(drop2)
|
52 |
+
avgpool = GlobalAveragePooling1D()(lstm1)
|
53 |
+
|
54 |
+
dense1 = Dense(128,
|
55 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
56 |
+
bias_regularizer=regularizers.l2(1e-4),
|
57 |
+
activation="relu")(avgpool)
|
58 |
+
drop3 = Dropout(0.1)(dense1)
|
59 |
+
|
60 |
+
dense2 = Dense(32,
|
61 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
62 |
+
bias_regularizer=regularizers.l2(1e-4),
|
63 |
+
activation="relu")(drop3)
|
64 |
+
drop4 = Dropout(0.1)(dense2)
|
65 |
+
|
66 |
+
dense3 = Dense(32,
|
67 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
68 |
+
bias_regularizer=regularizers.l2(1e-4),
|
69 |
+
activation="relu")(drop4)
|
70 |
+
drop5 = Dropout(0.1)(dense3)
|
71 |
+
|
72 |
+
output = Dense(1, activation="linear")(drop5)
|
73 |
+
|
74 |
+
model = Model(inputs=[input], outputs=[output])
|
75 |
+
return model
|
76 |
+
|
77 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
78 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
79 |
+
response = requests.get(url)
|
80 |
+
if response.status_code == 200:
|
81 |
+
gene_data = response.json()
|
82 |
+
if 'Transcript' in gene_data:
|
83 |
+
return gene_data['Transcript']
|
84 |
+
else:
|
85 |
+
print("No transcripts found for gene:", gene_symbol)
|
86 |
+
return None
|
87 |
+
else:
|
88 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
89 |
+
return None
|
90 |
+
|
91 |
+
def fetch_ensembl_sequence(transcript_id):
|
92 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
93 |
+
response = requests.get(url)
|
94 |
+
if response.status_code == 200:
|
95 |
+
sequence_data = response.json()
|
96 |
+
if 'seq' in sequence_data:
|
97 |
+
return sequence_data['seq']
|
98 |
+
else:
|
99 |
+
print("No sequence found for transcript:", transcript_id)
|
100 |
+
return None
|
101 |
+
else:
|
102 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
103 |
+
return None
|
104 |
+
|
105 |
+
def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="TTTN", target_length=34):
|
106 |
+
targets = []
|
107 |
+
len_sequence = len(sequence)
|
108 |
+
#complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
|
109 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
110 |
+
|
111 |
+
for i in range(len_sequence - target_length + 1):
|
112 |
+
target_seq = sequence[i:i + target_length]
|
113 |
+
if target_seq[4:7] == 'TTT':
|
114 |
+
if strand == -1:
|
115 |
+
tar_start = end - i - target_length + 1
|
116 |
+
tar_end = end -i
|
117 |
+
#seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
|
118 |
+
else:
|
119 |
+
tar_start = start + i
|
120 |
+
tar_end = start + i + target_length - 1
|
121 |
+
#seq_in_ref = target_seq
|
122 |
+
gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
|
123 |
+
targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
|
124 |
+
#targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
|
125 |
+
return targets
|
126 |
+
|
127 |
+
def format_prediction_output(targets, model_path):
|
128 |
+
# Loading weights for the model
|
129 |
+
Crispr_BiLSTM = BiLSTM_model(input_shape=(34, 4))
|
130 |
+
Crispr_BiLSTM.load_weights(model_path)
|
131 |
+
|
132 |
+
formatted_data = []
|
133 |
+
for target in targets:
|
134 |
+
# Predict
|
135 |
+
encoded_seq = get_seqcode(target[0])
|
136 |
+
prediction = float(list(Crispr_BiLSTM.predict(encoded_seq, verbose=0)[0])[0])
|
137 |
+
if prediction > 100:
|
138 |
+
prediction = 100
|
139 |
+
|
140 |
+
# Format output
|
141 |
+
gRNA = target[1]
|
142 |
+
chr = target[2]
|
143 |
+
start = target[3]
|
144 |
+
end = target[4]
|
145 |
+
strand = target[5]
|
146 |
+
transcript_id = target[6]
|
147 |
+
exon_id = target[7]
|
148 |
+
#seq_in_ref = target[8]
|
149 |
+
#formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction])
|
150 |
+
formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
|
151 |
+
|
152 |
+
return formatted_data
|
153 |
+
|
154 |
+
|
155 |
+
def process_gene(gene_symbol, model_path):
|
156 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
157 |
+
results = []
|
158 |
+
all_exons = [] # To accumulate all exons
|
159 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
160 |
+
|
161 |
+
if transcripts:
|
162 |
+
for transcript in transcripts:
|
163 |
+
Exons = transcript['Exon']
|
164 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
165 |
+
transcript_id = transcript['id']
|
166 |
+
|
167 |
+
for Exon in Exons:
|
168 |
+
exon_id = Exon['id']
|
169 |
+
gene_sequence = fetch_ensembl_sequence(exon_id)
|
170 |
+
if gene_sequence:
|
171 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
172 |
+
chr = Exon['seq_region_name']
|
173 |
+
start = Exon['start']
|
174 |
+
end = Exon['end']
|
175 |
+
strand = Exon['strand']
|
176 |
+
|
177 |
+
targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
|
178 |
+
if targets:
|
179 |
+
# Predict on-target efficiency for each gRNA site
|
180 |
+
formatted_data = format_prediction_output(targets, model_path)
|
181 |
+
results.extend(formatted_data) # Flatten the results
|
182 |
+
else:
|
183 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
184 |
+
else:
|
185 |
+
print("Failed to retrieve transcripts.")
|
186 |
+
|
187 |
+
output = []
|
188 |
+
for result in results:
|
189 |
+
for item in result:
|
190 |
+
output.append(item)
|
191 |
+
|
192 |
+
# Return the sorted output, combined gene sequences, and all exons
|
193 |
+
return results, all_gene_sequences, all_exons
|
194 |
+
|
195 |
+
def create_genbank_features(data):
|
196 |
+
features = []
|
197 |
+
|
198 |
+
# If the input data is a DataFrame, convert it to a list of lists
|
199 |
+
if isinstance(data, pd.DataFrame):
|
200 |
+
formatted_data = data.values.tolist()
|
201 |
+
elif isinstance(data, list):
|
202 |
+
formatted_data = data
|
203 |
+
else:
|
204 |
+
raise TypeError("Data should be either a list or a pandas DataFrame.")
|
205 |
+
|
206 |
+
for row in formatted_data:
|
207 |
+
try:
|
208 |
+
start = int(row[1])
|
209 |
+
end = int(row[2])
|
210 |
+
except ValueError as e:
|
211 |
+
print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
|
212 |
+
continue
|
213 |
+
|
214 |
+
strand = 1 if row[3] == '+' else -1
|
215 |
+
location = FeatureLocation(start=start, end=end, strand=strand)
|
216 |
+
feature = SeqFeature(location=location, type="misc_feature", qualifiers={
|
217 |
+
'label': row[7], # Use gRNA as the label
|
218 |
+
'note': f"Prediction: {row[8]}" # Include the prediction score
|
219 |
+
})
|
220 |
+
features.append(feature)
|
221 |
+
|
222 |
+
return features
|
223 |
+
|
224 |
+
|
225 |
+
def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
|
226 |
+
# Ensure gene_sequence is a string before creating Seq object
|
227 |
+
if not isinstance(gene_sequence, str):
|
228 |
+
gene_sequence = str(gene_sequence)
|
229 |
+
|
230 |
+
features = create_genbank_features(df)
|
231 |
+
|
232 |
+
# Now gene_sequence is guaranteed to be a string, suitable for Seq
|
233 |
+
seq_obj = Seq(gene_sequence)
|
234 |
+
record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
|
235 |
+
description=f'CRISPR Cas12 predicted targets for {gene_symbol}', features=features)
|
236 |
+
record.annotations["molecule_type"] = "DNA"
|
237 |
+
SeqIO.write(record, output_path, "genbank")
|
238 |
+
|
239 |
+
|
240 |
+
def create_bed_file_from_df(df, output_path):
|
241 |
+
with open(output_path, 'w') as bed_file:
|
242 |
+
for index, row in df.iterrows():
|
243 |
+
chrom = row["Chr"]
|
244 |
+
start = int(row["Start Pos"])
|
245 |
+
end = int(row["End Pos"])
|
246 |
+
strand = '+' if row["Strand"] == '1' else '-'
|
247 |
+
gRNA = row["gRNA"]
|
248 |
+
score = str(row["Prediction"])
|
249 |
+
# transcript_id is not typically part of the standard BED columns but added here for completeness
|
250 |
+
transcript_id = row["Transcript"]
|
251 |
+
|
252 |
+
# Writing only standard BED columns; additional columns can be appended as needed
|
253 |
+
bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
|
254 |
+
|
255 |
+
|
256 |
+
def create_csv_from_df(df, output_path):
|
257 |
+
df.to_csv(output_path, index=False)
|
258 |
+
|
cas12lstmvcf.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from keras import regularizers
|
3 |
+
from keras.layers import Input, Dense, Dropout, Activation, Conv1D
|
4 |
+
from keras.layers import GlobalAveragePooling1D, AveragePooling1D
|
5 |
+
from keras.layers import Bidirectional, LSTM
|
6 |
+
from keras import Model
|
7 |
+
from keras.metrics import MeanSquaredError
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
from Bio import SeqIO
|
12 |
+
from Bio.SeqRecord import SeqRecord
|
13 |
+
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
14 |
+
from Bio.Seq import Seq
|
15 |
+
|
16 |
+
import requests
|
17 |
+
from functools import reduce
|
18 |
+
from operator import add
|
19 |
+
import tabulate
|
20 |
+
from difflib import SequenceMatcher
|
21 |
+
|
22 |
+
import cyvcf2
|
23 |
+
import parasail
|
24 |
+
|
25 |
+
import re
|
26 |
+
|
27 |
+
ntmap = {'A': (1, 0, 0, 0),
|
28 |
+
'C': (0, 1, 0, 0),
|
29 |
+
'G': (0, 0, 1, 0),
|
30 |
+
'T': (0, 0, 0, 1)
|
31 |
+
}
|
32 |
+
|
33 |
+
def get_seqcode(seq):
|
34 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
35 |
+
|
36 |
+
def BiLSTM_model(input_shape):
|
37 |
+
input = Input(shape=input_shape)
|
38 |
+
|
39 |
+
conv1 = Conv1D(128, 5, activation="relu")(input)
|
40 |
+
pool1 = AveragePooling1D(2)(conv1)
|
41 |
+
drop1 = Dropout(0.1)(pool1)
|
42 |
+
|
43 |
+
conv2 = Conv1D(128, 5, activation="relu")(drop1)
|
44 |
+
pool2 = AveragePooling1D(2)(conv2)
|
45 |
+
drop2 = Dropout(0.1)(pool2)
|
46 |
+
|
47 |
+
lstm1 = Bidirectional(LSTM(128,
|
48 |
+
dropout=0.1,
|
49 |
+
activation='tanh',
|
50 |
+
return_sequences=True,
|
51 |
+
kernel_regularizer=regularizers.l2(1e-4)))(drop2)
|
52 |
+
avgpool = GlobalAveragePooling1D()(lstm1)
|
53 |
+
|
54 |
+
dense1 = Dense(128,
|
55 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
56 |
+
bias_regularizer=regularizers.l2(1e-4),
|
57 |
+
activation="relu")(avgpool)
|
58 |
+
drop3 = Dropout(0.1)(dense1)
|
59 |
+
|
60 |
+
dense2 = Dense(32,
|
61 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
62 |
+
bias_regularizer=regularizers.l2(1e-4),
|
63 |
+
activation="relu")(drop3)
|
64 |
+
drop4 = Dropout(0.1)(dense2)
|
65 |
+
|
66 |
+
dense3 = Dense(32,
|
67 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
68 |
+
bias_regularizer=regularizers.l2(1e-4),
|
69 |
+
activation="relu")(drop4)
|
70 |
+
drop5 = Dropout(0.1)(dense3)
|
71 |
+
|
72 |
+
output = Dense(1, activation="linear")(drop5)
|
73 |
+
|
74 |
+
model = Model(inputs=[input], outputs=[output])
|
75 |
+
return model
|
76 |
+
|
77 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
78 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
79 |
+
response = requests.get(url)
|
80 |
+
if response.status_code == 200:
|
81 |
+
gene_data = response.json()
|
82 |
+
if 'Transcript' in gene_data:
|
83 |
+
return gene_data['Transcript']
|
84 |
+
else:
|
85 |
+
print("No transcripts found for gene:", gene_symbol)
|
86 |
+
return None
|
87 |
+
else:
|
88 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
89 |
+
return None
|
90 |
+
|
91 |
+
def fetch_ensembl_sequence(transcript_id):
|
92 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
93 |
+
response = requests.get(url)
|
94 |
+
if response.status_code == 200:
|
95 |
+
sequence_data = response.json()
|
96 |
+
if 'seq' in sequence_data:
|
97 |
+
return sequence_data['seq']
|
98 |
+
else:
|
99 |
+
print("No sequence found for transcript:", transcript_id)
|
100 |
+
return None
|
101 |
+
else:
|
102 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
103 |
+
return None
|
104 |
+
|
105 |
+
def apply_mutation(ref_sequence, offset, ref, alt):
|
106 |
+
"""
|
107 |
+
Apply a single mutation to the sequence.
|
108 |
+
"""
|
109 |
+
if len(ref) == len(alt) and alt != "*": # SNP
|
110 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]
|
111 |
+
|
112 |
+
elif len(ref) < len(alt): # Insertion
|
113 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]
|
114 |
+
|
115 |
+
elif len(ref) == len(alt) and alt == "*": # Deletion
|
116 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]
|
117 |
+
|
118 |
+
elif len(ref) > len(alt) and alt != "*": # Deletion
|
119 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]
|
120 |
+
|
121 |
+
elif len(ref) > len(alt) and alt == "*": # Deletion
|
122 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]
|
123 |
+
|
124 |
+
return mutated_seq
|
125 |
+
|
126 |
+
|
127 |
+
def construct_combinations(sequence, mutations):
|
128 |
+
"""
|
129 |
+
Construct all combinations of mutations.
|
130 |
+
mutations is a list of tuples (position, ref, [alts])
|
131 |
+
"""
|
132 |
+
if not mutations:
|
133 |
+
return [sequence]
|
134 |
+
|
135 |
+
# Take the first mutation and recursively construct combinations for the rest
|
136 |
+
first_mutation = mutations[0]
|
137 |
+
rest_mutations = mutations[1:]
|
138 |
+
offset, ref, alts = first_mutation
|
139 |
+
|
140 |
+
sequences = []
|
141 |
+
for alt in alts:
|
142 |
+
mutated_sequence = apply_mutation(sequence, offset, ref, alt)
|
143 |
+
sequences.extend(construct_combinations(mutated_sequence, rest_mutations))
|
144 |
+
|
145 |
+
return sequences
|
146 |
+
|
147 |
+
def needleman_wunsch_alignment(query_seq, ref_seq):
|
148 |
+
"""
|
149 |
+
Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
|
150 |
+
Use this position to represent the position of target sequence with mutations
|
151 |
+
"""
|
152 |
+
# Needleman-Wunsch alignment
|
153 |
+
alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)
|
154 |
+
|
155 |
+
# extract CIGAR object
|
156 |
+
cigar = alignment.cigar
|
157 |
+
cigar_string = cigar.decode.decode("utf-8")
|
158 |
+
|
159 |
+
# record ref_pos
|
160 |
+
ref_pos = 0
|
161 |
+
|
162 |
+
matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
|
163 |
+
max_num_before_equal = 0
|
164 |
+
max_equal_index = -1
|
165 |
+
total_before_max_equal = 0
|
166 |
+
|
167 |
+
for i, (num_str, op) in enumerate(matches):
|
168 |
+
num = int(num_str)
|
169 |
+
if op == '=':
|
170 |
+
if num > max_num_before_equal:
|
171 |
+
max_num_before_equal = num
|
172 |
+
max_equal_index = i
|
173 |
+
total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))
|
174 |
+
|
175 |
+
ref_pos = total_before_max_equal
|
176 |
+
|
177 |
+
return ref_pos
|
178 |
+
|
179 |
+
def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
|
180 |
+
exon_id, gene_symbol, vcf_reader, pam="TTTN", target_length=34):
|
181 |
+
# initialization
|
182 |
+
mutated_sequences = [ref_sequence]
|
183 |
+
|
184 |
+
# find mutations within interested region
|
185 |
+
mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
|
186 |
+
if mutations:
|
187 |
+
# find mutations
|
188 |
+
mutation_list = []
|
189 |
+
for mutation in mutations:
|
190 |
+
offset = mutation.POS - start
|
191 |
+
ref = mutation.REF
|
192 |
+
alts = mutation.ALT[:-1]
|
193 |
+
mutation_list.append((offset, ref, alts))
|
194 |
+
|
195 |
+
# replace reference sequence of mutation
|
196 |
+
mutated_sequences = construct_combinations(ref_sequence, mutation_list)
|
197 |
+
|
198 |
+
# find gRNA in ref_sequence or all mutated_sequences
|
199 |
+
targets = []
|
200 |
+
for seq in mutated_sequences:
|
201 |
+
len_sequence = len(seq)
|
202 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
203 |
+
for i in range(len_sequence - target_length + 1):
|
204 |
+
target_seq = seq[i:i + target_length]
|
205 |
+
if target_seq[4:7] == 'TTT':
|
206 |
+
pos = ref_sequence.find(target_seq)
|
207 |
+
if pos != -1:
|
208 |
+
is_mut = False
|
209 |
+
if strand == -1:
|
210 |
+
tar_start = end - pos - target_length + 1
|
211 |
+
else:
|
212 |
+
tar_start = start + pos
|
213 |
+
else:
|
214 |
+
is_mut = True
|
215 |
+
nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
|
216 |
+
if strand == -1:
|
217 |
+
tar_start = str(end - nw_pos - target_length + 1) + '*'
|
218 |
+
else:
|
219 |
+
tar_start = str(start + nw_pos) + '*'
|
220 |
+
gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
|
221 |
+
targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])
|
222 |
+
|
223 |
+
# filter duplicated targets
|
224 |
+
unique_targets_set = set(tuple(element) for element in targets)
|
225 |
+
unique_targets = [list(element) for element in unique_targets_set]
|
226 |
+
|
227 |
+
return unique_targets
|
228 |
+
|
229 |
+
def format_prediction_output_with_mutation(targets, model_path):
|
230 |
+
Crispr_BiLSTM = BiLSTM_model(input_shape=(34, 4))
|
231 |
+
Crispr_BiLSTM.load_weights(model_path)
|
232 |
+
|
233 |
+
formatted_data = []
|
234 |
+
for target in targets:
|
235 |
+
# Predict
|
236 |
+
encoded_seq = get_seqcode(target[0])
|
237 |
+
prediction = float(list(Crispr_BiLSTM.predict(encoded_seq, verbose=0)[0])[0])
|
238 |
+
if prediction > 100:
|
239 |
+
prediction = 100
|
240 |
+
|
241 |
+
# Format output
|
242 |
+
gRNA = target[1]
|
243 |
+
exon_chr = target[2]
|
244 |
+
strand = target[3]
|
245 |
+
tar_start = target[4]
|
246 |
+
transcript_id = target[5]
|
247 |
+
exon_id = target[6]
|
248 |
+
gene_symbol = target[7]
|
249 |
+
is_mut = target[8]
|
250 |
+
formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id, exon_id, target[0], gRNA, prediction, is_mut])
|
251 |
+
|
252 |
+
return formatted_data
|
253 |
+
|
254 |
+
def process_gene(gene_symbol, vcf_reader, model_path):
|
255 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
256 |
+
results = []
|
257 |
+
all_exons = [] # To accumulate all exons
|
258 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
259 |
+
|
260 |
+
if transcripts:
|
261 |
+
for transcript in transcripts:
|
262 |
+
Exons = transcript['Exon']
|
263 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
264 |
+
transcript_id = transcript['id']
|
265 |
+
|
266 |
+
for Exon in Exons:
|
267 |
+
exon_id = Exon['id']
|
268 |
+
gene_sequence = fetch_ensembl_sequence(exon_id) # Reference exon sequence
|
269 |
+
if gene_sequence:
|
270 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
271 |
+
exon_chr = Exon['seq_region_name']
|
272 |
+
start = Exon['start']
|
273 |
+
end = Exon['end']
|
274 |
+
strand = Exon['strand']
|
275 |
+
|
276 |
+
targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand, transcript_id, exon_id, gene_symbol, vcf_reader)
|
277 |
+
if targets:
|
278 |
+
# Predict on-target efficiency for each gRNA site
|
279 |
+
formatted_data = format_prediction_output_with_mutation(targets, model_path)
|
280 |
+
results.extend(formatted_data) # Flatten the results
|
281 |
+
else:
|
282 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
283 |
+
else:
|
284 |
+
print("Failed to retrieve transcripts.")
|
285 |
+
|
286 |
+
# Return the sorted output, combined gene sequences, and all exons
|
287 |
+
return results, all_gene_sequences, all_exons
|
288 |
+
|
289 |
+
def create_genbank_features(data):
|
290 |
+
features = []
|
291 |
+
|
292 |
+
# If the input data is a DataFrame, convert it to a list of lists
|
293 |
+
if isinstance(data, pd.DataFrame):
|
294 |
+
formatted_data = data.values.tolist()
|
295 |
+
elif isinstance(data, list):
|
296 |
+
formatted_data = data
|
297 |
+
else:
|
298 |
+
raise TypeError("Data should be either a list or a pandas DataFrame.")
|
299 |
+
|
300 |
+
for row in formatted_data:
|
301 |
+
try:
|
302 |
+
start = int(row[1])
|
303 |
+
end = start + len(row[6]) # Calculate the end position based on the target sequence length
|
304 |
+
except ValueError as e:
|
305 |
+
print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
|
306 |
+
continue
|
307 |
+
|
308 |
+
strand = 1 if row[3] == '1' else -1
|
309 |
+
location = FeatureLocation(start=start, end=end, strand=strand)
|
310 |
+
is_mutation = 'Yes' if row[9] else 'No'
|
311 |
+
feature = SeqFeature(location=location, type="misc_feature", qualifiers={
|
312 |
+
'label': row[7], # Use gRNA as the label
|
313 |
+
'note': f"Prediction: {row[8]}, Mutation: {is_mutation}" # Include the prediction score and mutation status
|
314 |
+
})
|
315 |
+
features.append(feature)
|
316 |
+
|
317 |
+
return features
|
318 |
+
|
319 |
+
def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
|
320 |
+
# Ensure gene_sequence is a string before creating Seq object
|
321 |
+
if not isinstance(gene_sequence, str):
|
322 |
+
gene_sequence = str(gene_sequence)
|
323 |
+
|
324 |
+
features = create_genbank_features(df)
|
325 |
+
|
326 |
+
# Now gene_sequence is guaranteed to be a string, suitable for Seq
|
327 |
+
seq_obj = Seq(gene_sequence)
|
328 |
+
record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
|
329 |
+
description=f'CRISPR Cas12 predicted targets for {gene_symbol}', features=features)
|
330 |
+
record.annotations["molecule_type"] = "DNA"
|
331 |
+
SeqIO.write(record, output_path, "genbank")
|
332 |
+
|
333 |
+
def create_bed_file_from_df(df, output_path):
|
334 |
+
with open(output_path, 'w') as bed_file:
|
335 |
+
for index, row in df.iterrows():
|
336 |
+
chrom = row["Chr"]
|
337 |
+
start = int(row["Target Start"])
|
338 |
+
end = start + len(row["Target"]) # Calculate the end position based on the target sequence length
|
339 |
+
strand = '+' if row["Strand"] == '1' else '-'
|
340 |
+
gRNA = row["gRNA"]
|
341 |
+
score = str(row["Prediction"])
|
342 |
+
is_mutation = 'Yes' if row["Is Mutation"] else 'No'
|
343 |
+
# transcript_id is not typically part of the standard BED columns but added here for completeness
|
344 |
+
transcript_id = row["Transcript"]
|
345 |
+
|
346 |
+
# Writing only standard BED columns; additional columns can be appended as needed
|
347 |
+
bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\t{is_mutation}\n")
|
348 |
+
|
349 |
+
def create_csv_from_df(df, output_path):
|
350 |
+
df.to_csv(output_path, index=False)
|
351 |
+
|
cas9att.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import tensorflow as tf
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from operator import add
|
6 |
+
from functools import reduce
|
7 |
+
import random
|
8 |
+
import tabulate
|
9 |
+
|
10 |
+
from keras import Model
|
11 |
+
from keras import regularizers
|
12 |
+
from keras.optimizers import Adam
|
13 |
+
from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
|
14 |
+
from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
|
15 |
+
from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
|
16 |
+
from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
|
17 |
+
from keras.models import load_model
|
18 |
+
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
19 |
+
from Bio import SeqIO
|
20 |
+
from Bio.SeqRecord import SeqRecord
|
21 |
+
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
22 |
+
from Bio.Seq import Seq
|
23 |
+
|
24 |
+
import cyvcf2
|
25 |
+
import parasail
|
26 |
+
|
27 |
+
import re
|
28 |
+
|
29 |
+
ntmap = {'A': (1, 0, 0, 0),
|
30 |
+
'C': (0, 1, 0, 0),
|
31 |
+
'G': (0, 0, 1, 0),
|
32 |
+
'T': (0, 0, 0, 1)
|
33 |
+
}
|
34 |
+
|
35 |
+
def get_seqcode(seq):
|
36 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
37 |
+
|
38 |
+
class PositionalEncoding(Layer):
|
39 |
+
def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
|
40 |
+
super(PositionalEncoding, self).__init__()
|
41 |
+
self.sequence_len = sequence_len
|
42 |
+
self.embedding_dim = embedding_dim
|
43 |
+
|
44 |
+
def call(self, x):
|
45 |
+
|
46 |
+
position_embedding = np.array([
|
47 |
+
[pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
|
48 |
+
for pos in range(self.sequence_len)])
|
49 |
+
|
50 |
+
position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i
|
51 |
+
position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1
|
52 |
+
position_embedding = tf.cast(position_embedding, dtype=tf.float32)
|
53 |
+
|
54 |
+
return position_embedding+x
|
55 |
+
|
56 |
+
def get_config(self):
|
57 |
+
config = super().get_config().copy()
|
58 |
+
config.update({
|
59 |
+
'sequence_len' : self.sequence_len,
|
60 |
+
'embedding_dim' : self.embedding_dim,
|
61 |
+
})
|
62 |
+
return config
|
63 |
+
|
64 |
+
def MultiHeadAttention_model(input_shape):
|
65 |
+
input = Input(shape=input_shape)
|
66 |
+
|
67 |
+
conv1 = Conv1D(256, 3, activation="relu")(input)
|
68 |
+
pool1 = AveragePooling1D(2)(conv1)
|
69 |
+
drop1 = Dropout(0.4)(pool1)
|
70 |
+
|
71 |
+
conv2 = Conv1D(256, 3, activation="relu")(drop1)
|
72 |
+
pool2 = AveragePooling1D(2)(conv2)
|
73 |
+
drop2 = Dropout(0.4)(pool2)
|
74 |
+
|
75 |
+
lstm = Bidirectional(LSTM(128,
|
76 |
+
dropout=0.5,
|
77 |
+
activation='tanh',
|
78 |
+
return_sequences=True,
|
79 |
+
kernel_regularizer=regularizers.l2(0.01)))(drop2)
|
80 |
+
|
81 |
+
pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
|
82 |
+
atten = MultiHeadAttention(num_heads=2,
|
83 |
+
key_dim=64,
|
84 |
+
dropout=0.2,
|
85 |
+
kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)
|
86 |
+
|
87 |
+
flat = Flatten()(atten)
|
88 |
+
|
89 |
+
dense1 = Dense(512,
|
90 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
91 |
+
bias_regularizer=regularizers.l2(1e-4),
|
92 |
+
activation="relu")(flat)
|
93 |
+
drop3 = Dropout(0.1)(dense1)
|
94 |
+
|
95 |
+
dense2 = Dense(128,
|
96 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
97 |
+
bias_regularizer=regularizers.l2(1e-4),
|
98 |
+
activation="relu")(drop3)
|
99 |
+
drop4 = Dropout(0.1)(dense2)
|
100 |
+
|
101 |
+
dense3 = Dense(256,
|
102 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
103 |
+
bias_regularizer=regularizers.l2(1e-4),
|
104 |
+
activation="relu")(drop4)
|
105 |
+
drop5 = Dropout(0.1)(dense3)
|
106 |
+
|
107 |
+
output = Dense(1, activation="linear")(drop5)
|
108 |
+
|
109 |
+
model = Model(inputs=[input], outputs=[output])
|
110 |
+
return model
|
111 |
+
|
112 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
113 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
114 |
+
response = requests.get(url)
|
115 |
+
if response.status_code == 200:
|
116 |
+
gene_data = response.json()
|
117 |
+
if 'Transcript' in gene_data:
|
118 |
+
return gene_data['Transcript']
|
119 |
+
else:
|
120 |
+
print("No transcripts found for gene:", gene_symbol)
|
121 |
+
return None
|
122 |
+
else:
|
123 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
124 |
+
return None
|
125 |
+
|
126 |
+
def fetch_ensembl_sequence(transcript_id):
|
127 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
128 |
+
response = requests.get(url)
|
129 |
+
if response.status_code == 200:
|
130 |
+
sequence_data = response.json()
|
131 |
+
if 'seq' in sequence_data:
|
132 |
+
return sequence_data['seq']
|
133 |
+
else:
|
134 |
+
print("No sequence found for transcript:", transcript_id)
|
135 |
+
return None
|
136 |
+
else:
|
137 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
138 |
+
return None
|
139 |
+
|
140 |
+
def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="NGG", target_length=20):
|
141 |
+
targets = []
|
142 |
+
len_sequence = len(sequence)
|
143 |
+
#complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
|
144 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
145 |
+
|
146 |
+
for i in range(len_sequence - len(pam) + 1):
|
147 |
+
if sequence[i + 1:i + 3] == pam[1:]:
|
148 |
+
if i >= target_length:
|
149 |
+
target_seq = sequence[i - target_length:i + 3]
|
150 |
+
if strand == -1:
|
151 |
+
tar_start = end - (i + 2)
|
152 |
+
tar_end = end - (i - target_length)
|
153 |
+
#seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
|
154 |
+
else:
|
155 |
+
tar_start = start + i - target_length
|
156 |
+
tar_end = start + i + 3 - 1
|
157 |
+
#seq_in_ref = target_seq
|
158 |
+
gRNA = ''.join([dnatorna[base] for base in sequence[i - target_length:i]])
|
159 |
+
#targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
|
160 |
+
targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
|
161 |
+
|
162 |
+
return targets
|
163 |
+
|
164 |
+
# Function to predict on-target efficiency and format output
|
165 |
+
def format_prediction_output(targets, model_path):
|
166 |
+
model = MultiHeadAttention_model(input_shape=(23, 4))
|
167 |
+
model.load_weights(model_path)
|
168 |
+
|
169 |
+
formatted_data = []
|
170 |
+
|
171 |
+
for target in targets:
|
172 |
+
# Encode the gRNA sequence
|
173 |
+
encoded_seq = get_seqcode(target[0])
|
174 |
+
|
175 |
+
# Predict on-target efficiency using the model
|
176 |
+
prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
|
177 |
+
if prediction > 100:
|
178 |
+
prediction = 100
|
179 |
+
|
180 |
+
# Format output
|
181 |
+
gRNA = target[1]
|
182 |
+
chr = target[2]
|
183 |
+
start = target[3]
|
184 |
+
end = target[4]
|
185 |
+
strand = target[5]
|
186 |
+
transcript_id = target[6]
|
187 |
+
exon_id = target[7]
|
188 |
+
#seq_in_ref = target[8]
|
189 |
+
#formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction[0]])
|
190 |
+
formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
|
191 |
+
|
192 |
+
return formatted_data
|
193 |
+
|
194 |
+
def process_gene(gene_symbol, model_path):
|
195 |
+
# Fetch transcripts for the given gene symbol
|
196 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
197 |
+
results = []
|
198 |
+
all_exons = [] # To accumulate all exons
|
199 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
200 |
+
|
201 |
+
if transcripts:
|
202 |
+
for transcript in transcripts:
|
203 |
+
Exons = transcript['Exon']
|
204 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
205 |
+
transcript_id = transcript['id']
|
206 |
+
|
207 |
+
for exon in Exons:
|
208 |
+
exon_id = exon['id']
|
209 |
+
gene_sequence = fetch_ensembl_sequence(exon_id)
|
210 |
+
if gene_sequence:
|
211 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
212 |
+
start = exon['start']
|
213 |
+
end = exon['end']
|
214 |
+
strand = exon['strand']
|
215 |
+
chr = exon['seq_region_name']
|
216 |
+
# Find potential CRISPR targets within the exon
|
217 |
+
targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
|
218 |
+
if targets:
|
219 |
+
# Format the prediction output for the targets found
|
220 |
+
formatted_data = format_prediction_output(targets, model_path)
|
221 |
+
results.extend(formatted_data) # Append results
|
222 |
+
else:
|
223 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
224 |
+
else:
|
225 |
+
print("Failed to retrieve transcripts.")
|
226 |
+
|
227 |
+
# Return the sorted output, combined gene sequences, and all exons
|
228 |
+
return results, all_gene_sequences, all_exons
|
229 |
+
|
230 |
+
|
231 |
+
def create_genbank_features(data):
|
232 |
+
features = []
|
233 |
+
|
234 |
+
# If the input data is a DataFrame, convert it to a list of lists
|
235 |
+
if isinstance(data, pd.DataFrame):
|
236 |
+
formatted_data = data.values.tolist()
|
237 |
+
elif isinstance(data, list):
|
238 |
+
formatted_data = data
|
239 |
+
else:
|
240 |
+
raise TypeError("Data should be either a list or a pandas DataFrame.")
|
241 |
+
|
242 |
+
for row in formatted_data:
|
243 |
+
try:
|
244 |
+
start = int(row[1])
|
245 |
+
end = int(row[2])
|
246 |
+
except ValueError as e:
|
247 |
+
print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
|
248 |
+
continue
|
249 |
+
|
250 |
+
strand = 1 if row[3] == '+' else -1
|
251 |
+
location = FeatureLocation(start=start, end=end, strand=strand)
|
252 |
+
feature = SeqFeature(location=location, type="misc_feature", qualifiers={
|
253 |
+
'label': row[7], # Use gRNA as the label
|
254 |
+
'note': f"Prediction: {row[8]}" # Include the prediction score
|
255 |
+
})
|
256 |
+
features.append(feature)
|
257 |
+
|
258 |
+
return features
|
259 |
+
|
260 |
+
|
261 |
+
def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
|
262 |
+
# Ensure gene_sequence is a string before creating Seq object
|
263 |
+
if not isinstance(gene_sequence, str):
|
264 |
+
gene_sequence = str(gene_sequence)
|
265 |
+
|
266 |
+
features = create_genbank_features(df)
|
267 |
+
|
268 |
+
# Now gene_sequence is guaranteed to be a string, suitable for Seq
|
269 |
+
seq_obj = Seq(gene_sequence)
|
270 |
+
record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
|
271 |
+
description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
|
272 |
+
record.annotations["molecule_type"] = "DNA"
|
273 |
+
SeqIO.write(record, output_path, "genbank")
|
274 |
+
|
275 |
+
|
276 |
+
def create_bed_file_from_df(df, output_path):
|
277 |
+
with open(output_path, 'w') as bed_file:
|
278 |
+
for index, row in df.iterrows():
|
279 |
+
chrom = row["Chr"]
|
280 |
+
start = int(row["Start Pos"])
|
281 |
+
end = int(row["End Pos"])
|
282 |
+
strand = '+' if row["Strand"] == '1' else '-'
|
283 |
+
gRNA = row["gRNA"]
|
284 |
+
score = str(row["Prediction"])
|
285 |
+
# transcript_id is not typically part of the standard BED columns but added here for completeness
|
286 |
+
transcript_id = row["Transcript"]
|
287 |
+
|
288 |
+
# Writing only standard BED columns; additional columns can be appended as needed
|
289 |
+
bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
|
290 |
+
|
291 |
+
|
292 |
+
def create_csv_from_df(df, output_path):
|
293 |
+
df.to_csv(output_path, index=False)
|
294 |
+
|
295 |
+
|
296 |
+
|
cas9attvcf.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import tensorflow as tf
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from operator import add
|
6 |
+
from functools import reduce
|
7 |
+
import random
|
8 |
+
import tabulate
|
9 |
+
|
10 |
+
from keras import Model
|
11 |
+
from keras import regularizers
|
12 |
+
from keras.optimizers import Adam
|
13 |
+
from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
|
14 |
+
from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
|
15 |
+
from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
|
16 |
+
from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
|
17 |
+
from keras.models import load_model
|
18 |
+
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
19 |
+
from Bio import SeqIO
|
20 |
+
from Bio.SeqRecord import SeqRecord
|
21 |
+
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
22 |
+
from Bio.Seq import Seq
|
23 |
+
|
24 |
+
import cyvcf2
|
25 |
+
import parasail
|
26 |
+
|
27 |
+
import re
|
28 |
+
|
29 |
+
ntmap = {'A': (1, 0, 0, 0),
|
30 |
+
'C': (0, 1, 0, 0),
|
31 |
+
'G': (0, 0, 1, 0),
|
32 |
+
'T': (0, 0, 0, 1)
|
33 |
+
}
|
34 |
+
|
35 |
+
def get_seqcode(seq):
|
36 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
37 |
+
|
38 |
+
class PositionalEncoding(Layer):
|
39 |
+
def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
|
40 |
+
super(PositionalEncoding, self).__init__()
|
41 |
+
self.sequence_len = sequence_len
|
42 |
+
self.embedding_dim = embedding_dim
|
43 |
+
|
44 |
+
def call(self, x):
|
45 |
+
|
46 |
+
position_embedding = np.array([
|
47 |
+
[pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
|
48 |
+
for pos in range(self.sequence_len)])
|
49 |
+
|
50 |
+
position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i
|
51 |
+
position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1
|
52 |
+
position_embedding = tf.cast(position_embedding, dtype=tf.float32)
|
53 |
+
|
54 |
+
return position_embedding+x
|
55 |
+
|
56 |
+
def get_config(self):
|
57 |
+
config = super().get_config().copy()
|
58 |
+
config.update({
|
59 |
+
'sequence_len' : self.sequence_len,
|
60 |
+
'embedding_dim' : self.embedding_dim,
|
61 |
+
})
|
62 |
+
return config
|
63 |
+
|
64 |
+
def MultiHeadAttention_model(input_shape):
|
65 |
+
input = Input(shape=input_shape)
|
66 |
+
|
67 |
+
conv1 = Conv1D(256, 3, activation="relu")(input)
|
68 |
+
pool1 = AveragePooling1D(2)(conv1)
|
69 |
+
drop1 = Dropout(0.4)(pool1)
|
70 |
+
|
71 |
+
conv2 = Conv1D(256, 3, activation="relu")(drop1)
|
72 |
+
pool2 = AveragePooling1D(2)(conv2)
|
73 |
+
drop2 = Dropout(0.4)(pool2)
|
74 |
+
|
75 |
+
lstm = Bidirectional(LSTM(128,
|
76 |
+
dropout=0.5,
|
77 |
+
activation='tanh',
|
78 |
+
return_sequences=True,
|
79 |
+
kernel_regularizer=regularizers.l2(0.01)))(drop2)
|
80 |
+
|
81 |
+
pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
|
82 |
+
atten = MultiHeadAttention(num_heads=2,
|
83 |
+
key_dim=64,
|
84 |
+
dropout=0.2,
|
85 |
+
kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)
|
86 |
+
|
87 |
+
flat = Flatten()(atten)
|
88 |
+
|
89 |
+
dense1 = Dense(512,
|
90 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
91 |
+
bias_regularizer=regularizers.l2(1e-4),
|
92 |
+
activation="relu")(flat)
|
93 |
+
drop3 = Dropout(0.1)(dense1)
|
94 |
+
|
95 |
+
dense2 = Dense(128,
|
96 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
97 |
+
bias_regularizer=regularizers.l2(1e-4),
|
98 |
+
activation="relu")(drop3)
|
99 |
+
drop4 = Dropout(0.1)(dense2)
|
100 |
+
|
101 |
+
dense3 = Dense(256,
|
102 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
103 |
+
bias_regularizer=regularizers.l2(1e-4),
|
104 |
+
activation="relu")(drop4)
|
105 |
+
drop5 = Dropout(0.1)(dense3)
|
106 |
+
|
107 |
+
output = Dense(1, activation="linear")(drop5)
|
108 |
+
|
109 |
+
model = Model(inputs=[input], outputs=[output])
|
110 |
+
return model
|
111 |
+
|
112 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
113 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
114 |
+
response = requests.get(url)
|
115 |
+
if response.status_code == 200:
|
116 |
+
gene_data = response.json()
|
117 |
+
if 'Transcript' in gene_data:
|
118 |
+
return gene_data['Transcript']
|
119 |
+
else:
|
120 |
+
print("No transcripts found for gene:", gene_symbol)
|
121 |
+
return None
|
122 |
+
else:
|
123 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
124 |
+
return None
|
125 |
+
|
126 |
+
def fetch_ensembl_sequence(transcript_id):
|
127 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
128 |
+
response = requests.get(url)
|
129 |
+
if response.status_code == 200:
|
130 |
+
sequence_data = response.json()
|
131 |
+
if 'seq' in sequence_data:
|
132 |
+
return sequence_data['seq']
|
133 |
+
else:
|
134 |
+
print("No sequence found for transcript:", transcript_id)
|
135 |
+
return None
|
136 |
+
else:
|
137 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
138 |
+
return None
|
139 |
+
|
140 |
+
def apply_mutation(ref_sequence, offset, ref, alt):
|
141 |
+
"""
|
142 |
+
Apply a single mutation to the sequence.
|
143 |
+
"""
|
144 |
+
if len(ref) == len(alt) and alt != "*": # SNP
|
145 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]
|
146 |
+
|
147 |
+
elif len(ref) < len(alt): # Insertion
|
148 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]
|
149 |
+
|
150 |
+
elif len(ref) == len(alt) and alt == "*": # Deletion
|
151 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]
|
152 |
+
|
153 |
+
elif len(ref) > len(alt) and alt != "*": # Deletion
|
154 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]
|
155 |
+
|
156 |
+
elif len(ref) > len(alt) and alt == "*": # Deletion
|
157 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]
|
158 |
+
|
159 |
+
return mutated_seq
|
160 |
+
|
161 |
+
def construct_combinations(sequence, mutations):
|
162 |
+
"""
|
163 |
+
Construct all combinations of mutations.
|
164 |
+
mutations is a list of tuples (position, ref, [alts])
|
165 |
+
"""
|
166 |
+
if not mutations:
|
167 |
+
return [sequence]
|
168 |
+
|
169 |
+
# Take the first mutation and recursively construct combinations for the rest
|
170 |
+
first_mutation = mutations[0]
|
171 |
+
rest_mutations = mutations[1:]
|
172 |
+
offset, ref, alts = first_mutation
|
173 |
+
|
174 |
+
sequences = []
|
175 |
+
for alt in alts:
|
176 |
+
mutated_sequence = apply_mutation(sequence, offset, ref, alt)
|
177 |
+
sequences.extend(construct_combinations(mutated_sequence, rest_mutations))
|
178 |
+
|
179 |
+
return sequences
|
180 |
+
|
181 |
+
def needleman_wunsch_alignment(query_seq, ref_seq):
|
182 |
+
"""
|
183 |
+
Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
|
184 |
+
Use this position to represent the position of target sequence with mutations
|
185 |
+
"""
|
186 |
+
# Needleman-Wunsch alignment
|
187 |
+
alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)
|
188 |
+
|
189 |
+
# extract CIGAR object
|
190 |
+
cigar = alignment.cigar
|
191 |
+
cigar_string = cigar.decode.decode("utf-8")
|
192 |
+
|
193 |
+
# record ref_pos
|
194 |
+
ref_pos = 0
|
195 |
+
|
196 |
+
matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
|
197 |
+
max_num_before_equal = 0
|
198 |
+
max_equal_index = -1
|
199 |
+
total_before_max_equal = 0
|
200 |
+
|
201 |
+
for i, (num_str, op) in enumerate(matches):
|
202 |
+
num = int(num_str)
|
203 |
+
if op == '=':
|
204 |
+
if num > max_num_before_equal:
|
205 |
+
max_num_before_equal = num
|
206 |
+
max_equal_index = i
|
207 |
+
total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))
|
208 |
+
|
209 |
+
ref_pos = total_before_max_equal
|
210 |
+
|
211 |
+
return ref_pos
|
212 |
+
|
213 |
+
def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
|
214 |
+
exon_id, gene_symbol, vcf_reader, pam="NGG", target_length=20):
|
215 |
+
# initialization
|
216 |
+
mutated_sequences = [ref_sequence]
|
217 |
+
|
218 |
+
# find mutations within interested region
|
219 |
+
mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
|
220 |
+
if mutations:
|
221 |
+
# find mutations
|
222 |
+
mutation_list = []
|
223 |
+
for mutation in mutations:
|
224 |
+
offset = mutation.POS - start
|
225 |
+
ref = mutation.REF
|
226 |
+
alts = mutation.ALT[:-1]
|
227 |
+
mutation_list.append((offset, ref, alts))
|
228 |
+
|
229 |
+
# replace reference sequence of mutation
|
230 |
+
mutated_sequences = construct_combinations(ref_sequence, mutation_list)
|
231 |
+
|
232 |
+
# find gRNA in ref_sequence or all mutated_sequences
|
233 |
+
targets = []
|
234 |
+
for seq in mutated_sequences:
|
235 |
+
len_sequence = len(seq)
|
236 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
237 |
+
for i in range(len_sequence - len(pam) + 1):
|
238 |
+
if seq[i + 1:i + 3] == pam[1:]:
|
239 |
+
if i >= target_length:
|
240 |
+
target_seq = seq[i - target_length:i + 3]
|
241 |
+
pos = ref_sequence.find(target_seq)
|
242 |
+
if pos != -1:
|
243 |
+
is_mut = False
|
244 |
+
if strand == -1:
|
245 |
+
tar_start = end - pos - target_length - 2
|
246 |
+
else:
|
247 |
+
tar_start = start + pos
|
248 |
+
else:
|
249 |
+
is_mut = True
|
250 |
+
nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
|
251 |
+
if strand == -1:
|
252 |
+
tar_start = str(end - nw_pos - target_length - 2) + '*'
|
253 |
+
else:
|
254 |
+
tar_start = str(start + nw_pos) + '*'
|
255 |
+
gRNA = ''.join([dnatorna[base] for base in seq[i - target_length:i]])
|
256 |
+
targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])
|
257 |
+
|
258 |
+
# filter duplicated targets
|
259 |
+
unique_targets_set = set(tuple(element) for element in targets)
|
260 |
+
unique_targets = [list(element) for element in unique_targets_set]
|
261 |
+
|
262 |
+
return unique_targets
|
263 |
+
|
264 |
+
def format_prediction_output_with_mutation(targets, model_path):
|
265 |
+
model = MultiHeadAttention_model(input_shape=(23, 4))
|
266 |
+
model.load_weights(model_path)
|
267 |
+
|
268 |
+
formatted_data = []
|
269 |
+
|
270 |
+
for target in targets:
|
271 |
+
# Encode the gRNA sequence
|
272 |
+
encoded_seq = get_seqcode(target[0])
|
273 |
+
|
274 |
+
|
275 |
+
# Predict on-target efficiency using the model
|
276 |
+
prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
|
277 |
+
if prediction > 100:
|
278 |
+
prediction = 100
|
279 |
+
|
280 |
+
# Format output
|
281 |
+
gRNA = target[1]
|
282 |
+
exon_chr = target[2]
|
283 |
+
strand = target[3]
|
284 |
+
tar_start = target[4]
|
285 |
+
transcript_id = target[5]
|
286 |
+
exon_id = target[6]
|
287 |
+
gene_symbol = target[7]
|
288 |
+
is_mut = target[8]
|
289 |
+
formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id,
|
290 |
+
exon_id, target[0], gRNA, prediction, is_mut])
|
291 |
+
|
292 |
+
return formatted_data
|
293 |
+
|
294 |
+
|
295 |
+
def process_gene(gene_symbol, vcf_reader, model_path):
|
296 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
297 |
+
results = []
|
298 |
+
all_exons = [] # To accumulate all exons
|
299 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
300 |
+
|
301 |
+
if transcripts:
|
302 |
+
for transcript in transcripts:
|
303 |
+
Exons = transcript['Exon']
|
304 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
305 |
+
transcript_id = transcript['id']
|
306 |
+
|
307 |
+
for Exon in Exons:
|
308 |
+
exon_id = Exon['id']
|
309 |
+
gene_sequence = fetch_ensembl_sequence(exon_id) # Reference exon sequence
|
310 |
+
if gene_sequence:
|
311 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
312 |
+
exon_chr = Exon['seq_region_name']
|
313 |
+
start = Exon['start']
|
314 |
+
end = Exon['end']
|
315 |
+
strand = Exon['strand']
|
316 |
+
|
317 |
+
targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand,
|
318 |
+
transcript_id, exon_id, gene_symbol, vcf_reader)
|
319 |
+
if targets:
|
320 |
+
# Predict on-target efficiency for each gRNA site including mutations
|
321 |
+
formatted_data = format_prediction_output_with_mutation(targets, model_path)
|
322 |
+
results.extend(formatted_data)
|
323 |
+
else:
|
324 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
325 |
+
else:
|
326 |
+
print("Failed to retrieve transcripts.")
|
327 |
+
|
328 |
+
# Return the sorted output, combined gene sequences, and all exons
|
329 |
+
return results, all_gene_sequences, all_exons
|
330 |
+
|
331 |
+
|
332 |
+
def create_genbank_features(data):
|
333 |
+
features = []
|
334 |
+
|
335 |
+
# If the input data is a DataFrame, convert it to a list of lists
|
336 |
+
if isinstance(data, pd.DataFrame):
|
337 |
+
formatted_data = data.values.tolist()
|
338 |
+
elif isinstance(data, list):
|
339 |
+
formatted_data = data
|
340 |
+
else:
|
341 |
+
raise TypeError("Data should be either a list or a pandas DataFrame.")
|
342 |
+
|
343 |
+
for row in formatted_data:
|
344 |
+
try:
|
345 |
+
start = int(row[1])
|
346 |
+
end = start + len(row[6]) # Calculate the end position based on the target sequence length
|
347 |
+
except ValueError as e:
|
348 |
+
print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
|
349 |
+
continue
|
350 |
+
|
351 |
+
strand = 1 if row[3] == '1' else -1
|
352 |
+
location = FeatureLocation(start=start, end=end, strand=strand)
|
353 |
+
is_mutation = 'Yes' if row[9] else 'No'
|
354 |
+
feature = SeqFeature(location=location, type="misc_feature", qualifiers={
|
355 |
+
'label': row[7], # Use gRNA as the label
|
356 |
+
'note': f"Prediction: {row[8]}, Mutation: {is_mutation}" # Include the prediction score and mutation status
|
357 |
+
})
|
358 |
+
features.append(feature)
|
359 |
+
|
360 |
+
return features
|
361 |
+
|
362 |
+
def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
|
363 |
+
# Ensure gene_sequence is a string before creating Seq object
|
364 |
+
if not isinstance(gene_sequence, str):
|
365 |
+
gene_sequence = str(gene_sequence)
|
366 |
+
|
367 |
+
features = create_genbank_features(df)
|
368 |
+
|
369 |
+
# Now gene_sequence is guaranteed to be a string, suitable for Seq
|
370 |
+
seq_obj = Seq(gene_sequence)
|
371 |
+
record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
|
372 |
+
description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
|
373 |
+
record.annotations["molecule_type"] = "DNA"
|
374 |
+
SeqIO.write(record, output_path, "genbank")
|
375 |
+
|
376 |
+
def create_bed_file_from_df(df, output_path):
|
377 |
+
with open(output_path, 'w') as bed_file:
|
378 |
+
for index, row in df.iterrows():
|
379 |
+
chrom = row["Chr"]
|
380 |
+
start = int(row["Target Start"])
|
381 |
+
end = start + len(row["Target"]) # Calculate the end position based on the target sequence length
|
382 |
+
strand = '+' if row["Strand"] == '1' else '-'
|
383 |
+
gRNA = row["gRNA"]
|
384 |
+
score = str(row["Prediction"])
|
385 |
+
is_mutation = 'Yes' if row["Is Mutation"] else 'No'
|
386 |
+
# transcript_id is not typically part of the standard BED columns but added here for completeness
|
387 |
+
transcript_id = row["Transcript"]
|
388 |
+
|
389 |
+
# Writing only standard BED columns; additional columns can be appended as needed
|
390 |
+
bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\t{is_mutation}\n")
|
391 |
+
|
392 |
+
def create_csv_from_df(df, output_path):
|
393 |
+
df.to_csv(output_path, index=False)
|
cas9off.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
# configure GPUs
|
8 |
+
for gpu in tf.config.list_physical_devices('GPU'):
|
9 |
+
tf.config.experimental.set_memory_growth(gpu, enable=True)
|
10 |
+
if len(tf.config.list_physical_devices('GPU')) > 0:
|
11 |
+
tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
|
12 |
+
|
13 |
+
class Encoder:
|
14 |
+
def __init__(self, on_seq, off_seq, with_category = False, label = None, with_reg_val = False, value = None):
|
15 |
+
tlen = 24
|
16 |
+
self.on_seq = "-" *(tlen-len(on_seq)) + on_seq
|
17 |
+
self.off_seq = "-" *(tlen-len(off_seq)) + off_seq
|
18 |
+
self.encoded_dict_indel = {'A': [1, 0, 0, 0, 0], 'T': [0, 1, 0, 0, 0],
|
19 |
+
'G': [0, 0, 1, 0, 0], 'C': [0, 0, 0, 1, 0], '_': [0, 0, 0, 0, 1], '-': [0, 0, 0, 0, 0]}
|
20 |
+
self.direction_dict = {'A':5, 'G':4, 'C':3, 'T':2, '_':1}
|
21 |
+
if with_category:
|
22 |
+
self.label = label
|
23 |
+
if with_reg_val:
|
24 |
+
self.value = value
|
25 |
+
self.encode_on_off_dim7()
|
26 |
+
|
27 |
+
def encode_sgRNA(self):
|
28 |
+
code_list = []
|
29 |
+
encoded_dict = self.encoded_dict_indel
|
30 |
+
sgRNA_bases = list(self.on_seq)
|
31 |
+
for i in range(len(sgRNA_bases)):
|
32 |
+
if sgRNA_bases[i] == "N":
|
33 |
+
sgRNA_bases[i] = list(self.off_seq)[i]
|
34 |
+
code_list.append(encoded_dict[sgRNA_bases[i]])
|
35 |
+
self.sgRNA_code = np.array(code_list)
|
36 |
+
|
37 |
+
def encode_off(self):
|
38 |
+
code_list = []
|
39 |
+
encoded_dict = self.encoded_dict_indel
|
40 |
+
off_bases = list(self.off_seq)
|
41 |
+
for i in range(len(off_bases)):
|
42 |
+
code_list.append(encoded_dict[off_bases[i]])
|
43 |
+
self.off_code = np.array(code_list)
|
44 |
+
|
45 |
+
def encode_on_off_dim7(self):
|
46 |
+
self.encode_sgRNA()
|
47 |
+
self.encode_off()
|
48 |
+
on_bases = list(self.on_seq)
|
49 |
+
off_bases = list(self.off_seq)
|
50 |
+
on_off_dim7_codes = []
|
51 |
+
for i in range(len(on_bases)):
|
52 |
+
diff_code = np.bitwise_or(self.sgRNA_code[i], self.off_code[i])
|
53 |
+
on_b = on_bases[i]
|
54 |
+
off_b = off_bases[i]
|
55 |
+
if on_b == "N":
|
56 |
+
on_b = off_b
|
57 |
+
dir_code = np.zeros(2)
|
58 |
+
if on_b == "-" or off_b == "-" or self.direction_dict[on_b] == self.direction_dict[off_b]:
|
59 |
+
pass
|
60 |
+
else:
|
61 |
+
if self.direction_dict[on_b] > self.direction_dict[off_b]:
|
62 |
+
dir_code[0] = 1
|
63 |
+
else:
|
64 |
+
dir_code[1] = 1
|
65 |
+
on_off_dim7_codes.append(np.concatenate((diff_code, dir_code)))
|
66 |
+
self.on_off_code = np.array(on_off_dim7_codes)
|
67 |
+
|
68 |
+
def encode_on_off_seq_pairs(input_file):
|
69 |
+
inputs = pd.read_csv(input_file, delimiter=",", header=None, names=['on_seq', 'off_seq'])
|
70 |
+
input_codes = []
|
71 |
+
for idx, row in inputs.iterrows():
|
72 |
+
on_seq = row['on_seq']
|
73 |
+
off_seq = row['off_seq']
|
74 |
+
en = Encoder(on_seq=on_seq, off_seq=off_seq)
|
75 |
+
input_codes.append(en.on_off_code)
|
76 |
+
input_codes = np.array(input_codes)
|
77 |
+
input_codes = input_codes.reshape((len(input_codes), 1, 24, 7))
|
78 |
+
y_pred = CRISPR_net_predict(input_codes)
|
79 |
+
inputs['CRISPR_Net_score'] = y_pred
|
80 |
+
inputs.to_csv("CRISPR_net_results.csv", index=False)
|
81 |
+
|
82 |
+
def CRISPR_net_predict(X_test):
|
83 |
+
json_file = open("cas9_model/CRISPR_Net_CIRCLE_elevation_SITE_structure.json", 'r')
|
84 |
+
loaded_model_json = json_file.read()
|
85 |
+
json_file.close()
|
86 |
+
loaded_model = tf.keras.models.model_from_json(loaded_model_json) # Updated for TensorFlow 2
|
87 |
+
loaded_model.load_weights("cas9_model/CRISPR_Net_CIRCLE_elevation_SITE_weights.h5")
|
88 |
+
y_pred = loaded_model.predict(X_test).flatten()
|
89 |
+
return y_pred
|
90 |
+
|
91 |
+
|
92 |
+
def process_input_and_predict(input_data, input_type='manual'):
|
93 |
+
if input_type == 'manual':
|
94 |
+
sequences = [seq.split(',') for seq in input_data.split('\n')]
|
95 |
+
inputs = pd.DataFrame(sequences, columns=['on_seq', 'off_seq'])
|
96 |
+
elif input_type == 'file':
|
97 |
+
inputs = pd.read_csv(input_data, delimiter=",", header=None, names=['on_seq', 'off_seq'])
|
98 |
+
|
99 |
+
valid_inputs = []
|
100 |
+
input_codes = []
|
101 |
+
for idx, row in inputs.iterrows():
|
102 |
+
on_seq = row['on_seq']
|
103 |
+
off_seq = row['off_seq']
|
104 |
+
if not on_seq or not off_seq:
|
105 |
+
continue
|
106 |
+
|
107 |
+
en = Encoder(on_seq=on_seq, off_seq=off_seq)
|
108 |
+
input_codes.append(en.on_off_code)
|
109 |
+
valid_inputs.append((on_seq, off_seq))
|
110 |
+
|
111 |
+
input_codes = np.array(input_codes)
|
112 |
+
input_codes = input_codes.reshape((len(input_codes), 1, 24, 7))
|
113 |
+
|
114 |
+
y_pred = CRISPR_net_predict(input_codes)
|
115 |
+
|
116 |
+
# Create a new DataFrame from valid inputs and predictions
|
117 |
+
result_df = pd.DataFrame(valid_inputs, columns=['on_seq', 'off_seq'])
|
118 |
+
result_df['CRISPR_Net_score'] = y_pred
|
119 |
+
|
120 |
+
return result_df
|
121 |
+
|
122 |
+
if __name__ == '__main__':
|
123 |
+
parser = argparse.ArgumentParser(description="CRISPR-Net v1.0 (Aug 10 2019)")
|
124 |
+
parser.add_argument("input_file",
|
125 |
+
help="input_file example (on-target seq, off-target seq):\n GAGT_CCGAGCAGAAGAAGAATGG,GAGTACCAAGTAGAAGAAAAATTT\n"
|
126 |
+
"GTTGCCCCACAGGGCAGTAAAGG,GTGGACACCCCGGGCAGGAAAGG\n"
|
127 |
+
"GGGTGGGGGGAGTTTGCTCCAGG,AGGTGGGGTGA_TTTGCTCCAGG")
|
128 |
+
args = parser.parse_args()
|
129 |
+
file = args.input_file
|
130 |
+
if not os.path.exists(args.input_file):
|
131 |
+
print("File doesn't exist!")
|
132 |
+
else:
|
133 |
+
encode_on_off_seq_pairs(file)
|
134 |
+
tf.keras.backend.clear_session()
|
tiger.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import gzip
|
4 |
+
import pickle
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import tensorflow as tf
|
8 |
+
from Bio import SeqIO
|
9 |
+
|
10 |
+
# column names
|
11 |
+
ID_COL = 'Transcript ID'
|
12 |
+
SEQ_COL = 'Transcript Sequence'
|
13 |
+
TARGET_COL = 'Target Sequence'
|
14 |
+
GUIDE_COL = 'Guide Sequence'
|
15 |
+
MM_COL = 'Number of Mismatches'
|
16 |
+
SCORE_COL = 'Guide Score'
|
17 |
+
|
18 |
+
# nucleotide tokens
|
19 |
+
NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T', 'N'], [0, 1, 2, 3, 255]))
|
20 |
+
NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
|
21 |
+
|
22 |
+
# model hyper-parameters
|
23 |
+
GUIDE_LEN = 23
|
24 |
+
CONTEXT_5P = 3
|
25 |
+
CONTEXT_3P = 0
|
26 |
+
TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P
|
27 |
+
UNIT_INTERVAL_MAP = 'sigmoid'
|
28 |
+
|
29 |
+
# reference transcript files
|
30 |
+
REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
|
31 |
+
|
32 |
+
# application configuration
|
33 |
+
BATCH_SIZE_COMPUTE = 500
|
34 |
+
BATCH_SIZE_SCAN = 20
|
35 |
+
BATCH_SIZE_TRANSCRIPTS = 50
|
36 |
+
NUM_TOP_GUIDES = 10
|
37 |
+
NUM_MISMATCHES = 3
|
38 |
+
RUN_MODES = dict(
|
39 |
+
all='All on-target guides per transcript',
|
40 |
+
top_guides='Top {:d} guides per transcript'.format(NUM_TOP_GUIDES),
|
41 |
+
titration='Top {:d} guides per transcript & their titration candidates'.format(NUM_TOP_GUIDES)
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
# configure GPUs
|
46 |
+
for gpu in tf.config.list_physical_devices('GPU'):
|
47 |
+
tf.config.experimental.set_memory_growth(gpu, enable=True)
|
48 |
+
if len(tf.config.list_physical_devices('GPU')) > 0:
|
49 |
+
tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
|
50 |
+
|
51 |
+
|
52 |
+
def load_transcripts(fasta_files: list, enforce_unique_ids: bool = True):
|
53 |
+
|
54 |
+
# load all transcripts from fasta files into a DataFrame
|
55 |
+
transcripts = pd.DataFrame()
|
56 |
+
for file in fasta_files:
|
57 |
+
try:
|
58 |
+
if os.path.splitext(file)[1] == '.gz':
|
59 |
+
with gzip.open(file, 'rt') as f:
|
60 |
+
df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=[ID_COL, SEQ_COL])
|
61 |
+
else:
|
62 |
+
df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(file, 'fasta')], columns=[ID_COL, SEQ_COL])
|
63 |
+
except Exception as e:
|
64 |
+
print(e, 'while loading', file)
|
65 |
+
continue
|
66 |
+
transcripts = pd.concat([transcripts, df])
|
67 |
+
|
68 |
+
# set index
|
69 |
+
transcripts[ID_COL] = transcripts[ID_COL].apply(lambda s: s.split('|')[0])
|
70 |
+
transcripts.set_index(ID_COL, inplace=True)
|
71 |
+
if enforce_unique_ids:
|
72 |
+
assert not transcripts.index.has_duplicates, "duplicate transcript ID's detected in fasta file"
|
73 |
+
|
74 |
+
return transcripts
|
75 |
+
|
76 |
+
|
77 |
+
def sequence_complement(sequence: list):
|
78 |
+
return [''.join([NUCLEOTIDE_COMPLEMENT[nt] for nt in list(seq)]) for seq in sequence]
|
79 |
+
|
80 |
+
|
81 |
+
def one_hot_encode_sequence(sequence: list, add_context_padding: bool = False):
|
82 |
+
|
83 |
+
# stack list of sequences into a tensor
|
84 |
+
sequence = tf.ragged.stack([tf.constant(list(seq)) for seq in sequence], axis=0)
|
85 |
+
|
86 |
+
# tokenize sequence
|
87 |
+
nucleotide_table = tf.lookup.StaticVocabularyTable(
|
88 |
+
initializer=tf.lookup.KeyValueTensorInitializer(
|
89 |
+
keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string),
|
90 |
+
values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)),
|
91 |
+
num_oov_buckets=1)
|
92 |
+
sequence = tf.RaggedTensor.from_row_splits(values=nucleotide_table.lookup(sequence.values),
|
93 |
+
row_splits=sequence.row_splits).to_tensor(255)
|
94 |
+
|
95 |
+
# add context padding if requested
|
96 |
+
if add_context_padding:
|
97 |
+
pad_5p = 255 * tf.ones([sequence.shape[0], CONTEXT_5P], dtype=sequence.dtype)
|
98 |
+
pad_3p = 255 * tf.ones([sequence.shape[0], CONTEXT_3P], dtype=sequence.dtype)
|
99 |
+
sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)
|
100 |
+
|
101 |
+
# one-hot encode
|
102 |
+
sequence = tf.one_hot(sequence, depth=4, dtype=tf.float16)
|
103 |
+
|
104 |
+
return sequence
|
105 |
+
|
106 |
+
|
107 |
+
def process_data(transcript_seq: str):
|
108 |
+
|
109 |
+
# convert to upper case
|
110 |
+
transcript_seq = transcript_seq.upper()
|
111 |
+
|
112 |
+
# get all target sites
|
113 |
+
target_seq = [transcript_seq[i: i + TARGET_LEN] for i in range(len(transcript_seq) - TARGET_LEN + 1)]
|
114 |
+
|
115 |
+
# prepare guide sequences
|
116 |
+
guide_seq = sequence_complement([seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq])
|
117 |
+
|
118 |
+
# model inputs
|
119 |
+
model_inputs = tf.concat([
|
120 |
+
tf.reshape(one_hot_encode_sequence(target_seq, add_context_padding=False), [len(target_seq), -1]),
|
121 |
+
tf.reshape(one_hot_encode_sequence(guide_seq, add_context_padding=True), [len(guide_seq), -1]),
|
122 |
+
], axis=-1)
|
123 |
+
return target_seq, guide_seq, model_inputs
|
124 |
+
|
125 |
+
|
126 |
+
def calibrate_predictions(predictions: np.array, num_mismatches: np.array, params: pd.DataFrame = None):
|
127 |
+
if params is None:
|
128 |
+
params = pd.read_pickle('calibration_params.pkl')
|
129 |
+
correction = np.squeeze(params.set_index('num_mismatches').loc[num_mismatches, 'slope'].to_numpy())
|
130 |
+
return correction * predictions
|
131 |
+
|
132 |
+
|
133 |
+
def score_predictions(predictions: np.array, params: pd.DataFrame = None):
|
134 |
+
if params is None:
|
135 |
+
params = pd.read_pickle('scoring_params.pkl')
|
136 |
+
|
137 |
+
if UNIT_INTERVAL_MAP == 'sigmoid':
|
138 |
+
params = params.iloc[0]
|
139 |
+
return 1 - 1 / (1 + np.exp(params['a'] * predictions + params['b']))
|
140 |
+
|
141 |
+
elif UNIT_INTERVAL_MAP == 'min-max':
|
142 |
+
return 1 - (predictions - params['a']) / (params['b'] - params['a'])
|
143 |
+
|
144 |
+
elif UNIT_INTERVAL_MAP == 'exp-lin-exp':
|
145 |
+
# regime indices
|
146 |
+
active_saturation = predictions < params['a']
|
147 |
+
linear_regime = (params['a'] <= predictions) & (predictions <= params['c'])
|
148 |
+
inactive_saturation = params['c'] < predictions
|
149 |
+
|
150 |
+
# linear regime
|
151 |
+
slope = (params['d'] - params['b']) / (params['c'] - params['a'])
|
152 |
+
intercept = -params['a'] * slope + params['b']
|
153 |
+
predictions[linear_regime] = slope * predictions[linear_regime] + intercept
|
154 |
+
|
155 |
+
# active saturation regime
|
156 |
+
alpha = slope / params['b']
|
157 |
+
beta = alpha * params['a'] - np.log(params['b'])
|
158 |
+
predictions[active_saturation] = np.exp(alpha * predictions[active_saturation] - beta)
|
159 |
+
|
160 |
+
# inactive saturation regime
|
161 |
+
alpha = slope / (1 - params['d'])
|
162 |
+
beta = -alpha * params['c'] - np.log(1 - params['d'])
|
163 |
+
predictions[inactive_saturation] = 1 - np.exp(-alpha * predictions[inactive_saturation] - beta)
|
164 |
+
|
165 |
+
return 1 - predictions
|
166 |
+
|
167 |
+
else:
|
168 |
+
raise NotImplementedError
|
169 |
+
|
170 |
+
|
171 |
+
def get_on_target_predictions(transcripts: pd.DataFrame, model: tf.keras.Model, status_update_fn=None):
|
172 |
+
|
173 |
+
# loop over transcripts
|
174 |
+
predictions = pd.DataFrame()
|
175 |
+
for i, (index, row) in enumerate(transcripts.iterrows()):
|
176 |
+
|
177 |
+
# parse transcript sequence
|
178 |
+
target_seq, guide_seq, model_inputs = process_data(row[SEQ_COL])
|
179 |
+
|
180 |
+
# get predictions
|
181 |
+
lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
|
182 |
+
lfc_estimate = calibrate_predictions(lfc_estimate, num_mismatches=np.zeros_like(lfc_estimate))
|
183 |
+
scores = score_predictions(lfc_estimate)
|
184 |
+
predictions = pd.concat([predictions, pd.DataFrame({
|
185 |
+
ID_COL: [index] * len(scores),
|
186 |
+
TARGET_COL: target_seq,
|
187 |
+
GUIDE_COL: guide_seq,
|
188 |
+
SCORE_COL: scores})])
|
189 |
+
|
190 |
+
# progress update
|
191 |
+
percent_complete = 100 * min((i + 1) / len(transcripts), 1)
|
192 |
+
update_text = 'Evaluating on-target guides for each transcript: {:.2f}%'.format(percent_complete)
|
193 |
+
print('\r' + update_text, end='')
|
194 |
+
if status_update_fn is not None:
|
195 |
+
status_update_fn(update_text, percent_complete)
|
196 |
+
print('')
|
197 |
+
|
198 |
+
return predictions
|
199 |
+
|
200 |
+
|
201 |
+
def top_guides_per_transcript(predictions: pd.DataFrame):
|
202 |
+
|
203 |
+
# select and sort top guides for each transcript
|
204 |
+
top_guides = pd.DataFrame()
|
205 |
+
for transcript in predictions[ID_COL].unique():
|
206 |
+
df = predictions.loc[predictions[ID_COL] == transcript]
|
207 |
+
df = df.sort_values(SCORE_COL, ascending=False).reset_index(drop=True).iloc[:NUM_TOP_GUIDES]
|
208 |
+
top_guides = pd.concat([top_guides, df])
|
209 |
+
|
210 |
+
return top_guides.reset_index(drop=True)
|
211 |
+
|
212 |
+
|
213 |
+
def get_titration_candidates(top_guide_predictions: pd.DataFrame):
|
214 |
+
|
215 |
+
# generate a table of all titration candidates
|
216 |
+
titration_candidates = pd.DataFrame()
|
217 |
+
for _, row in top_guide_predictions.iterrows():
|
218 |
+
for i in range(len(row[GUIDE_COL])):
|
219 |
+
nt = row[GUIDE_COL][i]
|
220 |
+
for mutation in set(NUCLEOTIDE_TOKENS.keys()) - {nt, 'N'}:
|
221 |
+
sm_guide = list(row[GUIDE_COL])
|
222 |
+
sm_guide[i] = mutation
|
223 |
+
sm_guide = ''.join(sm_guide)
|
224 |
+
assert row[GUIDE_COL] != sm_guide
|
225 |
+
titration_candidates = pd.concat([titration_candidates, pd.DataFrame({
|
226 |
+
ID_COL: [row[ID_COL]],
|
227 |
+
TARGET_COL: [row[TARGET_COL]],
|
228 |
+
GUIDE_COL: [sm_guide],
|
229 |
+
MM_COL: [1]
|
230 |
+
})])
|
231 |
+
|
232 |
+
return titration_candidates
|
233 |
+
|
234 |
+
|
235 |
+
def find_off_targets(top_guides: pd.DataFrame, status_update_fn=None):
|
236 |
+
|
237 |
+
# load reference transcripts
|
238 |
+
reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS])
|
239 |
+
|
240 |
+
# one-hot encode guides to form a filter
|
241 |
+
guide_filter = one_hot_encode_sequence(sequence_complement(top_guides[GUIDE_COL]), add_context_padding=False)
|
242 |
+
guide_filter = tf.transpose(guide_filter, [1, 2, 0])
|
243 |
+
|
244 |
+
# loop over transcripts in batches
|
245 |
+
i = 0
|
246 |
+
off_targets = pd.DataFrame()
|
247 |
+
while i < len(reference_transcripts):
|
248 |
+
# select batch
|
249 |
+
df_batch = reference_transcripts.iloc[i:min(i + BATCH_SIZE_SCAN, len(reference_transcripts))]
|
250 |
+
i += BATCH_SIZE_SCAN
|
251 |
+
|
252 |
+
# find locations of off-targets
|
253 |
+
transcripts = one_hot_encode_sequence(df_batch[SEQ_COL].values.tolist(), add_context_padding=False)
|
254 |
+
num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
|
255 |
+
loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
|
256 |
+
|
257 |
+
# off-targets discovered
|
258 |
+
if len(loc_off_targets) > 0:
|
259 |
+
|
260 |
+
# log off-targets
|
261 |
+
dict_off_targets = pd.DataFrame({
|
262 |
+
'On-target ' + ID_COL: top_guides.iloc[loc_off_targets[:, 2]][ID_COL],
|
263 |
+
GUIDE_COL: top_guides.iloc[loc_off_targets[:, 2]][GUIDE_COL],
|
264 |
+
'Off-target ' + ID_COL: df_batch.index.values[loc_off_targets[:, 0]],
|
265 |
+
'Guide Midpoint': loc_off_targets[:, 1],
|
266 |
+
SEQ_COL: df_batch[SEQ_COL].values[loc_off_targets[:, 0]],
|
267 |
+
MM_COL: tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
|
268 |
+
}).to_dict('records')
|
269 |
+
|
270 |
+
# trim transcripts to targets
|
271 |
+
for row in dict_off_targets:
|
272 |
+
start_location = row['Guide Midpoint'] - (GUIDE_LEN // 2)
|
273 |
+
del row['Guide Midpoint']
|
274 |
+
target = row[SEQ_COL]
|
275 |
+
del row[SEQ_COL]
|
276 |
+
if start_location < CONTEXT_5P:
|
277 |
+
target = target[0:GUIDE_LEN + CONTEXT_3P]
|
278 |
+
target = 'N' * (TARGET_LEN - len(target)) + target
|
279 |
+
elif start_location + GUIDE_LEN + CONTEXT_3P > len(target):
|
280 |
+
target = target[start_location - CONTEXT_5P:]
|
281 |
+
target = target + 'N' * (TARGET_LEN - len(target))
|
282 |
+
else:
|
283 |
+
target = target[start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
|
284 |
+
if row[MM_COL] == 0 and 'N' not in target:
|
285 |
+
assert row[GUIDE_COL] == sequence_complement([target[CONTEXT_5P:TARGET_LEN - CONTEXT_3P]])[0]
|
286 |
+
row[TARGET_COL] = target
|
287 |
+
|
288 |
+
# append new off-targets
|
289 |
+
off_targets = pd.concat([off_targets, pd.DataFrame(dict_off_targets)])
|
290 |
+
|
291 |
+
# progress update
|
292 |
+
percent_complete = 100 * min((i + 1) / len(reference_transcripts), 1)
|
293 |
+
update_text = 'Scanning for off-targets: {:.2f}%'.format(percent_complete)
|
294 |
+
print('\r' + update_text, end='')
|
295 |
+
if status_update_fn is not None:
|
296 |
+
status_update_fn(update_text, percent_complete)
|
297 |
+
print('')
|
298 |
+
|
299 |
+
return off_targets
|
300 |
+
|
301 |
+
|
302 |
+
def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
|
303 |
+
if len(off_targets) == 0:
|
304 |
+
return pd.DataFrame()
|
305 |
+
|
306 |
+
# compute off-target predictions
|
307 |
+
model_inputs = tf.concat([
|
308 |
+
tf.reshape(one_hot_encode_sequence(off_targets[TARGET_COL], add_context_padding=False), [len(off_targets), -1]),
|
309 |
+
tf.reshape(one_hot_encode_sequence(off_targets[GUIDE_COL], add_context_padding=True), [len(off_targets), -1]),
|
310 |
+
], axis=-1)
|
311 |
+
lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
|
312 |
+
lfc_estimate = calibrate_predictions(lfc_estimate, off_targets['Number of Mismatches'].to_numpy())
|
313 |
+
off_targets[SCORE_COL] = score_predictions(lfc_estimate)
|
314 |
+
|
315 |
+
return off_targets.reset_index(drop=True)
|
316 |
+
|
317 |
+
|
318 |
+
def tiger_exhibit(transcripts: pd.DataFrame, mode: str, check_off_targets: bool, status_update_fn=None):
|
319 |
+
|
320 |
+
# load model
|
321 |
+
if os.path.exists('cas13_model'):
|
322 |
+
tiger = tf.keras.models.load_model('cas13_model')
|
323 |
+
else:
|
324 |
+
print('no saved model!')
|
325 |
+
exit()
|
326 |
+
|
327 |
+
# evaluate all on-target guides per transcript
|
328 |
+
on_target_predictions = get_on_target_predictions(transcripts, tiger, status_update_fn)
|
329 |
+
|
330 |
+
# initialize other outputs
|
331 |
+
titration_predictions = off_target_predictions = None
|
332 |
+
|
333 |
+
if mode == 'all' and not check_off_targets:
|
334 |
+
off_target_candidates = None
|
335 |
+
|
336 |
+
elif mode == 'top_guides':
|
337 |
+
on_target_predictions = top_guides_per_transcript(on_target_predictions)
|
338 |
+
off_target_candidates = on_target_predictions
|
339 |
+
|
340 |
+
elif mode == 'titration':
|
341 |
+
on_target_predictions = top_guides_per_transcript(on_target_predictions)
|
342 |
+
titration_candidates = get_titration_candidates(on_target_predictions)
|
343 |
+
titration_predictions = predict_off_target(titration_candidates, model=tiger)
|
344 |
+
off_target_candidates = pd.concat([on_target_predictions, titration_predictions])
|
345 |
+
|
346 |
+
else:
|
347 |
+
raise NotImplementedError
|
348 |
+
|
349 |
+
# check off-target effects for top guides
|
350 |
+
if check_off_targets and off_target_candidates is not None:
|
351 |
+
off_target_candidates = find_off_targets(off_target_candidates, status_update_fn)
|
352 |
+
off_target_predictions = predict_off_target(off_target_candidates, model=tiger)
|
353 |
+
if len(off_target_predictions) > 0:
|
354 |
+
off_target_predictions = off_target_predictions.sort_values(SCORE_COL, ascending=False)
|
355 |
+
off_target_predictions = off_target_predictions.reset_index(drop=True)
|
356 |
+
|
357 |
+
# finalize tables
|
358 |
+
for df in [on_target_predictions, titration_predictions, off_target_predictions]:
|
359 |
+
if df is not None and len(df) > 0:
|
360 |
+
for col in df.columns:
|
361 |
+
if ID_COL in col and set(df[col].unique()) == {'ManualEntry'}:
|
362 |
+
del df[col]
|
363 |
+
df[GUIDE_COL] = df[GUIDE_COL].apply(lambda s: s[::-1]) # reverse guide sequences
|
364 |
+
df[TARGET_COL] = df[TARGET_COL].apply(lambda seq: seq[CONTEXT_5P:len(seq) - CONTEXT_3P]) # remove context
|
365 |
+
|
366 |
+
return on_target_predictions, titration_predictions, off_target_predictions
|
367 |
+
|
368 |
+
|
369 |
+
if __name__ == '__main__':
|
370 |
+
|
371 |
+
# common arguments
|
372 |
+
parser = argparse.ArgumentParser()
|
373 |
+
parser.add_argument('--mode', type=str, default='titration')
|
374 |
+
parser.add_argument('--check_off_targets', action='store_true', default=False)
|
375 |
+
parser.add_argument('--fasta_path', type=str, default=None)
|
376 |
+
args = parser.parse_args()
|
377 |
+
|
378 |
+
# check for any existing results
|
379 |
+
if os.path.exists('on_target.csv') or os.path.exists('titration.csv') or os.path.exists('off_target.csv'):
|
380 |
+
raise FileExistsError('please rename or delete existing results')
|
381 |
+
|
382 |
+
# load transcripts from a directory of fasta files
|
383 |
+
if args.fasta_path is not None and os.path.exists(args.fasta_path):
|
384 |
+
df_transcripts = load_transcripts([os.path.join(args.fasta_path, f) for f in os.listdir(args.fasta_path)])
|
385 |
+
|
386 |
+
# otherwise consider simple test case with first 50 nucleotides from EIF3B-003's CDS
|
387 |
+
else:
|
388 |
+
df_transcripts = pd.DataFrame({
|
389 |
+
ID_COL: ['ManualEntry'],
|
390 |
+
SEQ_COL: ['ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC']})
|
391 |
+
df_transcripts.set_index(ID_COL, inplace=True)
|
392 |
+
|
393 |
+
# process in batches
|
394 |
+
batch = 0
|
395 |
+
num_batches = len(df_transcripts) // BATCH_SIZE_TRANSCRIPTS
|
396 |
+
num_batches += (len(df_transcripts) % BATCH_SIZE_TRANSCRIPTS > 0)
|
397 |
+
for idx in range(0, len(df_transcripts), BATCH_SIZE_TRANSCRIPTS):
|
398 |
+
batch += 1
|
399 |
+
print('Batch {:d} of {:d}'.format(batch, num_batches))
|
400 |
+
|
401 |
+
# run batch
|
402 |
+
idx_stop = min(idx + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts))
|
403 |
+
df_on_target, df_titration, df_off_target = tiger_exhibit(
|
404 |
+
transcripts=df_transcripts[idx:idx_stop],
|
405 |
+
mode=args.mode,
|
406 |
+
check_off_targets=args.check_off_targets
|
407 |
+
)
|
408 |
+
|
409 |
+
# save batch results
|
410 |
+
df_on_target.to_csv('on_target.csv', header=batch == 1, index=False, mode='a')
|
411 |
+
if df_titration is not None:
|
412 |
+
df_titration.to_csv('titration.csv', header=batch == 1, index=False, mode='a')
|
413 |
+
if df_off_target is not None:
|
414 |
+
df_off_target.to_csv('off_target.csv', header=batch == 1, index=False, mode='a')
|
415 |
+
|
416 |
+
# clear session to prevent memory blow up
|
417 |
+
tf.keras.backend.clear_session()
|