aiqcamp commited on
Commit
4b85b50
·
verified ·
1 Parent(s): a476574

Create app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +591 -0
app-backup.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+
3
+ # install environment goods
4
+ #os.system("pip -q install dgl -f https://data.dgl.ai/wheels/cu113/repo.html")
5
+ os.system('pip install dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html')
6
+ #os.system('pip install gradio')
7
+ os.environ["DGLBACKEND"] = "pytorch"
8
+ #os.system(f'pip install -r ./PROTEIN_GENERATOR/requirements.txt')
9
+ print('Modules installed')
10
+
11
+ #os.system('pip install --force gradio==3.36.1')
12
+ #os.system('pip install gradio_client==0.2.7')
13
+ #os.system('pip install \"numpy<2\"')
14
+ #os.system('pip install numpy --upgrade')
15
+ #os.system('pip install --force numpy==1.24.1')
16
+
17
+
18
+ if not os.path.exists('./SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'):
19
+ print('Downloading model weights 1')
20
+ os.system('wget http://files.ipd.uw.edu/pub/sequence_diffusion/checkpoints/SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt')
21
+ print('Successfully Downloaded')
22
+
23
+ if not os.path.exists('./SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'):
24
+ print('Downloading model weights 2')
25
+ os.system('wget http://files.ipd.uw.edu/pub/sequence_diffusion/checkpoints/SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt')
26
+ print('Successfully Downloaded')
27
+
28
+ import numpy as np
29
+ import gradio as gr
30
+ import py3Dmol
31
+ from io import StringIO
32
+ import json
33
+ import secrets
34
+ import copy
35
+ import matplotlib.pyplot as plt
36
+ from utils.sampler import HuggingFace_sampler
37
+ from utils.parsers_inference import parse_pdb
38
+ from model.util import writepdb
39
+ from utils.inpainting_util import *
40
+
41
+
42
+ plt.rcParams.update({'font.size': 13})
43
+
44
+ with open('./tmp/args.json','r') as f:
45
+ args = json.load(f)
46
+
47
+ # manually set checkpoint to load
48
+ args['checkpoint'] = None
49
+ args['dump_trb'] = False
50
+ args['dump_args'] = True
51
+ args['save_best_plddt'] = True
52
+ args['T'] = 25
53
+ args['strand_bias'] = 0.0
54
+ args['loop_bias'] = 0.0
55
+ args['helix_bias'] = 0.0
56
+
57
+
58
+
59
+ def protein_diffusion_model(sequence, seq_len, helix_bias, strand_bias, loop_bias,
60
+ secondary_structure, aa_bias, aa_bias_potential,
61
+ #target_charge, target_ph, charge_potential,
62
+ num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
63
+ contigs, pssm, seq_mask, str_mask, rewrite_pdb):
64
+
65
+ dssp_checkpoint = './SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'
66
+ og_checkpoint = './SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'
67
+
68
+ model_args = copy.deepcopy(args)
69
+
70
+ # make sampler
71
+ S = HuggingFace_sampler(args=model_args)
72
+
73
+ # get random prefix
74
+ S.out_prefix = './tmp/'+secrets.token_hex(nbytes=10).upper()
75
+
76
+ # set args
77
+ S.args['checkpoint'] = None
78
+ S.args['dump_trb'] = False
79
+ S.args['dump_args'] = True
80
+ S.args['save_best_plddt'] = True
81
+ S.args['T'] = 20
82
+ S.args['strand_bias'] = 0.0
83
+ S.args['loop_bias'] = 0.0
84
+ S.args['helix_bias'] = 0.0
85
+ S.args['potentials'] = None
86
+ S.args['potential_scale'] = None
87
+ S.args['aa_composition'] = None
88
+
89
+
90
+ # get sequence if entered and make sure all chars are valid
91
+ alt_aa_dict = {'B':['D','N'],'J':['I','L'],'U':['C'],'Z':['E','Q'],'O':['K']}
92
+ if sequence not in ['',None]:
93
+ L = len(sequence)
94
+ aa_seq = []
95
+ for aa in sequence.upper():
96
+ if aa in alt_aa_dict.keys():
97
+ aa_seq.append(np.random.choice(alt_aa_dict[aa]))
98
+ else:
99
+ aa_seq.append(aa)
100
+
101
+ S.args['sequence'] = aa_seq
102
+ elif contigs not in ['',None]:
103
+ S.args['contigs'] = [contigs]
104
+ else:
105
+ S.args['contigs'] = [f'{seq_len}']
106
+ L = int(seq_len)
107
+
108
+ print('DEBUG: ',rewrite_pdb)
109
+ if rewrite_pdb not in ['',None]:
110
+ S.args['pdb'] = rewrite_pdb.name
111
+
112
+ if seq_mask not in ['',None]:
113
+ S.args['inpaint_seq'] = [seq_mask]
114
+ if str_mask not in ['',None]:
115
+ S.args['inpaint_str'] = [str_mask]
116
+
117
+ if secondary_structure in ['',None]:
118
+ secondary_structure = None
119
+ else:
120
+ secondary_structure = ''.join(['E' if x == 'S' else x for x in secondary_structure])
121
+ if L < len(secondary_structure):
122
+ secondary_structure = secondary_structure[:len(sequence)]
123
+ elif L == len(secondary_structure):
124
+ pass
125
+ else:
126
+ dseq = L - len(secondary_structure)
127
+ secondary_structure += secondary_structure[-1]*dseq
128
+
129
+
130
+ # potentials
131
+ potential_list = []
132
+ potential_bias_list = []
133
+
134
+ if aa_bias not in ['',None]:
135
+ potential_list.append('aa_bias')
136
+ S.args['aa_composition'] = aa_bias
137
+ if aa_bias_potential in ['',None]:
138
+ aa_bias_potential = 3
139
+ potential_bias_list.append(str(aa_bias_potential))
140
+ '''
141
+ if target_charge not in ['',None]:
142
+ potential_list.append('charge')
143
+ if charge_potential in ['',None]:
144
+ charge_potential = 1
145
+ potential_bias_list.append(str(charge_potential))
146
+ S.args['target_charge'] = float(target_charge)
147
+ if target_ph in ['',None]:
148
+ target_ph = 7.4
149
+ S.args['target_pH'] = float(target_ph)
150
+ '''
151
+
152
+ if hydrophobic_target_score not in ['',None]:
153
+ potential_list.append('hydrophobic')
154
+ S.args['hydrophobic_score'] = float(hydrophobic_target_score)
155
+ if hydrophobic_potential in ['',None]:
156
+ hydrophobic_potential = 3
157
+ potential_bias_list.append(str(hydrophobic_potential))
158
+
159
+ if pssm not in ['',None]:
160
+ potential_list.append('PSSM')
161
+ potential_bias_list.append('5')
162
+ S.args['PSSM'] = pssm.name
163
+
164
+
165
+ if len(potential_list) > 0:
166
+ S.args['potentials'] = ','.join(potential_list)
167
+ S.args['potential_scale'] = ','.join(potential_bias_list)
168
+
169
+
170
+ # normalise secondary_structure bias from range 0-0.3
171
+ S.args['secondary_structure'] = secondary_structure
172
+ S.args['helix_bias'] = helix_bias
173
+ S.args['strand_bias'] = strand_bias
174
+ S.args['loop_bias'] = loop_bias
175
+
176
+ # set T
177
+ if num_steps in ['',None]:
178
+ S.args['T'] = 20
179
+ else:
180
+ S.args['T'] = int(num_steps)
181
+
182
+ # noise
183
+ if 'normal' in noise:
184
+ S.args['sample_distribution'] = noise
185
+ S.args['sample_distribution_gmm_means'] = [0]
186
+ S.args['sample_distribution_gmm_variances'] = [1]
187
+ elif 'gmm2' in noise:
188
+ S.args['sample_distribution'] = noise
189
+ S.args['sample_distribution_gmm_means'] = [-1,1]
190
+ S.args['sample_distribution_gmm_variances'] = [1,1]
191
+ elif 'gmm3' in noise:
192
+ S.args['sample_distribution'] = noise
193
+ S.args['sample_distribution_gmm_means'] = [-1,0,1]
194
+ S.args['sample_distribution_gmm_variances'] = [1,1,1]
195
+
196
+
197
+
198
+ if secondary_structure not in ['',None] or helix_bias+strand_bias+loop_bias > 0:
199
+ S.args['checkpoint'] = dssp_checkpoint
200
+ S.args['d_t1d'] = 29
201
+ print('using dssp checkpoint')
202
+ else:
203
+ S.args['checkpoint'] = og_checkpoint
204
+ S.args['d_t1d'] = 24
205
+ print('using og checkpoint')
206
+
207
+
208
+ for k,v in S.args.items():
209
+ print(f"{k} --> {v}")
210
+
211
+ # init S
212
+ S.model_init()
213
+ S.diffuser_init()
214
+ S.setup()
215
+
216
+ # sampling loop
217
+ plddt_data = []
218
+ for j in range(S.max_t):
219
+ print(f'on step {j}')
220
+ output_seq, output_pdb, plddt = S.take_step_get_outputs(j)
221
+ plddt_data.append(plddt)
222
+ yield output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
223
+
224
+ output_seq, output_pdb, plddt = S.get_outputs()
225
+ yield output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
226
+
227
+ def get_plddt_plot(plddt_data, max_t):
228
+ x = [i+1 for i in range(len(plddt_data))]
229
+ fig, ax = plt.subplots(figsize=(15,6))
230
+ ax.plot(x,plddt_data,color='#661dbf', linewidth=3,marker='o')
231
+ ax.set_xticks([i+1 for i in range(max_t)])
232
+ ax.set_yticks([(i+1)/10 for i in range(10)])
233
+ ax.set_ylim([0,1])
234
+ ax.set_ylabel('model confidence (plddt)')
235
+ ax.set_xlabel('diffusion steps (t)')
236
+ return fig
237
+
238
+ def display_pdb(path_to_pdb):
239
+ '''
240
+ #function to display pdb in py3dmol
241
+ '''
242
+ pdb = open(path_to_pdb, "r").read()
243
+
244
+ view = py3Dmol.view(width=500, height=500)
245
+ view.addModel(pdb, "pdb")
246
+ view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}})
247
+ view.zoomTo()
248
+ output = view._make_html().replace("'", '"')
249
+ print(view._make_html())
250
+ x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" # do not use ' in this input
251
+
252
+ return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera;
253
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
254
+ allow-scripts allow-same-origin allow-popups
255
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
256
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
257
+
258
+ '''
259
+ return f"""<iframe style="width: 100%; height:700px" name="result" allow="midi; geolocation; microphone; camera;
260
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
261
+ allow-scripts allow-same-origin allow-popups
262
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
263
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
264
+ '''
265
+
266
+
267
+
268
+ # MOTIF SCAFFOLDING
269
+ def get_motif_preview(pdb_id, contigs):
270
+ '''
271
+ #function to display selected motif in py3dmol
272
+ '''
273
+ input_pdb = fetch_pdb(pdb_id=pdb_id.lower())
274
+
275
+ # rewrite pdb
276
+ parse = parse_pdb(input_pdb)
277
+ #output_name = './rewrite_'+input_pdb.split('/')[-1]
278
+ #writepdb(output_name, torch.tensor(parse_og['xyz']),torch.tensor(parse_og['seq']))
279
+ #parse = parse_pdb(output_name)
280
+ output_name = input_pdb
281
+
282
+ pdb = open(output_name, "r").read()
283
+ view = py3Dmol.view(width=500, height=500)
284
+ view.addModel(pdb, "pdb")
285
+
286
+ if contigs in ['',0]:
287
+ contigs = ['0']
288
+ else:
289
+ contigs = [contigs]
290
+
291
+ print('DEBUG: ',contigs)
292
+
293
+ pdb_map = get_mappings(ContigMap(parse,contigs))
294
+ print('DEBUG: ',pdb_map)
295
+ print('DEBUG: ',pdb_map['con_ref_idx0'])
296
+ roi = [x[1]-1 for x in pdb_map['con_ref_pdb_idx']]
297
+
298
+ colormap = {0:'#D3D3D3', 1:'#F74CFF'}
299
+ colors = {i+1: colormap[1] if i in roi else colormap[0] for i in range(parse['xyz'].shape[0])}
300
+ view.setStyle({"cartoon": {"colorscheme": {"prop": "resi", "map": colors}}})
301
+ view.zoomTo()
302
+ output = view._make_html().replace("'", '"')
303
+ print(view._make_html())
304
+ x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" # do not use ' in this input
305
+
306
+ return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera;
307
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
308
+ allow-scripts allow-same-origin allow-popups
309
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
310
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""", output_name
311
+
312
+ def fetch_pdb(pdb_id=None):
313
+ if pdb_id is None or pdb_id == "":
314
+ return None
315
+ else:
316
+ os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_id}.pdb")
317
+ return f"{pdb_id}.pdb"
318
+
319
+ # MSA AND PSSM GUIDANCE
320
+ def save_pssm(file_upload):
321
+ filename = file_upload.name
322
+ orig_name = file_upload.orig_name
323
+ if filename.split('.')[-1] in ['fasta', 'a3m']:
324
+ return msa_to_pssm(file_upload)
325
+ return filename
326
+
327
+ def msa_to_pssm(msa_file):
328
+ # Define the lookup table for converting amino acids to indices
329
+ aa_to_index = {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 'L': 10,
330
+ 'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, 'X': 20, '-': 21}
331
+ # Open the FASTA file and read the sequences
332
+ records = list(SeqIO.parse(msa_file.name, "fasta"))
333
+
334
+ assert len(records) >= 1, "MSA must contain more than one protein sequecne."
335
+
336
+ first_seq = str(records[0].seq)
337
+ aligned_seqs = [first_seq]
338
+ # print(aligned_seqs)
339
+ # Perform sequence alignment using the Needleman-Wunsch algorithm
340
+ aligner = Align.PairwiseAligner()
341
+ aligner.open_gap_score = -0.7
342
+ aligner.extend_gap_score = -0.3
343
+ for record in records[1:]:
344
+ alignment = aligner.align(first_seq, str(record.seq))[0]
345
+ alignment = alignment.format().split("\n")
346
+ al1 = alignment[0]
347
+ al2 = alignment[2]
348
+ al1_fin = ""
349
+ al2_fin = ""
350
+ percent_gap = al2.count('-')/ len(al2)
351
+ if percent_gap > 0.4:
352
+ continue
353
+ for i in range(len(al1)):
354
+ if al1[i] != '-':
355
+ al1_fin += al1[i]
356
+ al2_fin += al2[i]
357
+ aligned_seqs.append(str(al2_fin))
358
+ # Get the length of the aligned sequences
359
+ aligned_seq_length = len(first_seq)
360
+ # Initialize the position scoring matrix
361
+ matrix = np.zeros((22, aligned_seq_length))
362
+ # Iterate through the aligned sequences and count the amino acids at each position
363
+ for seq in aligned_seqs:
364
+ #print(seq)
365
+ for i in range(aligned_seq_length):
366
+ if i == len(seq):
367
+ break
368
+ amino_acid = seq[i]
369
+ if amino_acid.upper() not in aa_to_index.keys():
370
+ continue
371
+ else:
372
+ aa_index = aa_to_index[amino_acid.upper()]
373
+ matrix[aa_index, i] += 1
374
+ # Normalize the counts to get the frequency of each amino acid at each position
375
+ matrix /= len(aligned_seqs)
376
+ print(len(aligned_seqs))
377
+ matrix[20:,]=0
378
+
379
+ outdir = ".".join(msa_file.name.split('.')[:-1]) + ".csv"
380
+ np.savetxt(outdir, matrix[:21,:].T, delimiter=",")
381
+ return outdir
382
+
383
+ def get_pssm(fasta_msa, input_pssm):
384
+
385
+ if input_pssm not in ['',None]:
386
+ outdir = input_pssm.name
387
+ else:
388
+ outdir = save_pssm(fasta_msa)
389
+
390
+ pssm = np.loadtxt(outdir, delimiter=",", dtype=float)
391
+ fig, ax = plt.subplots(figsize=(15,6))
392
+ plt.imshow(torch.permute(torch.tensor(pssm),(1,0)))
393
+
394
+ return fig, outdir
395
+
396
+
397
+ #toggle options
398
+ def toggle_seq_input(choice):
399
+ if choice == "protein length":
400
+ return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
401
+ elif choice == "custom sequence":
402
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
403
+
404
+ def toggle_secondary_structure(choice):
405
+ if choice == "sliders":
406
+ return gr.update(visible=True, value=None),gr.update(visible=True, value=None),gr.update(visible=True, value=None),gr.update(visible=False, value=None)
407
+ elif choice == "explicit":
408
+ return gr.update(visible=False, value=None),gr.update(visible=False, value=None),gr.update(visible=False, value=None),gr.update(visible=True, value=None)
409
+
410
+
411
+ # Define the Gradio interface
412
+ with gr.Blocks(theme='ParityError/Interstellar') as demo:
413
+
414
+
415
+
416
+ #with gr.Row().style(equal_height=False):
417
+ with gr.Row():
418
+ with gr.Column():
419
+ with gr.Tabs():
420
+ with gr.TabItem("Inputs"):
421
+ gr.Markdown("""## INPUTS""")
422
+ gr.Markdown("""#### Start Sequence
423
+ Specify the protein length for complete unconditional generation, or scaffold a motif (or your name) using the custom sequence input""")
424
+ seq_opt = gr.Radio(["protein length","custom sequence"], label="How would you like to specify the starting sequence?", value='protein length')
425
+
426
+ sequence = gr.Textbox(label="custom sequence", lines=1, placeholder='AMINO ACIDS: A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y\n MASK TOKEN: X', visible=False)
427
+ seq_len = gr.Slider(minimum=5.0, maximum=250.0, label="protein length", value=100, visible=True)
428
+
429
+ seq_opt.change(fn=toggle_seq_input,
430
+ inputs=[seq_opt],
431
+ outputs=[seq_len, sequence],
432
+ queue=False)
433
+
434
+ gr.Markdown("""### Optional Parameters""")
435
+ with gr.Accordion(label='Secondary Structure',open=True):
436
+ gr.Markdown("""Try changing the sliders or inputing explicit secondary structure conditioning for each residue""")
437
+ sec_str_opt = gr.Radio(["sliders","explicit"], label="How would you like to specify secondary structure?", value='sliders')
438
+
439
+ secondary_structure = gr.Textbox(label="secondary structure", lines=1, placeholder='HELIX = H STRAND = S LOOP = L MASK = X(must be the same length as input sequence)', visible=False)
440
+
441
+ with gr.Column():
442
+ helix_bias = gr.Slider(minimum=0.0, maximum=0.05, label="helix bias", visible=True)
443
+ strand_bias = gr.Slider(minimum=0.0, maximum=0.05, label="strand bias", visible=True)
444
+ loop_bias = gr.Slider(minimum=0.0, maximum=0.20, label="loop bias", visible=True)
445
+
446
+ sec_str_opt.change(fn=toggle_secondary_structure,
447
+ inputs=[sec_str_opt],
448
+ outputs=[helix_bias,strand_bias,loop_bias,secondary_structure],
449
+ queue=False)
450
+
451
+ with gr.Accordion(label='Amino Acid Compositional Bias',open=False):
452
+ gr.Markdown("""Bias sequence composition for particular amino acids by specifying the one letter code followed by the fraction to bias. This can be input as a list for example: W0.2,E0.1""")
453
+ with gr.Row():
454
+ aa_bias = gr.Textbox(label="aa bias", lines=1, placeholder='specify one letter AA and fraction to bias, for example W0.1 or M0.1,K0.1' )
455
+ aa_bias_potential = gr.Textbox(label="aa bias scale", lines=1, placeholder='AA Bias potential scale (recomended range 1.0-5.0)')
456
+
457
+ '''
458
+ with gr.Accordion(label='Charge Bias',open=False):
459
+ gr.Markdown("""Bias for a specified net charge at a particular pH using the boxes below""")
460
+ with gr.Row():
461
+ target_charge = gr.Textbox(label="net charge", lines=1, placeholder='net charge to target')
462
+ target_ph = gr.Textbox(label="pH", lines=1, placeholder='pH at which net charge is desired')
463
+ charge_potential = gr.Textbox(label="charge potential scale", lines=1, placeholder='charge potential scale (recomended range 1.0-5.0)')
464
+ '''
465
+
466
+ with gr.Accordion(label='Hydrophobic Bias',open=False):
467
+ gr.Markdown("""Bias for or against hydrophobic composition, to get more soluble proteins, bias away with a negative target score (ex. -5)""")
468
+ with gr.Row():
469
+ hydrophobic_target_score = gr.Textbox(label="hydrophobic score", lines=1, placeholder='hydrophobic score to target (negative score is good for solublility)')
470
+ hydrophobic_potential = gr.Textbox(label="hydrophobic potential scale", lines=1, placeholder='hydrophobic potential scale (recomended range 1.0-2.0)')
471
+
472
+ with gr.Accordion(label='Diffusion Params',open=False):
473
+ gr.Markdown("""Increasing T to more steps can be helpful for harder design challenges, sampling from different distributions can change the sequence and structural composition""")
474
+ with gr.Row():
475
+ num_steps = gr.Textbox(label="T", lines=1, placeholder='number of diffusion steps (25 or less will speed things up)')
476
+ noise = gr.Dropdown(['normal','gmm2 [-1,1]','gmm3 [-1,0,1]'], label='noise type', value='normal')
477
+
478
+ with gr.TabItem("Motif Selection"):
479
+
480
+ gr.Markdown("""### Motif Selection Preview""")
481
+ gr.Markdown('Contigs explained: to grab residues (seq and str) on a pdb chain you will provide the chain letter followed by a range of residues as indexed in the pdb file for example (A3-10) is the syntax to select residues 3-10 on chain A (the chain always needs to be specified). To add diffused residues to either side of this motif you can specify a range or discrete value without a chain letter infront. To add 15 residues before the motif and 20-30 residues (randomly sampled) after use the following syntax: 15,A3-10,20-30 commas are used to separate regions selected from the pdb and designed (diffused) resiudes which will be added. ')
482
+ pdb_id_code = gr.Textbox(label="PDB ID", lines=1, placeholder='INPUT PDB ID TO FETCH (ex. 1DPX)', visible=True)
483
+ contigs = gr.Textbox(label="contigs", lines=1, placeholder='specify contigs to grab particular residues from pdb ()', visible=True)
484
+ gr.Markdown('Using the same contig syntax, seq or str of input motif residues can be masked, allowing the model to hold strucutre fixed and design sequence or vice-versa')
485
+ with gr.Row():
486
+ seq_mask = gr.Textbox(label='seq mask',lines=1,placeholder='input residues to mask sequence')
487
+ str_mask = gr.Textbox(label='str mask',lines=1,placeholder='input residues to mask structure')
488
+ preview_viewer = gr.HTML()
489
+ rewrite_pdb = gr.File(label='PDB file')
490
+ preview_btn = gr.Button("Preview Motif")
491
+
492
+ with gr.TabItem("MSA to PSSM"):
493
+ gr.Markdown("""### MSA to PSSM Generation""")
494
+ gr.Markdown('input either an MSA or PSSM to guide the model toward generating samples within your family of interest')
495
+ with gr.Row():
496
+ fasta_msa = gr.File(label='MSA')
497
+ input_pssm = gr.File(label='PSSM (.csv)')
498
+ pssm = gr.File(label='Generated PSSM')
499
+ pssm_view = gr.Plot(label='PSSM Viewer')
500
+ pssm_gen_btn = gr.Button("Generate PSSM")
501
+
502
+
503
+ btn = gr.Button("GENERATE")
504
+
505
+ #with gr.Row():
506
+ with gr.Column():
507
+ gr.Markdown("""## OUTPUTS""")
508
+ gr.Markdown("""#### Confidence score for generated structure at each timestep""")
509
+ plddt_plot = gr.Plot(label='plddt at step t')
510
+ gr.Markdown("""#### Output protein sequnece""")
511
+ output_seq = gr.Textbox(label="sequence")
512
+ gr.Markdown("""#### Download PDB file""")
513
+ output_pdb = gr.File(label="PDB file")
514
+ gr.Markdown("""#### Structure viewer""")
515
+ output_viewer = gr.HTML()
516
+ '''
517
+ gr.Markdown("""### Don't know where to get started? Click on an example below to try it out!""")
518
+ gr.Examples(
519
+ [["","125",0.0,0.0,0.2,"","","","20","normal",'','','',None,'','',None],
520
+ ["","100",0.0,0.0,0.0,"","W0.2","2","20","normal",'','','',None,'','',None],
521
+ # ["","100",0.0,0.0,0.0,
522
+ # "XXHHHHHHHHHXXXXXXXHHHHHHHHHXXXXXXXHHHHHHHHXXXXSSSSSSSSSSSXXXXXXXXSSSSSSSSSSSSXXXXXXXSSSSSSSSSXXXXXXX",
523
+ # "","","25","normal",'','','',None,'','',None],
524
+ # ["XXXXXXXXXXXXXXXXXXXXXXXXXIPDXXXXXXXXXXXXXXXXXXXXXXPEPSEQXXXXXXXXXXXXXXXXXXXXXXXXXXIPDXXXXXXXXXXXXXXXXXXX",
525
+ # "",0.0,0.0,0.0,"","","","25","normal",'','','',None,'','',None],
526
+ # ["","",0.0,0.0,0.0,"","","","25","normal",'','',
527
+ # '9,D10-11,8,D20-20,4,D25-35,65,D101-101,2,D104-105,8,D114-116,15,D132-138,6,D145-145,2,D148-148,12,D161-161,3',
528
+ # './tmp/PSSM_lysozyme.csv',
529
+ # 'D25-25,D27-31,D33-35,D132-137',
530
+ # 'D26-26','./tmp/150l.pdb']
531
+ ],
532
+ inputs=[sequence,
533
+ seq_len,
534
+ helix_bias,
535
+ strand_bias,
536
+ loop_bias,
537
+ secondary_structure,
538
+ aa_bias,
539
+ aa_bias_potential,
540
+ #target_charge,
541
+ #target_ph,
542
+ #charge_potential,
543
+ num_steps,
544
+ noise,
545
+ hydrophobic_target_score,
546
+ hydrophobic_potential,
547
+ contigs,
548
+ pssm,
549
+ seq_mask,
550
+ str_mask,
551
+ rewrite_pdb],
552
+ outputs=[output_seq,
553
+ output_pdb,
554
+ output_viewer,
555
+ plddt_plot],
556
+ fn=protein_diffusion_model,
557
+ )
558
+ '''
559
+ preview_btn.click(get_motif_preview,[pdb_id_code, contigs],[preview_viewer, rewrite_pdb])
560
+
561
+ pssm_gen_btn.click(get_pssm,[fasta_msa,input_pssm],[pssm_view, pssm])
562
+
563
+ btn.click(protein_diffusion_model,
564
+ [sequence,
565
+ seq_len,
566
+ helix_bias,
567
+ strand_bias,
568
+ loop_bias,
569
+ secondary_structure,
570
+ aa_bias,
571
+ aa_bias_potential,
572
+ #target_charge,
573
+ #target_ph,
574
+ #charge_potential,
575
+ num_steps,
576
+ noise,
577
+ hydrophobic_target_score,
578
+ hydrophobic_potential,
579
+ contigs,
580
+ pssm,
581
+ seq_mask,
582
+ str_mask,
583
+ rewrite_pdb],
584
+ [output_seq,
585
+ output_pdb,
586
+ output_viewer,
587
+ plddt_plot])
588
+
589
+ demo.queue()
590
+ demo.launch(debug=True)
591
+