LfOreVEr commited on
Commit
4bb4454
·
verified ·
1 Parent(s): 566a054

Upload 7 files

Browse files
Files changed (7) hide show
  1. cas12.py +220 -0
  2. cas12lstm.py +258 -0
  3. cas12lstmvcf.py +351 -0
  4. cas9att.py +296 -0
  5. cas9attvcf.py +393 -0
  6. cas9off.py +134 -0
  7. tiger.py +417 -0
cas12.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras import Model
2
+ from keras.layers import Input
3
+ from keras.layers import Multiply
4
+ from keras.layers import Dense, Dropout, Activation, Flatten
5
+ from keras.layers import Convolution1D, AveragePooling1D
6
+ import pandas as pd
7
+ import numpy as np
8
+ import keras
9
+ import requests
10
+ from functools import reduce
11
+ from operator import add
12
+ from Bio.SeqRecord import SeqRecord
13
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
14
+ from Bio.Seq import Seq
15
+ from Bio import SeqIO
16
+
17
+ ntmap = {'A': (1, 0, 0, 0),
18
+ 'C': (0, 1, 0, 0),
19
+ 'G': (0, 0, 1, 0),
20
+ 'T': (0, 0, 0, 1)
21
+ }
22
+
23
+ def get_seqcode(seq):
24
+ return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
25
+
26
+ def Seq_DeepCpf1_model(input_shape):
27
+ Seq_deepCpf1_Input_SEQ = Input(shape=input_shape)
28
+ Seq_deepCpf1_C1 = Convolution1D(80, 5, activation='relu')(Seq_deepCpf1_Input_SEQ)
29
+ Seq_deepCpf1_P1 = AveragePooling1D(2)(Seq_deepCpf1_C1)
30
+ Seq_deepCpf1_F = Flatten()(Seq_deepCpf1_P1)
31
+ Seq_deepCpf1_DO1 = Dropout(0.3)(Seq_deepCpf1_F)
32
+ Seq_deepCpf1_D1 = Dense(80, activation='relu')(Seq_deepCpf1_DO1)
33
+ Seq_deepCpf1_DO2 = Dropout(0.3)(Seq_deepCpf1_D1)
34
+ Seq_deepCpf1_D2 = Dense(40, activation='relu')(Seq_deepCpf1_DO2)
35
+ Seq_deepCpf1_DO3 = Dropout(0.3)(Seq_deepCpf1_D2)
36
+ Seq_deepCpf1_D3 = Dense(40, activation='relu')(Seq_deepCpf1_DO3)
37
+ Seq_deepCpf1_DO4 = Dropout(0.3)(Seq_deepCpf1_D3)
38
+ Seq_deepCpf1_Output = Dense(1, activation='linear')(Seq_deepCpf1_DO4)
39
+ Seq_deepCpf1 = Model(inputs=[Seq_deepCpf1_Input_SEQ], outputs=[Seq_deepCpf1_Output])
40
+ return Seq_deepCpf1
41
+
42
+ # seq-ca model (DeepCpf1)
43
+ def DeepCpf1_model(input_shape):
44
+ DeepCpf1_Input_SEQ = Input(shape=input_shape)
45
+ DeepCpf1_C1 = Convolution1D(80, 5, activation='relu')(DeepCpf1_Input_SEQ)
46
+ DeepCpf1_P1 = AveragePooling1D(2)(DeepCpf1_C1)
47
+ DeepCpf1_F = Flatten()(DeepCpf1_P1)
48
+ DeepCpf1_DO1 = Dropout(0.3)(DeepCpf1_F)
49
+ DeepCpf1_D1 = Dense(80, activation='relu')(DeepCpf1_DO1)
50
+ DeepCpf1_DO2 = Dropout(0.3)(DeepCpf1_D1)
51
+ DeepCpf1_D2 = Dense(40, activation='relu')(DeepCpf1_DO2)
52
+ DeepCpf1_DO3 = Dropout(0.3)(DeepCpf1_D2)
53
+ DeepCpf1_D3_SEQ = Dense(40, activation='relu')(DeepCpf1_DO3)
54
+ DeepCpf1_Input_CA = Input(shape=(1,))
55
+ DeepCpf1_D3_CA = Dense(40, activation='relu')(DeepCpf1_Input_CA)
56
+ DeepCpf1_M = Multiply()([DeepCpf1_D3_SEQ, DeepCpf1_D3_CA])
57
+ DeepCpf1_DO4 = Dropout(0.3)(DeepCpf1_M)
58
+ DeepCpf1_Output = Dense(1, activation='linear')(DeepCpf1_DO4)
59
+ DeepCpf1 = Model(inputs=[DeepCpf1_Input_SEQ, DeepCpf1_Input_CA], outputs=[DeepCpf1_Output])
60
+ return DeepCpf1
61
+
62
+ def fetch_ensembl_transcripts(gene_symbol):
63
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
64
+ response = requests.get(url)
65
+ if response.status_code == 200:
66
+ gene_data = response.json()
67
+ if 'Transcript' in gene_data:
68
+ return gene_data['Transcript']
69
+ else:
70
+ print("No transcripts found for gene:", gene_symbol)
71
+ return None
72
+ else:
73
+ print(f"Error fetching gene data from Ensembl: {response.text}")
74
+ return None
75
+
76
+ def fetch_ensembl_sequence(transcript_id):
77
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
78
+ response = requests.get(url)
79
+ if response.status_code == 200:
80
+ sequence_data = response.json()
81
+ if 'seq' in sequence_data:
82
+ return sequence_data['seq']
83
+ else:
84
+ print("No sequence found for transcript:", transcript_id)
85
+ return None
86
+ else:
87
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
88
+ return None
89
+
90
+ def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="TTTN", target_length=34):
91
+ targets = []
92
+ len_sequence = len(sequence)
93
+ complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
94
+ dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
95
+
96
+ for i in range(len_sequence - target_length + 1):
97
+ target_seq = sequence[i:i + target_length]
98
+ if target_seq[4:7] == 'TTT':
99
+ if strand == -1:
100
+ tar_start = end - i - target_length + 1
101
+ tar_end = end -i
102
+ #seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
103
+ else:
104
+ tar_start = start + i
105
+ tar_end = start + i + target_length - 1
106
+ #seq_in_ref = target_seq
107
+ gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
108
+ targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
109
+ return targets
110
+
111
+ def format_prediction_output(targets, model_path):
112
+ # Loading weights for the model
113
+ Seq_deepCpf1 = Seq_DeepCpf1_model(input_shape=(34, 4))
114
+ Seq_deepCpf1.load_weights(model_path)
115
+
116
+ formatted_data = []
117
+ for target in targets:
118
+ # Predict
119
+ encoded_seq = get_seqcode(target[0])
120
+ prediction = float(list(Seq_deepCpf1.predict(encoded_seq)[0])[0])
121
+ if prediction > 100:
122
+ prediction = 100
123
+
124
+ # Format output
125
+ gRNA = target[1]
126
+ chr = target[2]
127
+ start = target[3]
128
+ end = target[4]
129
+ strand = target[5]
130
+ transcript_id = target[6]
131
+ exon_id = target[7]
132
+ formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
133
+
134
+ return formatted_data
135
+
136
+ def process_gene(gene_symbol, model_path):
137
+ transcripts = fetch_ensembl_transcripts(gene_symbol)
138
+ results = []
139
+ all_exons = []
140
+ all_gene_sequences = []
141
+
142
+ if transcripts:
143
+ for i in range(len(transcripts)):
144
+ Exons = transcripts[i]['Exon']
145
+ all_exons.append(Exons)
146
+ transcript_id = transcripts[i]['id']
147
+ for j in range(len(Exons)):
148
+ exon_id = Exons[j]['id']
149
+ gene_sequence = fetch_ensembl_sequence(exon_id)
150
+ if gene_sequence:
151
+ all_gene_sequences.append(gene_sequence)
152
+ start = Exons[j]['start']
153
+ end = Exons[j]['end']
154
+ strand = Exons[j]['strand']
155
+ chr = Exons[j]['seq_region_name']
156
+ targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
157
+ if targets:
158
+ formatted_data = format_prediction_output(targets, model_path)
159
+ results.append(formatted_data)
160
+ # for data in formatted_data:
161
+ # print(f"Chr: {data[0]}, Start: {data[1]}, End: {data[2]}, Strand: {data[3]}, target: {data[4]}, gRNA: {data[5]}, pred_Score: {data[6]}")
162
+ else:
163
+ print("Failed to retrieve gene sequence.")
164
+ else:
165
+ print("Failed to retrieve transcripts.")
166
+
167
+ return results, all_gene_sequences, all_exons
168
+
169
+
170
+ # def create_genbank_features(formatted_data):
171
+ # features = []
172
+ # for data in formatted_data:
173
+ # try:
174
+ # # Attempt to convert start and end positions to integers
175
+ # start = int(data[1])
176
+ # end = int(data[2])
177
+ # except ValueError as e:
178
+ # # Log the error and skip this iteration if conversion fails
179
+ # print(f"Error converting start/end to int: {data[1]}, {data[2]} - {e}")
180
+ # continue # Skip this iteration
181
+ #
182
+ # # Proceed as normal if conversion is successful
183
+ # strand = 1 if data[3] == '+' else -1
184
+ # location = FeatureLocation(start=start, end=end, strand=strand)
185
+ # feature = SeqFeature(location=location, type="misc_feature", qualifiers={
186
+ # 'label': data[5], # gRNA as label
187
+ # 'note': f"Prediction: {data[6]}" # Prediction score in note
188
+ # })
189
+ # features.append(feature)
190
+ # return features
191
+ #
192
+ # def generate_genbank_file_from_data(formatted_data, gene_sequence, gene_symbol, output_path):
193
+ # features = create_genbank_features(formatted_data)
194
+ # record = SeqRecord(Seq(gene_sequence), id=gene_symbol, name=gene_symbol,
195
+ # description='CRISPR Cas12 predicted targets', features=features)
196
+ # record.annotations["molecule_type"] = "DNA"
197
+ # SeqIO.write(record, output_path, "genbank")
198
+ #
199
+ # def create_csv_from_df(df, output_path):
200
+ # df.to_csv(output_path, index=False)
201
+ #
202
+ # def generate_bed_file_from_data(formatted_data, output_path):
203
+ # with open(output_path, 'w') as bed_file:
204
+ # for data in formatted_data:
205
+ # try:
206
+ # # Ensure data has the expected number of elements
207
+ # if len(data) < 7:
208
+ # raise ValueError("Incomplete data item")
209
+ #
210
+ # chrom = data[0]
211
+ # start = data[1]
212
+ # end = data[2]
213
+ # strand = '+' if data[3] == '+' else '-'
214
+ # gRNA = data[5]
215
+ # score = data[6] # Ensure this index exists
216
+ #
217
+ # bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
218
+ # except ValueError as e:
219
+ # print(f"Skipping an item due to error: {e}")
220
+ # continue
cas12lstm.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from keras import regularizers
3
+ from keras.layers import Input, Dense, Dropout, Activation, Conv1D
4
+ from keras.layers import GlobalAveragePooling1D, AveragePooling1D
5
+ from keras.layers import Bidirectional, LSTM
6
+ from keras import Model
7
+ from keras.metrics import MeanSquaredError
8
+
9
+ import pandas as pd
10
+ import numpy as np
11
+
12
+ import requests
13
+ from functools import reduce
14
+ from operator import add
15
+ import tabulate
16
+ from difflib import SequenceMatcher
17
+ from Bio import SeqIO
18
+ from Bio.SeqRecord import SeqRecord
19
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
20
+ from Bio.Seq import Seq
21
+
22
+ import cyvcf2
23
+ import parasail
24
+
25
+ import re
26
+
27
+ ntmap = {'A': (1, 0, 0, 0),
28
+ 'C': (0, 1, 0, 0),
29
+ 'G': (0, 0, 1, 0),
30
+ 'T': (0, 0, 0, 1)
31
+ }
32
+
33
+ def get_seqcode(seq):
34
+ return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
35
+
36
+ def BiLSTM_model(input_shape):
37
+ input = Input(shape=input_shape)
38
+
39
+ conv1 = Conv1D(128, 5, activation="relu")(input)
40
+ pool1 = AveragePooling1D(2)(conv1)
41
+ drop1 = Dropout(0.1)(pool1)
42
+
43
+ conv2 = Conv1D(128, 5, activation="relu")(drop1)
44
+ pool2 = AveragePooling1D(2)(conv2)
45
+ drop2 = Dropout(0.1)(pool2)
46
+
47
+ lstm1 = Bidirectional(LSTM(128,
48
+ dropout=0.1,
49
+ activation='tanh',
50
+ return_sequences=True,
51
+ kernel_regularizer=regularizers.l2(1e-4)))(drop2)
52
+ avgpool = GlobalAveragePooling1D()(lstm1)
53
+
54
+ dense1 = Dense(128,
55
+ kernel_regularizer=regularizers.l2(1e-4),
56
+ bias_regularizer=regularizers.l2(1e-4),
57
+ activation="relu")(avgpool)
58
+ drop3 = Dropout(0.1)(dense1)
59
+
60
+ dense2 = Dense(32,
61
+ kernel_regularizer=regularizers.l2(1e-4),
62
+ bias_regularizer=regularizers.l2(1e-4),
63
+ activation="relu")(drop3)
64
+ drop4 = Dropout(0.1)(dense2)
65
+
66
+ dense3 = Dense(32,
67
+ kernel_regularizer=regularizers.l2(1e-4),
68
+ bias_regularizer=regularizers.l2(1e-4),
69
+ activation="relu")(drop4)
70
+ drop5 = Dropout(0.1)(dense3)
71
+
72
+ output = Dense(1, activation="linear")(drop5)
73
+
74
+ model = Model(inputs=[input], outputs=[output])
75
+ return model
76
+
77
+ def fetch_ensembl_transcripts(gene_symbol):
78
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
79
+ response = requests.get(url)
80
+ if response.status_code == 200:
81
+ gene_data = response.json()
82
+ if 'Transcript' in gene_data:
83
+ return gene_data['Transcript']
84
+ else:
85
+ print("No transcripts found for gene:", gene_symbol)
86
+ return None
87
+ else:
88
+ print(f"Error fetching gene data from Ensembl: {response.text}")
89
+ return None
90
+
91
+ def fetch_ensembl_sequence(transcript_id):
92
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
93
+ response = requests.get(url)
94
+ if response.status_code == 200:
95
+ sequence_data = response.json()
96
+ if 'seq' in sequence_data:
97
+ return sequence_data['seq']
98
+ else:
99
+ print("No sequence found for transcript:", transcript_id)
100
+ return None
101
+ else:
102
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
103
+ return None
104
+
105
+ def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="TTTN", target_length=34):
106
+ targets = []
107
+ len_sequence = len(sequence)
108
+ #complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
109
+ dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
110
+
111
+ for i in range(len_sequence - target_length + 1):
112
+ target_seq = sequence[i:i + target_length]
113
+ if target_seq[4:7] == 'TTT':
114
+ if strand == -1:
115
+ tar_start = end - i - target_length + 1
116
+ tar_end = end -i
117
+ #seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
118
+ else:
119
+ tar_start = start + i
120
+ tar_end = start + i + target_length - 1
121
+ #seq_in_ref = target_seq
122
+ gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
123
+ targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
124
+ #targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
125
+ return targets
126
+
127
+ def format_prediction_output(targets, model_path):
128
+ # Loading weights for the model
129
+ Crispr_BiLSTM = BiLSTM_model(input_shape=(34, 4))
130
+ Crispr_BiLSTM.load_weights(model_path)
131
+
132
+ formatted_data = []
133
+ for target in targets:
134
+ # Predict
135
+ encoded_seq = get_seqcode(target[0])
136
+ prediction = float(list(Crispr_BiLSTM.predict(encoded_seq, verbose=0)[0])[0])
137
+ if prediction > 100:
138
+ prediction = 100
139
+
140
+ # Format output
141
+ gRNA = target[1]
142
+ chr = target[2]
143
+ start = target[3]
144
+ end = target[4]
145
+ strand = target[5]
146
+ transcript_id = target[6]
147
+ exon_id = target[7]
148
+ #seq_in_ref = target[8]
149
+ #formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction])
150
+ formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
151
+
152
+ return formatted_data
153
+
154
+
155
+ def process_gene(gene_symbol, model_path):
156
+ transcripts = fetch_ensembl_transcripts(gene_symbol)
157
+ results = []
158
+ all_exons = [] # To accumulate all exons
159
+ all_gene_sequences = [] # To accumulate all gene sequences
160
+
161
+ if transcripts:
162
+ for transcript in transcripts:
163
+ Exons = transcript['Exon']
164
+ all_exons.extend(Exons) # Add all exons from this transcript to the list
165
+ transcript_id = transcript['id']
166
+
167
+ for Exon in Exons:
168
+ exon_id = Exon['id']
169
+ gene_sequence = fetch_ensembl_sequence(exon_id)
170
+ if gene_sequence:
171
+ all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
172
+ chr = Exon['seq_region_name']
173
+ start = Exon['start']
174
+ end = Exon['end']
175
+ strand = Exon['strand']
176
+
177
+ targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
178
+ if targets:
179
+ # Predict on-target efficiency for each gRNA site
180
+ formatted_data = format_prediction_output(targets, model_path)
181
+ results.extend(formatted_data) # Flatten the results
182
+ else:
183
+ print(f"Failed to retrieve gene sequence for exon {exon_id}.")
184
+ else:
185
+ print("Failed to retrieve transcripts.")
186
+
187
+ output = []
188
+ for result in results:
189
+ for item in result:
190
+ output.append(item)
191
+
192
+ # Return the sorted output, combined gene sequences, and all exons
193
+ return results, all_gene_sequences, all_exons
194
+
195
+ def create_genbank_features(data):
196
+ features = []
197
+
198
+ # If the input data is a DataFrame, convert it to a list of lists
199
+ if isinstance(data, pd.DataFrame):
200
+ formatted_data = data.values.tolist()
201
+ elif isinstance(data, list):
202
+ formatted_data = data
203
+ else:
204
+ raise TypeError("Data should be either a list or a pandas DataFrame.")
205
+
206
+ for row in formatted_data:
207
+ try:
208
+ start = int(row[1])
209
+ end = int(row[2])
210
+ except ValueError as e:
211
+ print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
212
+ continue
213
+
214
+ strand = 1 if row[3] == '+' else -1
215
+ location = FeatureLocation(start=start, end=end, strand=strand)
216
+ feature = SeqFeature(location=location, type="misc_feature", qualifiers={
217
+ 'label': row[7], # Use gRNA as the label
218
+ 'note': f"Prediction: {row[8]}" # Include the prediction score
219
+ })
220
+ features.append(feature)
221
+
222
+ return features
223
+
224
+
225
+ def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
226
+ # Ensure gene_sequence is a string before creating Seq object
227
+ if not isinstance(gene_sequence, str):
228
+ gene_sequence = str(gene_sequence)
229
+
230
+ features = create_genbank_features(df)
231
+
232
+ # Now gene_sequence is guaranteed to be a string, suitable for Seq
233
+ seq_obj = Seq(gene_sequence)
234
+ record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
235
+ description=f'CRISPR Cas12 predicted targets for {gene_symbol}', features=features)
236
+ record.annotations["molecule_type"] = "DNA"
237
+ SeqIO.write(record, output_path, "genbank")
238
+
239
+
240
+ def create_bed_file_from_df(df, output_path):
241
+ with open(output_path, 'w') as bed_file:
242
+ for index, row in df.iterrows():
243
+ chrom = row["Chr"]
244
+ start = int(row["Start Pos"])
245
+ end = int(row["End Pos"])
246
+ strand = '+' if row["Strand"] == '1' else '-'
247
+ gRNA = row["gRNA"]
248
+ score = str(row["Prediction"])
249
+ # transcript_id is not typically part of the standard BED columns but added here for completeness
250
+ transcript_id = row["Transcript"]
251
+
252
+ # Writing only standard BED columns; additional columns can be appended as needed
253
+ bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
254
+
255
+
256
+ def create_csv_from_df(df, output_path):
257
+ df.to_csv(output_path, index=False)
258
+
cas12lstmvcf.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from keras import regularizers
3
+ from keras.layers import Input, Dense, Dropout, Activation, Conv1D
4
+ from keras.layers import GlobalAveragePooling1D, AveragePooling1D
5
+ from keras.layers import Bidirectional, LSTM
6
+ from keras import Model
7
+ from keras.metrics import MeanSquaredError
8
+
9
+ import pandas as pd
10
+ import numpy as np
11
+ from Bio import SeqIO
12
+ from Bio.SeqRecord import SeqRecord
13
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
14
+ from Bio.Seq import Seq
15
+
16
+ import requests
17
+ from functools import reduce
18
+ from operator import add
19
+ import tabulate
20
+ from difflib import SequenceMatcher
21
+
22
+ import cyvcf2
23
+ import parasail
24
+
25
+ import re
26
+
27
+ ntmap = {'A': (1, 0, 0, 0),
28
+ 'C': (0, 1, 0, 0),
29
+ 'G': (0, 0, 1, 0),
30
+ 'T': (0, 0, 0, 1)
31
+ }
32
+
33
+ def get_seqcode(seq):
34
+ return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
35
+
36
+ def BiLSTM_model(input_shape):
37
+ input = Input(shape=input_shape)
38
+
39
+ conv1 = Conv1D(128, 5, activation="relu")(input)
40
+ pool1 = AveragePooling1D(2)(conv1)
41
+ drop1 = Dropout(0.1)(pool1)
42
+
43
+ conv2 = Conv1D(128, 5, activation="relu")(drop1)
44
+ pool2 = AveragePooling1D(2)(conv2)
45
+ drop2 = Dropout(0.1)(pool2)
46
+
47
+ lstm1 = Bidirectional(LSTM(128,
48
+ dropout=0.1,
49
+ activation='tanh',
50
+ return_sequences=True,
51
+ kernel_regularizer=regularizers.l2(1e-4)))(drop2)
52
+ avgpool = GlobalAveragePooling1D()(lstm1)
53
+
54
+ dense1 = Dense(128,
55
+ kernel_regularizer=regularizers.l2(1e-4),
56
+ bias_regularizer=regularizers.l2(1e-4),
57
+ activation="relu")(avgpool)
58
+ drop3 = Dropout(0.1)(dense1)
59
+
60
+ dense2 = Dense(32,
61
+ kernel_regularizer=regularizers.l2(1e-4),
62
+ bias_regularizer=regularizers.l2(1e-4),
63
+ activation="relu")(drop3)
64
+ drop4 = Dropout(0.1)(dense2)
65
+
66
+ dense3 = Dense(32,
67
+ kernel_regularizer=regularizers.l2(1e-4),
68
+ bias_regularizer=regularizers.l2(1e-4),
69
+ activation="relu")(drop4)
70
+ drop5 = Dropout(0.1)(dense3)
71
+
72
+ output = Dense(1, activation="linear")(drop5)
73
+
74
+ model = Model(inputs=[input], outputs=[output])
75
+ return model
76
+
77
+ def fetch_ensembl_transcripts(gene_symbol):
78
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
79
+ response = requests.get(url)
80
+ if response.status_code == 200:
81
+ gene_data = response.json()
82
+ if 'Transcript' in gene_data:
83
+ return gene_data['Transcript']
84
+ else:
85
+ print("No transcripts found for gene:", gene_symbol)
86
+ return None
87
+ else:
88
+ print(f"Error fetching gene data from Ensembl: {response.text}")
89
+ return None
90
+
91
+ def fetch_ensembl_sequence(transcript_id):
92
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
93
+ response = requests.get(url)
94
+ if response.status_code == 200:
95
+ sequence_data = response.json()
96
+ if 'seq' in sequence_data:
97
+ return sequence_data['seq']
98
+ else:
99
+ print("No sequence found for transcript:", transcript_id)
100
+ return None
101
+ else:
102
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
103
+ return None
104
+
105
+ def apply_mutation(ref_sequence, offset, ref, alt):
106
+ """
107
+ Apply a single mutation to the sequence.
108
+ """
109
+ if len(ref) == len(alt) and alt != "*": # SNP
110
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]
111
+
112
+ elif len(ref) < len(alt): # Insertion
113
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]
114
+
115
+ elif len(ref) == len(alt) and alt == "*": # Deletion
116
+ mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]
117
+
118
+ elif len(ref) > len(alt) and alt != "*": # Deletion
119
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]
120
+
121
+ elif len(ref) > len(alt) and alt == "*": # Deletion
122
+ mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]
123
+
124
+ return mutated_seq
125
+
126
+
127
+ def construct_combinations(sequence, mutations):
128
+ """
129
+ Construct all combinations of mutations.
130
+ mutations is a list of tuples (position, ref, [alts])
131
+ """
132
+ if not mutations:
133
+ return [sequence]
134
+
135
+ # Take the first mutation and recursively construct combinations for the rest
136
+ first_mutation = mutations[0]
137
+ rest_mutations = mutations[1:]
138
+ offset, ref, alts = first_mutation
139
+
140
+ sequences = []
141
+ for alt in alts:
142
+ mutated_sequence = apply_mutation(sequence, offset, ref, alt)
143
+ sequences.extend(construct_combinations(mutated_sequence, rest_mutations))
144
+
145
+ return sequences
146
+
147
+ def needleman_wunsch_alignment(query_seq, ref_seq):
148
+ """
149
+ Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
150
+ Use this position to represent the position of target sequence with mutations
151
+ """
152
+ # Needleman-Wunsch alignment
153
+ alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)
154
+
155
+ # extract CIGAR object
156
+ cigar = alignment.cigar
157
+ cigar_string = cigar.decode.decode("utf-8")
158
+
159
+ # record ref_pos
160
+ ref_pos = 0
161
+
162
+ matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
163
+ max_num_before_equal = 0
164
+ max_equal_index = -1
165
+ total_before_max_equal = 0
166
+
167
+ for i, (num_str, op) in enumerate(matches):
168
+ num = int(num_str)
169
+ if op == '=':
170
+ if num > max_num_before_equal:
171
+ max_num_before_equal = num
172
+ max_equal_index = i
173
+ total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))
174
+
175
+ ref_pos = total_before_max_equal
176
+
177
+ return ref_pos
178
+
179
+ def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
180
+ exon_id, gene_symbol, vcf_reader, pam="TTTN", target_length=34):
181
+ # initialization
182
+ mutated_sequences = [ref_sequence]
183
+
184
+ # find mutations within interested region
185
+ mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
186
+ if mutations:
187
+ # find mutations
188
+ mutation_list = []
189
+ for mutation in mutations:
190
+ offset = mutation.POS - start
191
+ ref = mutation.REF
192
+ alts = mutation.ALT[:-1]
193
+ mutation_list.append((offset, ref, alts))
194
+
195
+ # replace reference sequence of mutation
196
+ mutated_sequences = construct_combinations(ref_sequence, mutation_list)
197
+
198
+ # find gRNA in ref_sequence or all mutated_sequences
199
+ targets = []
200
+ for seq in mutated_sequences:
201
+ len_sequence = len(seq)
202
+ dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
203
+ for i in range(len_sequence - target_length + 1):
204
+ target_seq = seq[i:i + target_length]
205
+ if target_seq[4:7] == 'TTT':
206
+ pos = ref_sequence.find(target_seq)
207
+ if pos != -1:
208
+ is_mut = False
209
+ if strand == -1:
210
+ tar_start = end - pos - target_length + 1
211
+ else:
212
+ tar_start = start + pos
213
+ else:
214
+ is_mut = True
215
+ nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
216
+ if strand == -1:
217
+ tar_start = str(end - nw_pos - target_length + 1) + '*'
218
+ else:
219
+ tar_start = str(start + nw_pos) + '*'
220
+ gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
221
+ targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])
222
+
223
+ # filter duplicated targets
224
+ unique_targets_set = set(tuple(element) for element in targets)
225
+ unique_targets = [list(element) for element in unique_targets_set]
226
+
227
+ return unique_targets
228
+
229
+ def format_prediction_output_with_mutation(targets, model_path):
230
+ Crispr_BiLSTM = BiLSTM_model(input_shape=(34, 4))
231
+ Crispr_BiLSTM.load_weights(model_path)
232
+
233
+ formatted_data = []
234
+ for target in targets:
235
+ # Predict
236
+ encoded_seq = get_seqcode(target[0])
237
+ prediction = float(list(Crispr_BiLSTM.predict(encoded_seq, verbose=0)[0])[0])
238
+ if prediction > 100:
239
+ prediction = 100
240
+
241
+ # Format output
242
+ gRNA = target[1]
243
+ exon_chr = target[2]
244
+ strand = target[3]
245
+ tar_start = target[4]
246
+ transcript_id = target[5]
247
+ exon_id = target[6]
248
+ gene_symbol = target[7]
249
+ is_mut = target[8]
250
+ formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id, exon_id, target[0], gRNA, prediction, is_mut])
251
+
252
+ return formatted_data
253
+
254
+ def process_gene(gene_symbol, vcf_reader, model_path):
255
+ transcripts = fetch_ensembl_transcripts(gene_symbol)
256
+ results = []
257
+ all_exons = [] # To accumulate all exons
258
+ all_gene_sequences = [] # To accumulate all gene sequences
259
+
260
+ if transcripts:
261
+ for transcript in transcripts:
262
+ Exons = transcript['Exon']
263
+ all_exons.extend(Exons) # Add all exons from this transcript to the list
264
+ transcript_id = transcript['id']
265
+
266
+ for Exon in Exons:
267
+ exon_id = Exon['id']
268
+ gene_sequence = fetch_ensembl_sequence(exon_id) # Reference exon sequence
269
+ if gene_sequence:
270
+ all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
271
+ exon_chr = Exon['seq_region_name']
272
+ start = Exon['start']
273
+ end = Exon['end']
274
+ strand = Exon['strand']
275
+
276
+ targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand, transcript_id, exon_id, gene_symbol, vcf_reader)
277
+ if targets:
278
+ # Predict on-target efficiency for each gRNA site
279
+ formatted_data = format_prediction_output_with_mutation(targets, model_path)
280
+ results.extend(formatted_data) # Flatten the results
281
+ else:
282
+ print(f"Failed to retrieve gene sequence for exon {exon_id}.")
283
+ else:
284
+ print("Failed to retrieve transcripts.")
285
+
286
+ # Return the sorted output, combined gene sequences, and all exons
287
+ return results, all_gene_sequences, all_exons
288
+
289
+ def create_genbank_features(data):
290
+ features = []
291
+
292
+ # If the input data is a DataFrame, convert it to a list of lists
293
+ if isinstance(data, pd.DataFrame):
294
+ formatted_data = data.values.tolist()
295
+ elif isinstance(data, list):
296
+ formatted_data = data
297
+ else:
298
+ raise TypeError("Data should be either a list or a pandas DataFrame.")
299
+
300
+ for row in formatted_data:
301
+ try:
302
+ start = int(row[1])
303
+ end = start + len(row[6]) # Calculate the end position based on the target sequence length
304
+ except ValueError as e:
305
+ print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
306
+ continue
307
+
308
+ strand = 1 if row[3] == '1' else -1
309
+ location = FeatureLocation(start=start, end=end, strand=strand)
310
+ is_mutation = 'Yes' if row[9] else 'No'
311
+ feature = SeqFeature(location=location, type="misc_feature", qualifiers={
312
+ 'label': row[7], # Use gRNA as the label
313
+ 'note': f"Prediction: {row[8]}, Mutation: {is_mutation}" # Include the prediction score and mutation status
314
+ })
315
+ features.append(feature)
316
+
317
+ return features
318
+
319
+ def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
320
+ # Ensure gene_sequence is a string before creating Seq object
321
+ if not isinstance(gene_sequence, str):
322
+ gene_sequence = str(gene_sequence)
323
+
324
+ features = create_genbank_features(df)
325
+
326
+ # Now gene_sequence is guaranteed to be a string, suitable for Seq
327
+ seq_obj = Seq(gene_sequence)
328
+ record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
329
+ description=f'CRISPR Cas12 predicted targets for {gene_symbol}', features=features)
330
+ record.annotations["molecule_type"] = "DNA"
331
+ SeqIO.write(record, output_path, "genbank")
332
+
333
+ def create_bed_file_from_df(df, output_path):
334
+ with open(output_path, 'w') as bed_file:
335
+ for index, row in df.iterrows():
336
+ chrom = row["Chr"]
337
+ start = int(row["Target Start"])
338
+ end = start + len(row["Target"]) # Calculate the end position based on the target sequence length
339
+ strand = '+' if row["Strand"] == '1' else '-'
340
+ gRNA = row["gRNA"]
341
+ score = str(row["Prediction"])
342
+ is_mutation = 'Yes' if row["Is Mutation"] else 'No'
343
+ # transcript_id is not typically part of the standard BED columns but added here for completeness
344
+ transcript_id = row["Transcript"]
345
+
346
+ # Writing only standard BED columns; additional columns can be appended as needed
347
+ bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\t{is_mutation}\n")
348
+
349
+ def create_csv_from_df(df, output_path):
350
+ df.to_csv(output_path, index=False)
351
+
cas9att.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import tensorflow as tf
3
+ import pandas as pd
4
+ import numpy as np
5
+ from operator import add
6
+ from functools import reduce
7
+ import random
8
+ import tabulate
9
+
10
+ from keras import Model
11
+ from keras import regularizers
12
+ from keras.optimizers import Adam
13
+ from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
14
+ from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
15
+ from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
16
+ from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
17
+ from keras.models import load_model
18
+ from keras.callbacks import EarlyStopping, ReduceLROnPlateau
19
+ from Bio import SeqIO
20
+ from Bio.SeqRecord import SeqRecord
21
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
22
+ from Bio.Seq import Seq
23
+
24
+ import cyvcf2
25
+ import parasail
26
+
27
+ import re
28
+
29
+ ntmap = {'A': (1, 0, 0, 0),
30
+ 'C': (0, 1, 0, 0),
31
+ 'G': (0, 0, 1, 0),
32
+ 'T': (0, 0, 0, 1)
33
+ }
34
+
35
+ def get_seqcode(seq):
36
+ return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
37
+
38
+ class PositionalEncoding(Layer):
39
+ def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
40
+ super(PositionalEncoding, self).__init__()
41
+ self.sequence_len = sequence_len
42
+ self.embedding_dim = embedding_dim
43
+
44
+ def call(self, x):
45
+
46
+ position_embedding = np.array([
47
+ [pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
48
+ for pos in range(self.sequence_len)])
49
+
50
+ position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i
51
+ position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1
52
+ position_embedding = tf.cast(position_embedding, dtype=tf.float32)
53
+
54
+ return position_embedding+x
55
+
56
+ def get_config(self):
57
+ config = super().get_config().copy()
58
+ config.update({
59
+ 'sequence_len' : self.sequence_len,
60
+ 'embedding_dim' : self.embedding_dim,
61
+ })
62
+ return config
63
+
64
+ def MultiHeadAttention_model(input_shape):
65
+ input = Input(shape=input_shape)
66
+
67
+ conv1 = Conv1D(256, 3, activation="relu")(input)
68
+ pool1 = AveragePooling1D(2)(conv1)
69
+ drop1 = Dropout(0.4)(pool1)
70
+
71
+ conv2 = Conv1D(256, 3, activation="relu")(drop1)
72
+ pool2 = AveragePooling1D(2)(conv2)
73
+ drop2 = Dropout(0.4)(pool2)
74
+
75
+ lstm = Bidirectional(LSTM(128,
76
+ dropout=0.5,
77
+ activation='tanh',
78
+ return_sequences=True,
79
+ kernel_regularizer=regularizers.l2(0.01)))(drop2)
80
+
81
+ pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
82
+ atten = MultiHeadAttention(num_heads=2,
83
+ key_dim=64,
84
+ dropout=0.2,
85
+ kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)
86
+
87
+ flat = Flatten()(atten)
88
+
89
+ dense1 = Dense(512,
90
+ kernel_regularizer=regularizers.l2(1e-4),
91
+ bias_regularizer=regularizers.l2(1e-4),
92
+ activation="relu")(flat)
93
+ drop3 = Dropout(0.1)(dense1)
94
+
95
+ dense2 = Dense(128,
96
+ kernel_regularizer=regularizers.l2(1e-4),
97
+ bias_regularizer=regularizers.l2(1e-4),
98
+ activation="relu")(drop3)
99
+ drop4 = Dropout(0.1)(dense2)
100
+
101
+ dense3 = Dense(256,
102
+ kernel_regularizer=regularizers.l2(1e-4),
103
+ bias_regularizer=regularizers.l2(1e-4),
104
+ activation="relu")(drop4)
105
+ drop5 = Dropout(0.1)(dense3)
106
+
107
+ output = Dense(1, activation="linear")(drop5)
108
+
109
+ model = Model(inputs=[input], outputs=[output])
110
+ return model
111
+
112
+ def fetch_ensembl_transcripts(gene_symbol):
113
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
114
+ response = requests.get(url)
115
+ if response.status_code == 200:
116
+ gene_data = response.json()
117
+ if 'Transcript' in gene_data:
118
+ return gene_data['Transcript']
119
+ else:
120
+ print("No transcripts found for gene:", gene_symbol)
121
+ return None
122
+ else:
123
+ print(f"Error fetching gene data from Ensembl: {response.text}")
124
+ return None
125
+
126
+ def fetch_ensembl_sequence(transcript_id):
127
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
128
+ response = requests.get(url)
129
+ if response.status_code == 200:
130
+ sequence_data = response.json()
131
+ if 'seq' in sequence_data:
132
+ return sequence_data['seq']
133
+ else:
134
+ print("No sequence found for transcript:", transcript_id)
135
+ return None
136
+ else:
137
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
138
+ return None
139
+
140
+ def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="NGG", target_length=20):
141
+ targets = []
142
+ len_sequence = len(sequence)
143
+ #complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
144
+ dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
145
+
146
+ for i in range(len_sequence - len(pam) + 1):
147
+ if sequence[i + 1:i + 3] == pam[1:]:
148
+ if i >= target_length:
149
+ target_seq = sequence[i - target_length:i + 3]
150
+ if strand == -1:
151
+ tar_start = end - (i + 2)
152
+ tar_end = end - (i - target_length)
153
+ #seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
154
+ else:
155
+ tar_start = start + i - target_length
156
+ tar_end = start + i + 3 - 1
157
+ #seq_in_ref = target_seq
158
+ gRNA = ''.join([dnatorna[base] for base in sequence[i - target_length:i]])
159
+ #targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
160
+ targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
161
+
162
+ return targets
163
+
164
+ # Function to predict on-target efficiency and format output
165
+ def format_prediction_output(targets, model_path):
166
+ model = MultiHeadAttention_model(input_shape=(23, 4))
167
+ model.load_weights(model_path)
168
+
169
+ formatted_data = []
170
+
171
+ for target in targets:
172
+ # Encode the gRNA sequence
173
+ encoded_seq = get_seqcode(target[0])
174
+
175
+ # Predict on-target efficiency using the model
176
+ prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
177
+ if prediction > 100:
178
+ prediction = 100
179
+
180
+ # Format output
181
+ gRNA = target[1]
182
+ chr = target[2]
183
+ start = target[3]
184
+ end = target[4]
185
+ strand = target[5]
186
+ transcript_id = target[6]
187
+ exon_id = target[7]
188
+ #seq_in_ref = target[8]
189
+ #formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction[0]])
190
+ formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
191
+
192
+ return formatted_data
193
+
194
+ def process_gene(gene_symbol, model_path):
195
+ # Fetch transcripts for the given gene symbol
196
+ transcripts = fetch_ensembl_transcripts(gene_symbol)
197
+ results = []
198
+ all_exons = [] # To accumulate all exons
199
+ all_gene_sequences = [] # To accumulate all gene sequences
200
+
201
+ if transcripts:
202
+ for transcript in transcripts:
203
+ Exons = transcript['Exon']
204
+ all_exons.extend(Exons) # Add all exons from this transcript to the list
205
+ transcript_id = transcript['id']
206
+
207
+ for exon in Exons:
208
+ exon_id = exon['id']
209
+ gene_sequence = fetch_ensembl_sequence(exon_id)
210
+ if gene_sequence:
211
+ all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
212
+ start = exon['start']
213
+ end = exon['end']
214
+ strand = exon['strand']
215
+ chr = exon['seq_region_name']
216
+ # Find potential CRISPR targets within the exon
217
+ targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
218
+ if targets:
219
+ # Format the prediction output for the targets found
220
+ formatted_data = format_prediction_output(targets, model_path)
221
+ results.extend(formatted_data) # Append results
222
+ else:
223
+ print(f"Failed to retrieve gene sequence for exon {exon_id}.")
224
+ else:
225
+ print("Failed to retrieve transcripts.")
226
+
227
+ # Return the sorted output, combined gene sequences, and all exons
228
+ return results, all_gene_sequences, all_exons
229
+
230
+
231
+ def create_genbank_features(data):
232
+ features = []
233
+
234
+ # If the input data is a DataFrame, convert it to a list of lists
235
+ if isinstance(data, pd.DataFrame):
236
+ formatted_data = data.values.tolist()
237
+ elif isinstance(data, list):
238
+ formatted_data = data
239
+ else:
240
+ raise TypeError("Data should be either a list or a pandas DataFrame.")
241
+
242
+ for row in formatted_data:
243
+ try:
244
+ start = int(row[1])
245
+ end = int(row[2])
246
+ except ValueError as e:
247
+ print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
248
+ continue
249
+
250
+ strand = 1 if row[3] == '+' else -1
251
+ location = FeatureLocation(start=start, end=end, strand=strand)
252
+ feature = SeqFeature(location=location, type="misc_feature", qualifiers={
253
+ 'label': row[7], # Use gRNA as the label
254
+ 'note': f"Prediction: {row[8]}" # Include the prediction score
255
+ })
256
+ features.append(feature)
257
+
258
+ return features
259
+
260
+
261
+ def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
262
+ # Ensure gene_sequence is a string before creating Seq object
263
+ if not isinstance(gene_sequence, str):
264
+ gene_sequence = str(gene_sequence)
265
+
266
+ features = create_genbank_features(df)
267
+
268
+ # Now gene_sequence is guaranteed to be a string, suitable for Seq
269
+ seq_obj = Seq(gene_sequence)
270
+ record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
271
+ description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
272
+ record.annotations["molecule_type"] = "DNA"
273
+ SeqIO.write(record, output_path, "genbank")
274
+
275
+
276
+ def create_bed_file_from_df(df, output_path):
277
+ with open(output_path, 'w') as bed_file:
278
+ for index, row in df.iterrows():
279
+ chrom = row["Chr"]
280
+ start = int(row["Start Pos"])
281
+ end = int(row["End Pos"])
282
+ strand = '+' if row["Strand"] == '1' else '-'
283
+ gRNA = row["gRNA"]
284
+ score = str(row["Prediction"])
285
+ # transcript_id is not typically part of the standard BED columns but added here for completeness
286
+ transcript_id = row["Transcript"]
287
+
288
+ # Writing only standard BED columns; additional columns can be appended as needed
289
+ bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
290
+
291
+
292
+ def create_csv_from_df(df, output_path):
293
+ df.to_csv(output_path, index=False)
294
+
295
+
296
+
cas9attvcf.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import tensorflow as tf
3
+ import pandas as pd
4
+ import numpy as np
5
+ from operator import add
6
+ from functools import reduce
7
+ import random
8
+ import tabulate
9
+
10
+ from keras import Model
11
+ from keras import regularizers
12
+ from keras.optimizers import Adam
13
+ from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
14
+ from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
15
+ from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
16
+ from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
17
+ from keras.models import load_model
18
+ from keras.callbacks import EarlyStopping, ReduceLROnPlateau
19
+ from Bio import SeqIO
20
+ from Bio.SeqRecord import SeqRecord
21
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
22
+ from Bio.Seq import Seq
23
+
24
+ import cyvcf2
25
+ import parasail
26
+
27
+ import re
28
+
29
+ ntmap = {'A': (1, 0, 0, 0),
30
+ 'C': (0, 1, 0, 0),
31
+ 'G': (0, 0, 1, 0),
32
+ 'T': (0, 0, 0, 1)
33
+ }
34
+
35
+ def get_seqcode(seq):
36
+ return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
37
+
38
+ class PositionalEncoding(Layer):
39
+ def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
40
+ super(PositionalEncoding, self).__init__()
41
+ self.sequence_len = sequence_len
42
+ self.embedding_dim = embedding_dim
43
+
44
+ def call(self, x):
45
+
46
+ position_embedding = np.array([
47
+ [pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
48
+ for pos in range(self.sequence_len)])
49
+
50
+ position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i
51
+ position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1
52
+ position_embedding = tf.cast(position_embedding, dtype=tf.float32)
53
+
54
+ return position_embedding+x
55
+
56
+ def get_config(self):
57
+ config = super().get_config().copy()
58
+ config.update({
59
+ 'sequence_len' : self.sequence_len,
60
+ 'embedding_dim' : self.embedding_dim,
61
+ })
62
+ return config
63
+
64
+ def MultiHeadAttention_model(input_shape):
65
+ input = Input(shape=input_shape)
66
+
67
+ conv1 = Conv1D(256, 3, activation="relu")(input)
68
+ pool1 = AveragePooling1D(2)(conv1)
69
+ drop1 = Dropout(0.4)(pool1)
70
+
71
+ conv2 = Conv1D(256, 3, activation="relu")(drop1)
72
+ pool2 = AveragePooling1D(2)(conv2)
73
+ drop2 = Dropout(0.4)(pool2)
74
+
75
+ lstm = Bidirectional(LSTM(128,
76
+ dropout=0.5,
77
+ activation='tanh',
78
+ return_sequences=True,
79
+ kernel_regularizer=regularizers.l2(0.01)))(drop2)
80
+
81
+ pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
82
+ atten = MultiHeadAttention(num_heads=2,
83
+ key_dim=64,
84
+ dropout=0.2,
85
+ kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)
86
+
87
+ flat = Flatten()(atten)
88
+
89
+ dense1 = Dense(512,
90
+ kernel_regularizer=regularizers.l2(1e-4),
91
+ bias_regularizer=regularizers.l2(1e-4),
92
+ activation="relu")(flat)
93
+ drop3 = Dropout(0.1)(dense1)
94
+
95
+ dense2 = Dense(128,
96
+ kernel_regularizer=regularizers.l2(1e-4),
97
+ bias_regularizer=regularizers.l2(1e-4),
98
+ activation="relu")(drop3)
99
+ drop4 = Dropout(0.1)(dense2)
100
+
101
+ dense3 = Dense(256,
102
+ kernel_regularizer=regularizers.l2(1e-4),
103
+ bias_regularizer=regularizers.l2(1e-4),
104
+ activation="relu")(drop4)
105
+ drop5 = Dropout(0.1)(dense3)
106
+
107
+ output = Dense(1, activation="linear")(drop5)
108
+
109
+ model = Model(inputs=[input], outputs=[output])
110
+ return model
111
+
112
+ def fetch_ensembl_transcripts(gene_symbol):
113
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
114
+ response = requests.get(url)
115
+ if response.status_code == 200:
116
+ gene_data = response.json()
117
+ if 'Transcript' in gene_data:
118
+ return gene_data['Transcript']
119
+ else:
120
+ print("No transcripts found for gene:", gene_symbol)
121
+ return None
122
+ else:
123
+ print(f"Error fetching gene data from Ensembl: {response.text}")
124
+ return None
125
+
126
+ def fetch_ensembl_sequence(transcript_id):
127
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
128
+ response = requests.get(url)
129
+ if response.status_code == 200:
130
+ sequence_data = response.json()
131
+ if 'seq' in sequence_data:
132
+ return sequence_data['seq']
133
+ else:
134
+ print("No sequence found for transcript:", transcript_id)
135
+ return None
136
+ else:
137
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
138
+ return None
139
+
140
+ def apply_mutation(ref_sequence, offset, ref, alt):
141
+ """
142
+ Apply a single mutation to the sequence.
143
+ """
144
+ if len(ref) == len(alt) and alt != "*": # SNP
145
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]
146
+
147
+ elif len(ref) < len(alt): # Insertion
148
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]
149
+
150
+ elif len(ref) == len(alt) and alt == "*": # Deletion
151
+ mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]
152
+
153
+ elif len(ref) > len(alt) and alt != "*": # Deletion
154
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]
155
+
156
+ elif len(ref) > len(alt) and alt == "*": # Deletion
157
+ mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]
158
+
159
+ return mutated_seq
160
+
161
+ def construct_combinations(sequence, mutations):
162
+ """
163
+ Construct all combinations of mutations.
164
+ mutations is a list of tuples (position, ref, [alts])
165
+ """
166
+ if not mutations:
167
+ return [sequence]
168
+
169
+ # Take the first mutation and recursively construct combinations for the rest
170
+ first_mutation = mutations[0]
171
+ rest_mutations = mutations[1:]
172
+ offset, ref, alts = first_mutation
173
+
174
+ sequences = []
175
+ for alt in alts:
176
+ mutated_sequence = apply_mutation(sequence, offset, ref, alt)
177
+ sequences.extend(construct_combinations(mutated_sequence, rest_mutations))
178
+
179
+ return sequences
180
+
181
+ def needleman_wunsch_alignment(query_seq, ref_seq):
182
+ """
183
+ Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
184
+ Use this position to represent the position of target sequence with mutations
185
+ """
186
+ # Needleman-Wunsch alignment
187
+ alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)
188
+
189
+ # extract CIGAR object
190
+ cigar = alignment.cigar
191
+ cigar_string = cigar.decode.decode("utf-8")
192
+
193
+ # record ref_pos
194
+ ref_pos = 0
195
+
196
+ matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
197
+ max_num_before_equal = 0
198
+ max_equal_index = -1
199
+ total_before_max_equal = 0
200
+
201
+ for i, (num_str, op) in enumerate(matches):
202
+ num = int(num_str)
203
+ if op == '=':
204
+ if num > max_num_before_equal:
205
+ max_num_before_equal = num
206
+ max_equal_index = i
207
+ total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))
208
+
209
+ ref_pos = total_before_max_equal
210
+
211
+ return ref_pos
212
+
213
+ def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
214
+ exon_id, gene_symbol, vcf_reader, pam="NGG", target_length=20):
215
+ # initialization
216
+ mutated_sequences = [ref_sequence]
217
+
218
+ # find mutations within interested region
219
+ mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
220
+ if mutations:
221
+ # find mutations
222
+ mutation_list = []
223
+ for mutation in mutations:
224
+ offset = mutation.POS - start
225
+ ref = mutation.REF
226
+ alts = mutation.ALT[:-1]
227
+ mutation_list.append((offset, ref, alts))
228
+
229
+ # replace reference sequence of mutation
230
+ mutated_sequences = construct_combinations(ref_sequence, mutation_list)
231
+
232
+ # find gRNA in ref_sequence or all mutated_sequences
233
+ targets = []
234
+ for seq in mutated_sequences:
235
+ len_sequence = len(seq)
236
+ dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
237
+ for i in range(len_sequence - len(pam) + 1):
238
+ if seq[i + 1:i + 3] == pam[1:]:
239
+ if i >= target_length:
240
+ target_seq = seq[i - target_length:i + 3]
241
+ pos = ref_sequence.find(target_seq)
242
+ if pos != -1:
243
+ is_mut = False
244
+ if strand == -1:
245
+ tar_start = end - pos - target_length - 2
246
+ else:
247
+ tar_start = start + pos
248
+ else:
249
+ is_mut = True
250
+ nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
251
+ if strand == -1:
252
+ tar_start = str(end - nw_pos - target_length - 2) + '*'
253
+ else:
254
+ tar_start = str(start + nw_pos) + '*'
255
+ gRNA = ''.join([dnatorna[base] for base in seq[i - target_length:i]])
256
+ targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])
257
+
258
+ # filter duplicated targets
259
+ unique_targets_set = set(tuple(element) for element in targets)
260
+ unique_targets = [list(element) for element in unique_targets_set]
261
+
262
+ return unique_targets
263
+
264
+ def format_prediction_output_with_mutation(targets, model_path):
265
+ model = MultiHeadAttention_model(input_shape=(23, 4))
266
+ model.load_weights(model_path)
267
+
268
+ formatted_data = []
269
+
270
+ for target in targets:
271
+ # Encode the gRNA sequence
272
+ encoded_seq = get_seqcode(target[0])
273
+
274
+
275
+ # Predict on-target efficiency using the model
276
+ prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
277
+ if prediction > 100:
278
+ prediction = 100
279
+
280
+ # Format output
281
+ gRNA = target[1]
282
+ exon_chr = target[2]
283
+ strand = target[3]
284
+ tar_start = target[4]
285
+ transcript_id = target[5]
286
+ exon_id = target[6]
287
+ gene_symbol = target[7]
288
+ is_mut = target[8]
289
+ formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id,
290
+ exon_id, target[0], gRNA, prediction, is_mut])
291
+
292
+ return formatted_data
293
+
294
+
295
+ def process_gene(gene_symbol, vcf_reader, model_path):
296
+ transcripts = fetch_ensembl_transcripts(gene_symbol)
297
+ results = []
298
+ all_exons = [] # To accumulate all exons
299
+ all_gene_sequences = [] # To accumulate all gene sequences
300
+
301
+ if transcripts:
302
+ for transcript in transcripts:
303
+ Exons = transcript['Exon']
304
+ all_exons.extend(Exons) # Add all exons from this transcript to the list
305
+ transcript_id = transcript['id']
306
+
307
+ for Exon in Exons:
308
+ exon_id = Exon['id']
309
+ gene_sequence = fetch_ensembl_sequence(exon_id) # Reference exon sequence
310
+ if gene_sequence:
311
+ all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
312
+ exon_chr = Exon['seq_region_name']
313
+ start = Exon['start']
314
+ end = Exon['end']
315
+ strand = Exon['strand']
316
+
317
+ targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand,
318
+ transcript_id, exon_id, gene_symbol, vcf_reader)
319
+ if targets:
320
+ # Predict on-target efficiency for each gRNA site including mutations
321
+ formatted_data = format_prediction_output_with_mutation(targets, model_path)
322
+ results.extend(formatted_data)
323
+ else:
324
+ print(f"Failed to retrieve gene sequence for exon {exon_id}.")
325
+ else:
326
+ print("Failed to retrieve transcripts.")
327
+
328
+ # Return the sorted output, combined gene sequences, and all exons
329
+ return results, all_gene_sequences, all_exons
330
+
331
+
332
+ def create_genbank_features(data):
333
+ features = []
334
+
335
+ # If the input data is a DataFrame, convert it to a list of lists
336
+ if isinstance(data, pd.DataFrame):
337
+ formatted_data = data.values.tolist()
338
+ elif isinstance(data, list):
339
+ formatted_data = data
340
+ else:
341
+ raise TypeError("Data should be either a list or a pandas DataFrame.")
342
+
343
+ for row in formatted_data:
344
+ try:
345
+ start = int(row[1])
346
+ end = start + len(row[6]) # Calculate the end position based on the target sequence length
347
+ except ValueError as e:
348
+ print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
349
+ continue
350
+
351
+ strand = 1 if row[3] == '1' else -1
352
+ location = FeatureLocation(start=start, end=end, strand=strand)
353
+ is_mutation = 'Yes' if row[9] else 'No'
354
+ feature = SeqFeature(location=location, type="misc_feature", qualifiers={
355
+ 'label': row[7], # Use gRNA as the label
356
+ 'note': f"Prediction: {row[8]}, Mutation: {is_mutation}" # Include the prediction score and mutation status
357
+ })
358
+ features.append(feature)
359
+
360
+ return features
361
+
362
+ def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
363
+ # Ensure gene_sequence is a string before creating Seq object
364
+ if not isinstance(gene_sequence, str):
365
+ gene_sequence = str(gene_sequence)
366
+
367
+ features = create_genbank_features(df)
368
+
369
+ # Now gene_sequence is guaranteed to be a string, suitable for Seq
370
+ seq_obj = Seq(gene_sequence)
371
+ record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
372
+ description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
373
+ record.annotations["molecule_type"] = "DNA"
374
+ SeqIO.write(record, output_path, "genbank")
375
+
376
+ def create_bed_file_from_df(df, output_path):
377
+ with open(output_path, 'w') as bed_file:
378
+ for index, row in df.iterrows():
379
+ chrom = row["Chr"]
380
+ start = int(row["Target Start"])
381
+ end = start + len(row["Target"]) # Calculate the end position based on the target sequence length
382
+ strand = '+' if row["Strand"] == '1' else '-'
383
+ gRNA = row["gRNA"]
384
+ score = str(row["Prediction"])
385
+ is_mutation = 'Yes' if row["Is Mutation"] else 'No'
386
+ # transcript_id is not typically part of the standard BED columns but added here for completeness
387
+ transcript_id = row["Transcript"]
388
+
389
+ # Writing only standard BED columns; additional columns can be appended as needed
390
+ bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\t{is_mutation}\n")
391
+
392
+ def create_csv_from_df(df, output_path):
393
+ df.to_csv(output_path, index=False)
cas9off.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ import pandas as pd
4
+ import os
5
+ import argparse
6
+
7
+ # configure GPUs
8
+ for gpu in tf.config.list_physical_devices('GPU'):
9
+ tf.config.experimental.set_memory_growth(gpu, enable=True)
10
+ if len(tf.config.list_physical_devices('GPU')) > 0:
11
+ tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
12
+
13
+ class Encoder:
14
+ def __init__(self, on_seq, off_seq, with_category = False, label = None, with_reg_val = False, value = None):
15
+ tlen = 24
16
+ self.on_seq = "-" *(tlen-len(on_seq)) + on_seq
17
+ self.off_seq = "-" *(tlen-len(off_seq)) + off_seq
18
+ self.encoded_dict_indel = {'A': [1, 0, 0, 0, 0], 'T': [0, 1, 0, 0, 0],
19
+ 'G': [0, 0, 1, 0, 0], 'C': [0, 0, 0, 1, 0], '_': [0, 0, 0, 0, 1], '-': [0, 0, 0, 0, 0]}
20
+ self.direction_dict = {'A':5, 'G':4, 'C':3, 'T':2, '_':1}
21
+ if with_category:
22
+ self.label = label
23
+ if with_reg_val:
24
+ self.value = value
25
+ self.encode_on_off_dim7()
26
+
27
+ def encode_sgRNA(self):
28
+ code_list = []
29
+ encoded_dict = self.encoded_dict_indel
30
+ sgRNA_bases = list(self.on_seq)
31
+ for i in range(len(sgRNA_bases)):
32
+ if sgRNA_bases[i] == "N":
33
+ sgRNA_bases[i] = list(self.off_seq)[i]
34
+ code_list.append(encoded_dict[sgRNA_bases[i]])
35
+ self.sgRNA_code = np.array(code_list)
36
+
37
+ def encode_off(self):
38
+ code_list = []
39
+ encoded_dict = self.encoded_dict_indel
40
+ off_bases = list(self.off_seq)
41
+ for i in range(len(off_bases)):
42
+ code_list.append(encoded_dict[off_bases[i]])
43
+ self.off_code = np.array(code_list)
44
+
45
+ def encode_on_off_dim7(self):
46
+ self.encode_sgRNA()
47
+ self.encode_off()
48
+ on_bases = list(self.on_seq)
49
+ off_bases = list(self.off_seq)
50
+ on_off_dim7_codes = []
51
+ for i in range(len(on_bases)):
52
+ diff_code = np.bitwise_or(self.sgRNA_code[i], self.off_code[i])
53
+ on_b = on_bases[i]
54
+ off_b = off_bases[i]
55
+ if on_b == "N":
56
+ on_b = off_b
57
+ dir_code = np.zeros(2)
58
+ if on_b == "-" or off_b == "-" or self.direction_dict[on_b] == self.direction_dict[off_b]:
59
+ pass
60
+ else:
61
+ if self.direction_dict[on_b] > self.direction_dict[off_b]:
62
+ dir_code[0] = 1
63
+ else:
64
+ dir_code[1] = 1
65
+ on_off_dim7_codes.append(np.concatenate((diff_code, dir_code)))
66
+ self.on_off_code = np.array(on_off_dim7_codes)
67
+
68
+ def encode_on_off_seq_pairs(input_file):
69
+ inputs = pd.read_csv(input_file, delimiter=",", header=None, names=['on_seq', 'off_seq'])
70
+ input_codes = []
71
+ for idx, row in inputs.iterrows():
72
+ on_seq = row['on_seq']
73
+ off_seq = row['off_seq']
74
+ en = Encoder(on_seq=on_seq, off_seq=off_seq)
75
+ input_codes.append(en.on_off_code)
76
+ input_codes = np.array(input_codes)
77
+ input_codes = input_codes.reshape((len(input_codes), 1, 24, 7))
78
+ y_pred = CRISPR_net_predict(input_codes)
79
+ inputs['CRISPR_Net_score'] = y_pred
80
+ inputs.to_csv("CRISPR_net_results.csv", index=False)
81
+
82
+ def CRISPR_net_predict(X_test):
83
+ json_file = open("cas9_model/CRISPR_Net_CIRCLE_elevation_SITE_structure.json", 'r')
84
+ loaded_model_json = json_file.read()
85
+ json_file.close()
86
+ loaded_model = tf.keras.models.model_from_json(loaded_model_json) # Updated for TensorFlow 2
87
+ loaded_model.load_weights("cas9_model/CRISPR_Net_CIRCLE_elevation_SITE_weights.h5")
88
+ y_pred = loaded_model.predict(X_test).flatten()
89
+ return y_pred
90
+
91
+
92
+ def process_input_and_predict(input_data, input_type='manual'):
93
+ if input_type == 'manual':
94
+ sequences = [seq.split(',') for seq in input_data.split('\n')]
95
+ inputs = pd.DataFrame(sequences, columns=['on_seq', 'off_seq'])
96
+ elif input_type == 'file':
97
+ inputs = pd.read_csv(input_data, delimiter=",", header=None, names=['on_seq', 'off_seq'])
98
+
99
+ valid_inputs = []
100
+ input_codes = []
101
+ for idx, row in inputs.iterrows():
102
+ on_seq = row['on_seq']
103
+ off_seq = row['off_seq']
104
+ if not on_seq or not off_seq:
105
+ continue
106
+
107
+ en = Encoder(on_seq=on_seq, off_seq=off_seq)
108
+ input_codes.append(en.on_off_code)
109
+ valid_inputs.append((on_seq, off_seq))
110
+
111
+ input_codes = np.array(input_codes)
112
+ input_codes = input_codes.reshape((len(input_codes), 1, 24, 7))
113
+
114
+ y_pred = CRISPR_net_predict(input_codes)
115
+
116
+ # Create a new DataFrame from valid inputs and predictions
117
+ result_df = pd.DataFrame(valid_inputs, columns=['on_seq', 'off_seq'])
118
+ result_df['CRISPR_Net_score'] = y_pred
119
+
120
+ return result_df
121
+
122
+ if __name__ == '__main__':
123
+ parser = argparse.ArgumentParser(description="CRISPR-Net v1.0 (Aug 10 2019)")
124
+ parser.add_argument("input_file",
125
+ help="input_file example (on-target seq, off-target seq):\n GAGT_CCGAGCAGAAGAAGAATGG,GAGTACCAAGTAGAAGAAAAATTT\n"
126
+ "GTTGCCCCACAGGGCAGTAAAGG,GTGGACACCCCGGGCAGGAAAGG\n"
127
+ "GGGTGGGGGGAGTTTGCTCCAGG,AGGTGGGGTGA_TTTGCTCCAGG")
128
+ args = parser.parse_args()
129
+ file = args.input_file
130
+ if not os.path.exists(args.input_file):
131
+ print("File doesn't exist!")
132
+ else:
133
+ encode_on_off_seq_pairs(file)
134
+ tf.keras.backend.clear_session()
tiger.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import gzip
4
+ import pickle
5
+ import numpy as np
6
+ import pandas as pd
7
+ import tensorflow as tf
8
+ from Bio import SeqIO
9
+
10
+ # column names
11
+ ID_COL = 'Transcript ID'
12
+ SEQ_COL = 'Transcript Sequence'
13
+ TARGET_COL = 'Target Sequence'
14
+ GUIDE_COL = 'Guide Sequence'
15
+ MM_COL = 'Number of Mismatches'
16
+ SCORE_COL = 'Guide Score'
17
+
18
+ # nucleotide tokens
19
+ NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T', 'N'], [0, 1, 2, 3, 255]))
20
+ NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
21
+
22
+ # model hyper-parameters
23
+ GUIDE_LEN = 23
24
+ CONTEXT_5P = 3
25
+ CONTEXT_3P = 0
26
+ TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P
27
+ UNIT_INTERVAL_MAP = 'sigmoid'
28
+
29
+ # reference transcript files
30
+ REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
31
+
32
+ # application configuration
33
+ BATCH_SIZE_COMPUTE = 500
34
+ BATCH_SIZE_SCAN = 20
35
+ BATCH_SIZE_TRANSCRIPTS = 50
36
+ NUM_TOP_GUIDES = 10
37
+ NUM_MISMATCHES = 3
38
+ RUN_MODES = dict(
39
+ all='All on-target guides per transcript',
40
+ top_guides='Top {:d} guides per transcript'.format(NUM_TOP_GUIDES),
41
+ titration='Top {:d} guides per transcript & their titration candidates'.format(NUM_TOP_GUIDES)
42
+ )
43
+
44
+
45
+ # configure GPUs
46
+ for gpu in tf.config.list_physical_devices('GPU'):
47
+ tf.config.experimental.set_memory_growth(gpu, enable=True)
48
+ if len(tf.config.list_physical_devices('GPU')) > 0:
49
+ tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
50
+
51
+
52
+ def load_transcripts(fasta_files: list, enforce_unique_ids: bool = True):
53
+
54
+ # load all transcripts from fasta files into a DataFrame
55
+ transcripts = pd.DataFrame()
56
+ for file in fasta_files:
57
+ try:
58
+ if os.path.splitext(file)[1] == '.gz':
59
+ with gzip.open(file, 'rt') as f:
60
+ df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=[ID_COL, SEQ_COL])
61
+ else:
62
+ df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(file, 'fasta')], columns=[ID_COL, SEQ_COL])
63
+ except Exception as e:
64
+ print(e, 'while loading', file)
65
+ continue
66
+ transcripts = pd.concat([transcripts, df])
67
+
68
+ # set index
69
+ transcripts[ID_COL] = transcripts[ID_COL].apply(lambda s: s.split('|')[0])
70
+ transcripts.set_index(ID_COL, inplace=True)
71
+ if enforce_unique_ids:
72
+ assert not transcripts.index.has_duplicates, "duplicate transcript ID's detected in fasta file"
73
+
74
+ return transcripts
75
+
76
+
77
+ def sequence_complement(sequence: list):
78
+ return [''.join([NUCLEOTIDE_COMPLEMENT[nt] for nt in list(seq)]) for seq in sequence]
79
+
80
+
81
+ def one_hot_encode_sequence(sequence: list, add_context_padding: bool = False):
82
+
83
+ # stack list of sequences into a tensor
84
+ sequence = tf.ragged.stack([tf.constant(list(seq)) for seq in sequence], axis=0)
85
+
86
+ # tokenize sequence
87
+ nucleotide_table = tf.lookup.StaticVocabularyTable(
88
+ initializer=tf.lookup.KeyValueTensorInitializer(
89
+ keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string),
90
+ values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)),
91
+ num_oov_buckets=1)
92
+ sequence = tf.RaggedTensor.from_row_splits(values=nucleotide_table.lookup(sequence.values),
93
+ row_splits=sequence.row_splits).to_tensor(255)
94
+
95
+ # add context padding if requested
96
+ if add_context_padding:
97
+ pad_5p = 255 * tf.ones([sequence.shape[0], CONTEXT_5P], dtype=sequence.dtype)
98
+ pad_3p = 255 * tf.ones([sequence.shape[0], CONTEXT_3P], dtype=sequence.dtype)
99
+ sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)
100
+
101
+ # one-hot encode
102
+ sequence = tf.one_hot(sequence, depth=4, dtype=tf.float16)
103
+
104
+ return sequence
105
+
106
+
107
+ def process_data(transcript_seq: str):
108
+
109
+ # convert to upper case
110
+ transcript_seq = transcript_seq.upper()
111
+
112
+ # get all target sites
113
+ target_seq = [transcript_seq[i: i + TARGET_LEN] for i in range(len(transcript_seq) - TARGET_LEN + 1)]
114
+
115
+ # prepare guide sequences
116
+ guide_seq = sequence_complement([seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq])
117
+
118
+ # model inputs
119
+ model_inputs = tf.concat([
120
+ tf.reshape(one_hot_encode_sequence(target_seq, add_context_padding=False), [len(target_seq), -1]),
121
+ tf.reshape(one_hot_encode_sequence(guide_seq, add_context_padding=True), [len(guide_seq), -1]),
122
+ ], axis=-1)
123
+ return target_seq, guide_seq, model_inputs
124
+
125
+
126
+ def calibrate_predictions(predictions: np.array, num_mismatches: np.array, params: pd.DataFrame = None):
127
+ if params is None:
128
+ params = pd.read_pickle('calibration_params.pkl')
129
+ correction = np.squeeze(params.set_index('num_mismatches').loc[num_mismatches, 'slope'].to_numpy())
130
+ return correction * predictions
131
+
132
+
133
+ def score_predictions(predictions: np.array, params: pd.DataFrame = None):
134
+ if params is None:
135
+ params = pd.read_pickle('scoring_params.pkl')
136
+
137
+ if UNIT_INTERVAL_MAP == 'sigmoid':
138
+ params = params.iloc[0]
139
+ return 1 - 1 / (1 + np.exp(params['a'] * predictions + params['b']))
140
+
141
+ elif UNIT_INTERVAL_MAP == 'min-max':
142
+ return 1 - (predictions - params['a']) / (params['b'] - params['a'])
143
+
144
+ elif UNIT_INTERVAL_MAP == 'exp-lin-exp':
145
+ # regime indices
146
+ active_saturation = predictions < params['a']
147
+ linear_regime = (params['a'] <= predictions) & (predictions <= params['c'])
148
+ inactive_saturation = params['c'] < predictions
149
+
150
+ # linear regime
151
+ slope = (params['d'] - params['b']) / (params['c'] - params['a'])
152
+ intercept = -params['a'] * slope + params['b']
153
+ predictions[linear_regime] = slope * predictions[linear_regime] + intercept
154
+
155
+ # active saturation regime
156
+ alpha = slope / params['b']
157
+ beta = alpha * params['a'] - np.log(params['b'])
158
+ predictions[active_saturation] = np.exp(alpha * predictions[active_saturation] - beta)
159
+
160
+ # inactive saturation regime
161
+ alpha = slope / (1 - params['d'])
162
+ beta = -alpha * params['c'] - np.log(1 - params['d'])
163
+ predictions[inactive_saturation] = 1 - np.exp(-alpha * predictions[inactive_saturation] - beta)
164
+
165
+ return 1 - predictions
166
+
167
+ else:
168
+ raise NotImplementedError
169
+
170
+
171
+ def get_on_target_predictions(transcripts: pd.DataFrame, model: tf.keras.Model, status_update_fn=None):
172
+
173
+ # loop over transcripts
174
+ predictions = pd.DataFrame()
175
+ for i, (index, row) in enumerate(transcripts.iterrows()):
176
+
177
+ # parse transcript sequence
178
+ target_seq, guide_seq, model_inputs = process_data(row[SEQ_COL])
179
+
180
+ # get predictions
181
+ lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
182
+ lfc_estimate = calibrate_predictions(lfc_estimate, num_mismatches=np.zeros_like(lfc_estimate))
183
+ scores = score_predictions(lfc_estimate)
184
+ predictions = pd.concat([predictions, pd.DataFrame({
185
+ ID_COL: [index] * len(scores),
186
+ TARGET_COL: target_seq,
187
+ GUIDE_COL: guide_seq,
188
+ SCORE_COL: scores})])
189
+
190
+ # progress update
191
+ percent_complete = 100 * min((i + 1) / len(transcripts), 1)
192
+ update_text = 'Evaluating on-target guides for each transcript: {:.2f}%'.format(percent_complete)
193
+ print('\r' + update_text, end='')
194
+ if status_update_fn is not None:
195
+ status_update_fn(update_text, percent_complete)
196
+ print('')
197
+
198
+ return predictions
199
+
200
+
201
+ def top_guides_per_transcript(predictions: pd.DataFrame):
202
+
203
+ # select and sort top guides for each transcript
204
+ top_guides = pd.DataFrame()
205
+ for transcript in predictions[ID_COL].unique():
206
+ df = predictions.loc[predictions[ID_COL] == transcript]
207
+ df = df.sort_values(SCORE_COL, ascending=False).reset_index(drop=True).iloc[:NUM_TOP_GUIDES]
208
+ top_guides = pd.concat([top_guides, df])
209
+
210
+ return top_guides.reset_index(drop=True)
211
+
212
+
213
+ def get_titration_candidates(top_guide_predictions: pd.DataFrame):
214
+
215
+ # generate a table of all titration candidates
216
+ titration_candidates = pd.DataFrame()
217
+ for _, row in top_guide_predictions.iterrows():
218
+ for i in range(len(row[GUIDE_COL])):
219
+ nt = row[GUIDE_COL][i]
220
+ for mutation in set(NUCLEOTIDE_TOKENS.keys()) - {nt, 'N'}:
221
+ sm_guide = list(row[GUIDE_COL])
222
+ sm_guide[i] = mutation
223
+ sm_guide = ''.join(sm_guide)
224
+ assert row[GUIDE_COL] != sm_guide
225
+ titration_candidates = pd.concat([titration_candidates, pd.DataFrame({
226
+ ID_COL: [row[ID_COL]],
227
+ TARGET_COL: [row[TARGET_COL]],
228
+ GUIDE_COL: [sm_guide],
229
+ MM_COL: [1]
230
+ })])
231
+
232
+ return titration_candidates
233
+
234
+
235
+ def find_off_targets(top_guides: pd.DataFrame, status_update_fn=None):
236
+
237
+ # load reference transcripts
238
+ reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS])
239
+
240
+ # one-hot encode guides to form a filter
241
+ guide_filter = one_hot_encode_sequence(sequence_complement(top_guides[GUIDE_COL]), add_context_padding=False)
242
+ guide_filter = tf.transpose(guide_filter, [1, 2, 0])
243
+
244
+ # loop over transcripts in batches
245
+ i = 0
246
+ off_targets = pd.DataFrame()
247
+ while i < len(reference_transcripts):
248
+ # select batch
249
+ df_batch = reference_transcripts.iloc[i:min(i + BATCH_SIZE_SCAN, len(reference_transcripts))]
250
+ i += BATCH_SIZE_SCAN
251
+
252
+ # find locations of off-targets
253
+ transcripts = one_hot_encode_sequence(df_batch[SEQ_COL].values.tolist(), add_context_padding=False)
254
+ num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
255
+ loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
256
+
257
+ # off-targets discovered
258
+ if len(loc_off_targets) > 0:
259
+
260
+ # log off-targets
261
+ dict_off_targets = pd.DataFrame({
262
+ 'On-target ' + ID_COL: top_guides.iloc[loc_off_targets[:, 2]][ID_COL],
263
+ GUIDE_COL: top_guides.iloc[loc_off_targets[:, 2]][GUIDE_COL],
264
+ 'Off-target ' + ID_COL: df_batch.index.values[loc_off_targets[:, 0]],
265
+ 'Guide Midpoint': loc_off_targets[:, 1],
266
+ SEQ_COL: df_batch[SEQ_COL].values[loc_off_targets[:, 0]],
267
+ MM_COL: tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
268
+ }).to_dict('records')
269
+
270
+ # trim transcripts to targets
271
+ for row in dict_off_targets:
272
+ start_location = row['Guide Midpoint'] - (GUIDE_LEN // 2)
273
+ del row['Guide Midpoint']
274
+ target = row[SEQ_COL]
275
+ del row[SEQ_COL]
276
+ if start_location < CONTEXT_5P:
277
+ target = target[0:GUIDE_LEN + CONTEXT_3P]
278
+ target = 'N' * (TARGET_LEN - len(target)) + target
279
+ elif start_location + GUIDE_LEN + CONTEXT_3P > len(target):
280
+ target = target[start_location - CONTEXT_5P:]
281
+ target = target + 'N' * (TARGET_LEN - len(target))
282
+ else:
283
+ target = target[start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
284
+ if row[MM_COL] == 0 and 'N' not in target:
285
+ assert row[GUIDE_COL] == sequence_complement([target[CONTEXT_5P:TARGET_LEN - CONTEXT_3P]])[0]
286
+ row[TARGET_COL] = target
287
+
288
+ # append new off-targets
289
+ off_targets = pd.concat([off_targets, pd.DataFrame(dict_off_targets)])
290
+
291
+ # progress update
292
+ percent_complete = 100 * min((i + 1) / len(reference_transcripts), 1)
293
+ update_text = 'Scanning for off-targets: {:.2f}%'.format(percent_complete)
294
+ print('\r' + update_text, end='')
295
+ if status_update_fn is not None:
296
+ status_update_fn(update_text, percent_complete)
297
+ print('')
298
+
299
+ return off_targets
300
+
301
+
302
+ def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
303
+ if len(off_targets) == 0:
304
+ return pd.DataFrame()
305
+
306
+ # compute off-target predictions
307
+ model_inputs = tf.concat([
308
+ tf.reshape(one_hot_encode_sequence(off_targets[TARGET_COL], add_context_padding=False), [len(off_targets), -1]),
309
+ tf.reshape(one_hot_encode_sequence(off_targets[GUIDE_COL], add_context_padding=True), [len(off_targets), -1]),
310
+ ], axis=-1)
311
+ lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
312
+ lfc_estimate = calibrate_predictions(lfc_estimate, off_targets['Number of Mismatches'].to_numpy())
313
+ off_targets[SCORE_COL] = score_predictions(lfc_estimate)
314
+
315
+ return off_targets.reset_index(drop=True)
316
+
317
+
318
+ def tiger_exhibit(transcripts: pd.DataFrame, mode: str, check_off_targets: bool, status_update_fn=None):
319
+
320
+ # load model
321
+ if os.path.exists('cas13_model'):
322
+ tiger = tf.keras.models.load_model('cas13_model')
323
+ else:
324
+ print('no saved model!')
325
+ exit()
326
+
327
+ # evaluate all on-target guides per transcript
328
+ on_target_predictions = get_on_target_predictions(transcripts, tiger, status_update_fn)
329
+
330
+ # initialize other outputs
331
+ titration_predictions = off_target_predictions = None
332
+
333
+ if mode == 'all' and not check_off_targets:
334
+ off_target_candidates = None
335
+
336
+ elif mode == 'top_guides':
337
+ on_target_predictions = top_guides_per_transcript(on_target_predictions)
338
+ off_target_candidates = on_target_predictions
339
+
340
+ elif mode == 'titration':
341
+ on_target_predictions = top_guides_per_transcript(on_target_predictions)
342
+ titration_candidates = get_titration_candidates(on_target_predictions)
343
+ titration_predictions = predict_off_target(titration_candidates, model=tiger)
344
+ off_target_candidates = pd.concat([on_target_predictions, titration_predictions])
345
+
346
+ else:
347
+ raise NotImplementedError
348
+
349
+ # check off-target effects for top guides
350
+ if check_off_targets and off_target_candidates is not None:
351
+ off_target_candidates = find_off_targets(off_target_candidates, status_update_fn)
352
+ off_target_predictions = predict_off_target(off_target_candidates, model=tiger)
353
+ if len(off_target_predictions) > 0:
354
+ off_target_predictions = off_target_predictions.sort_values(SCORE_COL, ascending=False)
355
+ off_target_predictions = off_target_predictions.reset_index(drop=True)
356
+
357
+ # finalize tables
358
+ for df in [on_target_predictions, titration_predictions, off_target_predictions]:
359
+ if df is not None and len(df) > 0:
360
+ for col in df.columns:
361
+ if ID_COL in col and set(df[col].unique()) == {'ManualEntry'}:
362
+ del df[col]
363
+ df[GUIDE_COL] = df[GUIDE_COL].apply(lambda s: s[::-1]) # reverse guide sequences
364
+ df[TARGET_COL] = df[TARGET_COL].apply(lambda seq: seq[CONTEXT_5P:len(seq) - CONTEXT_3P]) # remove context
365
+
366
+ return on_target_predictions, titration_predictions, off_target_predictions
367
+
368
+
369
+ if __name__ == '__main__':
370
+
371
+ # common arguments
372
+ parser = argparse.ArgumentParser()
373
+ parser.add_argument('--mode', type=str, default='titration')
374
+ parser.add_argument('--check_off_targets', action='store_true', default=False)
375
+ parser.add_argument('--fasta_path', type=str, default=None)
376
+ args = parser.parse_args()
377
+
378
+ # check for any existing results
379
+ if os.path.exists('on_target.csv') or os.path.exists('titration.csv') or os.path.exists('off_target.csv'):
380
+ raise FileExistsError('please rename or delete existing results')
381
+
382
+ # load transcripts from a directory of fasta files
383
+ if args.fasta_path is not None and os.path.exists(args.fasta_path):
384
+ df_transcripts = load_transcripts([os.path.join(args.fasta_path, f) for f in os.listdir(args.fasta_path)])
385
+
386
+ # otherwise consider simple test case with first 50 nucleotides from EIF3B-003's CDS
387
+ else:
388
+ df_transcripts = pd.DataFrame({
389
+ ID_COL: ['ManualEntry'],
390
+ SEQ_COL: ['ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC']})
391
+ df_transcripts.set_index(ID_COL, inplace=True)
392
+
393
+ # process in batches
394
+ batch = 0
395
+ num_batches = len(df_transcripts) // BATCH_SIZE_TRANSCRIPTS
396
+ num_batches += (len(df_transcripts) % BATCH_SIZE_TRANSCRIPTS > 0)
397
+ for idx in range(0, len(df_transcripts), BATCH_SIZE_TRANSCRIPTS):
398
+ batch += 1
399
+ print('Batch {:d} of {:d}'.format(batch, num_batches))
400
+
401
+ # run batch
402
+ idx_stop = min(idx + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts))
403
+ df_on_target, df_titration, df_off_target = tiger_exhibit(
404
+ transcripts=df_transcripts[idx:idx_stop],
405
+ mode=args.mode,
406
+ check_off_targets=args.check_off_targets
407
+ )
408
+
409
+ # save batch results
410
+ df_on_target.to_csv('on_target.csv', header=batch == 1, index=False, mode='a')
411
+ if df_titration is not None:
412
+ df_titration.to_csv('titration.csv', header=batch == 1, index=False, mode='a')
413
+ if df_off_target is not None:
414
+ df_off_target.to_csv('off_target.csv', header=batch == 1, index=False, mode='a')
415
+
416
+ # clear session to prevent memory blow up
417
+ tf.keras.backend.clear_session()