LfOreVEr commited on
Commit
5750b33
·
verified ·
1 Parent(s): 3f84a3d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +759 -0
app.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tiger
3
+ import cas9att
4
+ import cas9attvcf
5
+ import cas9off
6
+ import cas12
7
+ import cas12lstm
8
+ import cas12lstmvcf
9
+ import pandas as pd
10
+ import streamlit as st
11
+ import plotly.graph_objs as go
12
+ import numpy as np
13
+ from pathlib import Path
14
+ import zipfile
15
+ import io
16
+ import gtracks
17
+ import subprocess
18
+ import cyvcf2
19
+
20
+
21
+
22
+ # title and documentation
23
+ st.markdown(Path('crisprTool.md').read_text(), unsafe_allow_html=True)
24
+ st.divider()
25
+
26
+ CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
27
+
28
+ selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
29
+ cas9att_path = 'cas9_model/Cas9_MultiHeadAttention_weights.h5'
30
+ cas12lstm_path = 'cas12_model/BiLSTM_Cpf1_weights.h5'
31
+
32
+ #plot functions
33
+ def generate_coolbox_plot(bigwig_path, region, output_image_path):
34
+ frame = CoolBox()
35
+ frame += BigWig(bigwig_path)
36
+ frame.plot(region, savefig=output_image_path)
37
+
38
+ def generate_pygenometracks_plot(bigwig_file_path, region, output_image_path):
39
+ # Define the configuration for pyGenomeTracks
40
+ tracks = """
41
+ [bigwig]
42
+ file = {}
43
+ height = 4
44
+ color = blue
45
+ min_value = 0
46
+ max_value = 10
47
+ """.format(bigwig_file_path)
48
+
49
+ # Write the configuration to a temporary INI file
50
+ config_file_path = "pygenometracks.ini"
51
+ with open(config_file_path, 'w') as configfile:
52
+ configfile.write(tracks)
53
+
54
+ # Define the region to plot
55
+ region_dict = {'chrom': region.split(':')[0],
56
+ 'start': int(region.split(':')[1].split('-')[0]),
57
+ 'end': int(region.split(':')[1].split('-')[1])}
58
+
59
+ # Generate the plot
60
+ plot_tracks(tracks_file=config_file_path,
61
+ region=region_dict,
62
+ out_file_name=output_image_path)
63
+
64
+ @st.cache_data
65
+ def convert_df(df):
66
+ # IMPORTANT: Cache the conversion to prevent computation on every rerun
67
+ return df.to_csv().encode('utf-8')
68
+
69
+
70
+ def mode_change_callback():
71
+ if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}: # TODO: support titration
72
+ st.session_state.check_off_targets = False
73
+ st.session_state.disable_off_target_checkbox = True
74
+ else:
75
+ st.session_state.disable_off_target_checkbox = False
76
+
77
+
78
+ def progress_update(update_text, percent_complete):
79
+ with progress.container():
80
+ st.write(update_text)
81
+ st.progress(percent_complete / 100)
82
+
83
+
84
+ def initiate_run():
85
+ # initialize state variables
86
+ st.session_state.transcripts = None
87
+ st.session_state.input_error = None
88
+ st.session_state.on_target = None
89
+ st.session_state.titration = None
90
+ st.session_state.off_target = None
91
+
92
+ # initialize transcript DataFrame
93
+ transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL])
94
+
95
+ # manual entry
96
+ if st.session_state.entry_method == ENTRY_METHODS['manual']:
97
+ transcripts = pd.DataFrame({
98
+ tiger.ID_COL: ['ManualEntry'],
99
+ tiger.SEQ_COL: [st.session_state.manual_entry]
100
+ }).set_index(tiger.ID_COL)
101
+
102
+ # fasta file upload
103
+ elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
104
+ if st.session_state.fasta_entry is not None:
105
+ fasta_path = st.session_state.fasta_entry.name
106
+ with open(fasta_path, 'w') as f:
107
+ f.write(st.session_state.fasta_entry.getvalue().decode('utf-8'))
108
+ transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False)
109
+ os.remove(fasta_path)
110
+
111
+ # convert to upper case as used by tokenizer
112
+ transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T'))
113
+
114
+ # ensure all transcripts have unique identifiers
115
+ if transcripts.index.has_duplicates:
116
+ st.session_state.input_error = "Duplicate transcript ID's detected in fasta file"
117
+
118
+ # ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N
119
+ elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))):
120
+ st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us'
121
+
122
+ # ensure all transcripts satisfy length requirements
123
+ elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)):
124
+ st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN)
125
+
126
+ # run model if we have any transcripts
127
+ elif len(transcripts) > 0:
128
+ st.session_state.transcripts = transcripts
129
+
130
+ def parse_gene_annotations(file_path):
131
+ gene_dict = {}
132
+ with open(file_path, 'r') as file:
133
+ headers = file.readline().strip().split('\t') # Assuming tab-delimited file
134
+ symbol_idx = headers.index('Approved symbol') # Find index of 'Approved symbol'
135
+ ensembl_idx = headers.index('Ensembl gene ID') # Find index of 'Ensembl gene ID'
136
+ for line in file:
137
+ values = line.strip().split('\t')
138
+ # Ensure we have enough values and add mapping from symbol to Ensembl ID
139
+ if len(values) > max(symbol_idx, ensembl_idx):
140
+ gene_dict[values[symbol_idx]] = values[ensembl_idx]
141
+ return gene_dict
142
+
143
+ # Replace 'your_annotation_file.txt' with the path to your actual gene annotation file
144
+ gene_annotations = parse_gene_annotations('Human_genes_HUGO_02242024_annotation.txt')
145
+ gene_symbol_list = list(gene_annotations.keys()) # List of gene symbols for the autocomplete feature
146
+ # Check if the selected model is Cas9
147
+ if selected_model == 'Cas9':
148
+ # Use a radio button to select enzymes, making sure only one can be selected at a time
149
+ target_selection = st.radio(
150
+ "Select either on-target, on-target with mutation or off-target:",
151
+ ('on-target', 'mutation', 'off-target'),
152
+ key='target_selection'
153
+ )
154
+ if 'current_gene_symbol' not in st.session_state:
155
+ st.session_state['current_gene_symbol'] = ""
156
+
157
+ # Define a function to clean up old files
158
+ def clean_up_old_files(gene_symbol):
159
+ genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
160
+ bed_file_path = f"{gene_symbol}_crispr_targets.bed"
161
+ csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
162
+ for path in [genbank_file_path, bed_file_path, csv_file_path]:
163
+ if os.path.exists(path):
164
+ os.remove(path)
165
+
166
+
167
+ # Gene symbol entry with autocomplete-like feature
168
+ gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
169
+ format_func=lambda x: x if x else "")
170
+
171
+ # Handle gene symbol change and file cleanup
172
+ if gene_symbol != st.session_state['current_gene_symbol'] and gene_symbol:
173
+ if st.session_state['current_gene_symbol']:
174
+ # Clean up files only if a different gene symbol is entered and a previous symbol exists
175
+ clean_up_old_files(st.session_state['current_gene_symbol'])
176
+ # Update the session state with the new gene symbol
177
+ st.session_state['current_gene_symbol'] = gene_symbol
178
+
179
+ if target_selection == 'on-target':
180
+ # Prediction button
181
+ predict_button = st.button('Predict on-target')
182
+
183
+ if 'exons' not in st.session_state:
184
+ st.session_state['exons'] = []
185
+
186
+ # Process predictions
187
+ if predict_button and gene_symbol:
188
+ with st.spinner('Predicting... Please wait'):
189
+ predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path)
190
+ sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
191
+ st.session_state['on_target_results'] = sorted_predictions
192
+ st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
193
+ st.session_state['exons'] = exons # Store exon data
194
+
195
+ # Notify the user once the process is completed successfully.
196
+ st.success('Prediction completed!')
197
+ st.session_state['prediction_made'] = True
198
+
199
+ if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
200
+ ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
201
+ col1, col2, col3 = st.columns(3)
202
+ with col1:
203
+ st.markdown("**Genome**")
204
+ st.markdown("Homo sapiens")
205
+ with col2:
206
+ st.markdown("**Gene**")
207
+ st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
208
+ with col3:
209
+ st.markdown("**Nuclease**")
210
+ st.markdown("SpCas9")
211
+ # Include "Target" in the DataFrame's columns
212
+ try:
213
+ df = pd.DataFrame(st.session_state['on_target_results'],
214
+ columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"])
215
+ st.dataframe(df)
216
+ except ValueError as e:
217
+ st.error(f"DataFrame creation error: {e}")
218
+ # Optionally print or log the problematic data for debugging:
219
+ print(st.session_state['on_target_results'])
220
+
221
+ # Initialize Plotly figure
222
+ fig = go.Figure()
223
+
224
+ EXON_BASE = 0 # Base position for exons and CDS on the Y axis
225
+ EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
226
+
227
+ # Plot Exons as small markers on the X-axis
228
+ for exon in st.session_state['exons']:
229
+ exon_start, exon_end = exon['start'], exon['end']
230
+ fig.add_trace(go.Bar(
231
+ x=[(exon_start + exon_end) / 2],
232
+ y=[EXON_HEIGHT],
233
+ width=[exon_end - exon_start],
234
+ base=EXON_BASE,
235
+ marker_color='rgba(128, 0, 128, 0.5)',
236
+ name='Exon'
237
+ ))
238
+
239
+ VERTICAL_GAP = 0.2 # Gap between different ranks
240
+
241
+ # Define max and min Y values based on strand and rank
242
+ MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results
243
+ MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results
244
+
245
+ # Iterate over top 5 sorted predictions to create the plot
246
+ for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5
247
+ chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction
248
+ midpoint = (int(start) + int(end)) / 2
249
+
250
+ # Vertical position based on rank, modified by strand
251
+ y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else (
252
+ MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)
253
+
254
+ fig.add_trace(go.Scatter(
255
+ x=[midpoint],
256
+ y=[y_value],
257
+ mode='markers+text',
258
+ marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down',
259
+ size=12),
260
+ text=f"Rank: {i}", # Text label
261
+ hoverinfo='text',
262
+ hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' or strand == '+' else '-'}<br>Transcript: {transcript}<br>Prediction: {prediction_score:.4f}",
263
+ ))
264
+
265
+ # Update layout for clarity and interaction
266
+ fig.update_layout(
267
+ title='Top 5 gRNA Sequences by Prediction Score',
268
+ xaxis_title='Genomic Position',
269
+ yaxis_title='Strand',
270
+ yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
271
+ showlegend=False,
272
+ hovermode='x unified',
273
+ )
274
+
275
+ # Display the plot
276
+ st.plotly_chart(fig)
277
+
278
+ if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
279
+ gene_symbol = st.session_state['current_gene_symbol']
280
+ gene_sequence = st.session_state['gene_sequence']
281
+
282
+ # Define file paths
283
+ genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
284
+ bed_file_path = f"{gene_symbol}_crispr_targets.bed"
285
+ csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
286
+ plot_image_path = f"{gene_symbol}_gtracks_plot.png"
287
+
288
+
289
+ # Generate files
290
+ cas9att.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
291
+ cas9att.create_bed_file_from_df(df, bed_file_path)
292
+ cas9att.create_csv_from_df(df, csv_file_path)
293
+
294
+ # Prepare an in-memory buffer for the ZIP file
295
+ zip_buffer = io.BytesIO()
296
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
297
+ # For each file, add it to the ZIP file
298
+ zip_file.write(genbank_file_path)
299
+ zip_file.write(bed_file_path)
300
+ zip_file.write(csv_file_path)
301
+
302
+
303
+ # Important: move the cursor to the beginning of the BytesIO buffer before reading it
304
+ zip_buffer.seek(0)
305
+
306
+ # Specify the region you want to visualize
307
+ min_start = df['Start Pos'].min()
308
+ max_end = df['End Pos'].max()
309
+ chromosome = df['Chr'].mode()[0] # Assumes most common chromosome is the target
310
+ region = f"{chromosome}:{min_start}-{max_end}"
311
+
312
+ # Generate the pyGenomeTracks plot
313
+ gtracks_command = f"gtracks {region} {bed_file_path} {plot_image_path}"
314
+ subprocess.run(gtracks_command, shell=True)
315
+ st.image(plot_image_path)
316
+
317
+ # Display the download button for the ZIP file
318
+ st.download_button(
319
+ label="Download GenBank, BED, CSV files as ZIP",
320
+ data=zip_buffer.getvalue(),
321
+ file_name=f"{gene_symbol}_files.zip",
322
+ mime="application/zip"
323
+ )
324
+ elif target_selection == 'mutation':
325
+ # Prediction button
326
+ predict_button = st.button('Predict on-target')
327
+ vcf_reader = cyvcf2.VCF('SRR25934512.filter.snps.indels.vcf.gz')
328
+
329
+ if 'exons' not in st.session_state:
330
+ st.session_state['exons'] = []
331
+
332
+ # Process predictions
333
+ if predict_button and gene_symbol:
334
+ with st.spinner('Predicting... Please wait'):
335
+ predictions, gene_sequence, exons = cas9attvcf.process_gene(gene_symbol, vcf_reader, cas9att_path)
336
+ full_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)
337
+ sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
338
+ st.session_state['full_results'] = full_predictions
339
+ st.session_state['on_target_results'] = sorted_predictions
340
+ st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
341
+ st.session_state['exons'] = exons # Store exon data
342
+
343
+ # Notify the user once the process is completed successfully.
344
+ st.success('Prediction completed!')
345
+ st.session_state['prediction_made'] = True
346
+
347
+ if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
348
+ ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
349
+ col1, col2, col3 = st.columns(3)
350
+ with col1:
351
+ st.markdown("**Genome**")
352
+ st.markdown("Homo sapiens")
353
+ with col2:
354
+ st.markdown("**Gene**")
355
+ st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
356
+ with col3:
357
+ st.markdown("**Nuclease**")
358
+ st.markdown("SpCas9")
359
+ # Include "Target" in the DataFrame's columns
360
+ try:
361
+ df = pd.DataFrame(st.session_state['on_target_results'],
362
+ columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript", "Exon",
363
+ "Target",
364
+ "gRNA", "Prediction", "Is Mutation"])
365
+ df_full = pd.DataFrame(st.session_state['full_results'],
366
+ columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript",
367
+ "Exon", "Target",
368
+ "gRNA", "Prediction", "Is Mutation"])
369
+ st.dataframe(df)
370
+ except ValueError as e:
371
+ st.error(f"DataFrame creation error: {e}")
372
+ # Optionally print or log the problematic data for debugging:
373
+ print(st.session_state['on_target_results'])
374
+
375
+ if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
376
+ gene_symbol = st.session_state['current_gene_symbol']
377
+ gene_sequence = st.session_state['gene_sequence']
378
+
379
+ # Define file paths
380
+ genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
381
+ bed_file_path = f"{gene_symbol}_crispr_targets.bed"
382
+ csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
383
+ plot_image_path = f"{gene_symbol}_gtracks_plot.png"
384
+
385
+ # Generate files
386
+ cas9att.generate_genbank_file_from_df(df_full, gene_sequence, gene_symbol, genbank_file_path)
387
+ cas9att.create_bed_file_from_df(df_full, bed_file_path)
388
+ cas9att.create_csv_from_df(df_full, csv_file_path)
389
+
390
+ # Prepare an in-memory buffer for the ZIP file
391
+ zip_buffer = io.BytesIO()
392
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
393
+ # For each file, add it to the ZIP file
394
+ zip_file.write(genbank_file_path)
395
+ zip_file.write(bed_file_path)
396
+ zip_file.write(csv_file_path)
397
+
398
+ # Display the download button for the ZIP file
399
+ st.download_button(
400
+ label="Download GenBank, BED, CSV files as ZIP",
401
+ data=zip_buffer.getvalue(),
402
+ file_name=f"{gene_symbol}_files.zip",
403
+ mime="application/zip"
404
+ )
405
+
406
+ elif target_selection == 'off-target':
407
+ ENTRY_METHODS = dict(
408
+ manual='Manual entry of target sequence',
409
+ txt="txt file upload"
410
+ )
411
+ if __name__ == '__main__':
412
+ # app initialization for Cas9 off-target
413
+ if 'target_sequence' not in st.session_state:
414
+ st.session_state.target_sequence = None
415
+ if 'input_error' not in st.session_state:
416
+ st.session_state.input_error = None
417
+ if 'off_target_results' not in st.session_state:
418
+ st.session_state.off_target_results = None
419
+
420
+ # target sequence entry
421
+ st.selectbox(
422
+ label='How would you like to provide target sequences?',
423
+ options=ENTRY_METHODS.values(),
424
+ key='entry_method',
425
+ disabled=st.session_state.target_sequence is not None
426
+ )
427
+ if st.session_state.entry_method == ENTRY_METHODS['manual']:
428
+ st.text_input(
429
+ label='Enter on/off sequences:',
430
+ key='manual_entry',
431
+ placeholder='Enter on/off sequences like:GGGTGGGGGGAGTTTGCTCCAGG,AGGTGGGGTGA_TTTGCTCCAGG',
432
+ disabled=st.session_state.target_sequence is not None
433
+ )
434
+ elif st.session_state.entry_method == ENTRY_METHODS['txt']:
435
+ st.file_uploader(
436
+ label='Upload a txt file:',
437
+ key='txt_entry',
438
+ disabled=st.session_state.target_sequence is not None
439
+ )
440
+
441
+ # prediction button
442
+ if st.button('Predict off-target'):
443
+ if st.session_state.entry_method == ENTRY_METHODS['manual']:
444
+ user_input = st.session_state.manual_entry
445
+ if user_input: # Check if user_input is not empty
446
+ predictions = cas9off.process_input_and_predict(user_input, input_type='manual')
447
+ elif st.session_state.entry_method == ENTRY_METHODS['txt']:
448
+ uploaded_file = st.session_state.txt_entry
449
+ if uploaded_file is not None:
450
+ # Read the uploaded file content
451
+ file_content = uploaded_file.getvalue().decode("utf-8")
452
+ predictions = cas9off.process_input_and_predict(file_content, input_type='manual')
453
+
454
+ st.session_state.off_target_results = predictions
455
+ else:
456
+ predictions = None
457
+ progress = st.empty()
458
+
459
+ # input error display
460
+ error = st.empty()
461
+ if st.session_state.input_error is not None:
462
+ error.error(st.session_state.input_error, icon="🚨")
463
+ else:
464
+ error.empty()
465
+
466
+ # off-target results display
467
+ off_target_results = st.empty()
468
+ if st.session_state.off_target_results is not None:
469
+ with off_target_results.container():
470
+ if len(st.session_state.off_target_results) > 0:
471
+ st.write('Off-target predictions:', st.session_state.off_target_results)
472
+ st.download_button(
473
+ label='Download off-target predictions',
474
+ data=convert_df(st.session_state.off_target_results),
475
+ file_name='off_target_results.csv',
476
+ mime='text/csv'
477
+ )
478
+ else:
479
+ st.write('No significant off-target effects detected!')
480
+ else:
481
+ off_target_results.empty()
482
+
483
+ # running the CRISPR-Net model for off-target predictions
484
+ if st.session_state.target_sequence is not None:
485
+ st.session_state.off_target_results = cas9off.predict_off_targets(
486
+ target_sequence=st.session_state.target_sequence,
487
+ status_update_fn=progress_update
488
+ )
489
+ st.session_state.target_sequence = None
490
+ st.experimental_rerun()
491
+
492
+ elif selected_model == 'Cas12':
493
+ def visualize_genomic_data():
494
+ fig = go.Figure()
495
+
496
+ EXON_BASE = 0 # Base position for exons and CDS on the Y axis
497
+ EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
498
+
499
+ # Plot Exons as small markers on the X-axis
500
+ for exon in st.session_state['exons']:
501
+ try:
502
+ exon_start, exon_end = int(exon['start']), int(exon['end'])
503
+ fig.add_trace(go.Bar(
504
+ x=[(exon_start + exon_end) / 2],
505
+ y=[EXON_HEIGHT],
506
+ width=[exon_end - exon_start],
507
+ base=EXON_BASE,
508
+ marker_color='rgba(128, 0, 128, 0.5)',
509
+ name='Exon'
510
+ ))
511
+ except ValueError:
512
+ st.error("Error in exon positions. Exon positions should be numeric.")
513
+
514
+ VERTICAL_GAP = 0.2 # Gap between different ranks
515
+
516
+ # Define max and min Y values based on strand and rank
517
+ MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results
518
+ MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results
519
+
520
+ # Iterate over top 5 sorted predictions to create the plot
521
+ for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5
522
+ try:
523
+ start, end = int(prediction['Start Pos']), int(prediction['End Pos'])
524
+ midpoint = (start + end) / 2
525
+ strand = prediction['Strand']
526
+ y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand in ['1', '+'] else (
527
+ MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)
528
+
529
+ fig.add_trace(go.Scatter(
530
+ x=[midpoint],
531
+ y=[y_value],
532
+ mode='markers+text',
533
+ marker=dict(symbol='triangle-up' if strand in ['1', '+'] else 'triangle-down', size=12),
534
+ text=f"Rank: {i}",
535
+ hoverinfo='text',
536
+ hovertext=f"Rank: {i}<br>Chromosome: {prediction['Chr']}<br>Target Sequence: {prediction['Target']}<br>gRNA: {prediction['gRNA']}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand in ['1', '+'] else '-'}<br>Transcript: {prediction['Transcript']}<br>Prediction: {prediction['Prediction']:.4f}",
537
+ ))
538
+ except ValueError:
539
+ st.error("Error in prediction positions. Start and end positions should be numeric.")
540
+
541
+ # Update layout for clarity and interaction
542
+ fig.update_layout(
543
+ title='Top 5 gRNA Sequences by Prediction Score',
544
+ xaxis_title='Genomic Position',
545
+ yaxis_title='Strand',
546
+ yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
547
+ showlegend=False,
548
+ hovermode='x unified',
549
+ )
550
+
551
+ # Display the plot
552
+ st.plotly_chart(fig)
553
+
554
+ # File generation and download
555
+ generate_and_download_files(df, gene_symbol)
556
+
557
+
558
+ def generate_and_download_files(df, gene_symbol):
559
+ genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
560
+ bed_file_path = f"{gene_symbol}_crispr_targets.bed"
561
+ csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
562
+ df.to_csv(csv_file_path, index=False)
563
+ # Assume functions to generate GenBank and BED are defined in cas12lstm or cas12lstmvcf
564
+ cas12lstm.generate_genbank_file_from_df(df, gene_symbol, genbank_file_path)
565
+ cas12lstm.create_bed_file_from_df(df, bed_file_path)
566
+
567
+ zip_buffer = io.BytesIO()
568
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
569
+ zip_file.write(genbank_file_path)
570
+ zip_file.write(bed_file_path)
571
+ zip_file.write(csv_file_path)
572
+ zip_buffer.seek(0)
573
+ st.download_button("Download GenBank, BED, CSV files as ZIP", data=zip_buffer.getvalue(),
574
+ file_name=f"{gene_symbol}_files.zip", mime="application/zip")
575
+
576
+
577
+ def display_results(predictions, gene_sequence, exons, gene_symbol):
578
+ st.success('Prediction completed!')
579
+ ensembl_id = gene_annotations.get(gene_symbol, 'Unknown')
580
+ st.write(f"**Genome:** Homo sapiens")
581
+ st.write(f"**Gene:** {gene_symbol} : {ensembl_id} (primary)")
582
+ st.write("**Nuclease:** Cas12")
583
+ df = pd.DataFrame(predictions,
584
+ columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA",
585
+ "Prediction"])
586
+ st.dataframe(df)
587
+
588
+ # Visualization and file generation as demonstrated in the Cas9 example
589
+ visualize_and_generate_files(df, gene_sequence, exons, gene_symbol)
590
+
591
+
592
+ cas12target_selection = st.radio(
593
+ "Select either regular or mutation:",
594
+ ('regular', 'mutation'),
595
+ key='cas12target_selection'
596
+ )
597
+ if 'current_gene_symbol' not in st.session_state:
598
+ st.session_state['current_gene_symbol'] = ""
599
+
600
+ def clean_up_old_files(gene_symbol):
601
+ for suffix in ['_crispr_targets.gb', '_crispr_targets.bed', '_crispr_predictions.csv']:
602
+ file_path = f"{gene_symbol}{suffix}"
603
+ if os.path.exists(file_path):
604
+ os.remove(file_path)
605
+
606
+ gene_symbol = st.selectbox(
607
+ 'Enter a Gene Symbol:',
608
+ [''] + gene_symbol_list,
609
+ key='gene_symbol',
610
+ format_func=lambda x: x if x else ""
611
+ )
612
+
613
+ if gene_symbol != st.session_state['current_gene_symbol']:
614
+ if st.session_state['current_gene_symbol']:
615
+ clean_up_old_files(st.session_state['current_gene_symbol'])
616
+ st.session_state['current_gene_symbol'] = gene_symbol
617
+
618
+ if cas12target_selection == 'regular':
619
+ if st.button('Predict cas12 Regular'):
620
+ with st.spinner('Predicting... Please wait'):
621
+ predictions, gene_sequence, exons = cas12lstm.process_gene(gene_symbol, cas12lstm_path)
622
+ sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
623
+ display_results(sorted_predictions, gene_sequence, exons, gene_symbol)
624
+ elif cas12target_selection == 'mutation':
625
+ vcf_reader = cyvcf2.VCF('SRR25934512.filter.snps.indels.vcf.gz')
626
+ if st.button('Predict cas12 Mutation'):
627
+ with st.spinner('Predicting... Please wait'):
628
+ predictions, gene_sequence, exons = cas12lstmvcf.process_gene(gene_symbol, vcf_reader, cas12lstm_path)
629
+ sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
630
+ display_results(sorted_predictions, gene_sequence, exons, gene_symbol)
631
+
632
+ elif selected_model == 'Cas13d':
633
+ ENTRY_METHODS = dict(
634
+ manual='Manual entry of single transcript',
635
+ fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)"
636
+ )
637
+
638
+ if __name__ == '__main__':
639
+ # app initialization
640
+ if 'mode' not in st.session_state:
641
+ st.session_state.mode = tiger.RUN_MODES['all']
642
+ st.session_state.disable_off_target_checkbox = True
643
+ if 'entry_method' not in st.session_state:
644
+ st.session_state.entry_method = ENTRY_METHODS['manual']
645
+ if 'transcripts' not in st.session_state:
646
+ st.session_state.transcripts = None
647
+ if 'input_error' not in st.session_state:
648
+ st.session_state.input_error = None
649
+ if 'on_target' not in st.session_state:
650
+ st.session_state.on_target = None
651
+ if 'titration' not in st.session_state:
652
+ st.session_state.titration = None
653
+ if 'off_target' not in st.session_state:
654
+ st.session_state.off_target = None
655
+
656
+ # mode selection
657
+ col1, col2 = st.columns([0.65, 0.35])
658
+ with col1:
659
+ st.radio(
660
+ label='What do you want to predict?',
661
+ options=tuple(tiger.RUN_MODES.values()),
662
+ key='mode',
663
+ on_change=mode_change_callback,
664
+ disabled=st.session_state.transcripts is not None,
665
+ )
666
+ with col2:
667
+ st.checkbox(
668
+ label='Find off-target effects (slow)',
669
+ key='check_off_targets',
670
+ disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None
671
+ )
672
+
673
+ # transcript entry
674
+ st.selectbox(
675
+ label='How would you like to provide transcript(s) of interest?',
676
+ options=ENTRY_METHODS.values(),
677
+ key='entry_method',
678
+ disabled=st.session_state.transcripts is not None
679
+ )
680
+ if st.session_state.entry_method == ENTRY_METHODS['manual']:
681
+ st.text_input(
682
+ label='Enter a target transcript:',
683
+ key='manual_entry',
684
+ placeholder='Upper or lower case',
685
+ disabled=st.session_state.transcripts is not None
686
+ )
687
+ elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
688
+ st.file_uploader(
689
+ label='Upload a fasta file:',
690
+ key='fasta_entry',
691
+ disabled=st.session_state.transcripts is not None
692
+ )
693
+
694
+ # let's go!
695
+ st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None)
696
+ progress = st.empty()
697
+
698
+ # input error
699
+ error = st.empty()
700
+ if st.session_state.input_error is not None:
701
+ error.error(st.session_state.input_error, icon="🚨")
702
+ else:
703
+ error.empty()
704
+
705
+ # on-target results
706
+ on_target_results = st.empty()
707
+ if st.session_state.on_target is not None:
708
+ with on_target_results.container():
709
+ st.write('On-target predictions:', st.session_state.on_target)
710
+ st.download_button(
711
+ label='Download on-target predictions',
712
+ data=convert_df(st.session_state.on_target),
713
+ file_name='on_target.csv',
714
+ mime='text/csv'
715
+ )
716
+ else:
717
+ on_target_results.empty()
718
+
719
+ # titration results
720
+ titration_results = st.empty()
721
+ if st.session_state.titration is not None:
722
+ with titration_results.container():
723
+ st.write('Titration predictions:', st.session_state.titration)
724
+ st.download_button(
725
+ label='Download titration predictions',
726
+ data=convert_df(st.session_state.titration),
727
+ file_name='titration.csv',
728
+ mime='text/csv'
729
+ )
730
+ else:
731
+ titration_results.empty()
732
+
733
+ # off-target results
734
+ off_target_results = st.empty()
735
+ if st.session_state.off_target is not None:
736
+ with off_target_results.container():
737
+ if len(st.session_state.off_target) > 0:
738
+ st.write('Off-target predictions:', st.session_state.off_target)
739
+ st.download_button(
740
+ label='Download off-target predictions',
741
+ data=convert_df(st.session_state.off_target),
742
+ file_name='off_target.csv',
743
+ mime='text/csv'
744
+ )
745
+ else:
746
+ st.write('We did not find any off-target effects!')
747
+ else:
748
+ off_target_results.empty()
749
+
750
+ # keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns)
751
+ if st.session_state.transcripts is not None:
752
+ st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit(
753
+ transcripts=st.session_state.transcripts,
754
+ mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode],
755
+ check_off_targets=st.session_state.check_off_targets,
756
+ status_update_fn=progress_update
757
+ )
758
+ st.session_state.transcripts = None
759
+ st.experimental_rerun()