Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModel, AutoConfig, logging | |
from huggingface_hub import login | |
import os | |
import json | |
import torch | |
from pathlib import Path | |
import seaborn as sns | |
from collections import defaultdict | |
from gradio_modal import Modal | |
logging.set_verbosity_error() | |
metl_config = AutoConfig.from_pretrained('gitter-lab/METL', trust_remote_code=True, cache_dir='./cache') | |
metl = AutoModel.from_pretrained('gitter-lab/METL', trust_remote_code=True, cache_dir='./cache') | |
pdb_path = None | |
RADIO_CSS = """ | |
#indexing>div { | |
flex-direction: column; | |
width: 50px; | |
} | |
#pdbUpload { | |
height: 150px; | |
} | |
#modelPDBRow{ | |
height: 150px; | |
} | |
#modelInputCol{ | |
gap:0px; | |
} | |
#modelStatus { | |
margin: auto; | |
text-align: center; | |
} | |
.options { | |
height: 300px; !important | |
} | |
.multiModalText > label > div > textarea{ | |
border-style: solid; | |
border-radius: var(--block-radius); | |
border-color: red; !important | |
padding-left: 5px; | |
border-color: var(--border-color-primary); | |
border-width: 1px; | |
} | |
.multiModalText > label > div > button.upload-button { | |
margin-right: 5px; | |
} | |
#wildTypeSequence > label > div > button.upload-button { | |
display: none; | |
} | |
.multiModalText > label > div > button.submit-button { | |
display: none; | |
} | |
.main { | |
width: 50vw; | |
min-width: 700px; | |
margin-left: auto; | |
margin-right: auto; | |
} | |
#moleculeRow > div:not(.form) { | |
flex-grow: 90; | |
overflow: hidden; | |
} | |
#variantCheck > div.wrap > label { | |
width: 100%; | |
} | |
div:has(> button.selectorButton){ | |
flex-direction: row; | |
flex-wrap: nowrap; | |
} | |
.selectorButton{ | |
width: 100px; | |
min-width: 100px; | |
} | |
.selectionHint{ | |
text-align: center; | |
} | |
#helpModal{ | |
position: absolute; | |
bottom: calc(100% - 2.58rem); | |
width: 5rem; | |
left: 101%; | |
} | |
.modal-container{ | |
width: 45rem; | |
} | |
#helpModalText{ | |
font-size:large; | |
} | |
p { | |
font-size: large; | |
} | |
li { | |
font-size: large; | |
} | |
""" | |
RED = "#DA667B" | |
GREEN = "#6B9A5F" | |
def generate_iframe_html(variants): | |
global pdb_path | |
if pdb_path is not None: | |
with open(pdb_path, 'r') as f: | |
mol = f.read() | |
residue_code = "" | |
if variants is not None: | |
if isinstance(variants, str): | |
try: | |
variants = json.loads(variants) | |
except: | |
return '<span style="color: white;">The variants given were not in a valid JSON list format</span>' | |
if len(variants) <= 9: | |
cmap = sns.color_palette('colorblind').as_hex() | |
del cmap[-3] #Doesn't show up well on the molecule | |
else: | |
#no colorblind support past 9 items. RIP | |
cmap = sns.color_palette('husl', len(variants)).as_hex() | |
duplicate_dict = defaultdict(int) | |
visited_dict = defaultdict(int) | |
# I have to do this twice so I can color the duplicates the duplicate color (black) | |
for index, variant in enumerate(variants): | |
variant_list = variant.split(',') | |
for mutation in variant_list: | |
residue = mutation[1:-1] | |
duplicate_dict[residue] += 1 | |
for index, variant in enumerate(variants): | |
variant_list = variant.split(',') | |
for mutation in variant_list: | |
residue = mutation[1:-1] | |
visited_dict[residue] += 1 | |
if duplicate_dict[residue] > 1 and visited_dict[residue] == 1: | |
continue | |
elif duplicate_dict[residue] > 1: | |
residue_code += 'viewer.getModel(0).setStyle({resi:[' + residue + ']}, {cartoon:{color:"#570606"}});\n' | |
else: | |
residue_code += 'viewer.getModel(0).setStyle({resi:[' + residue + ']}, {cartoon:{color:" ' + cmap[index] + '"}});\n' | |
script = ( | |
""" | |
<script> | |
let pdb = `""" + mol + """` | |
$(document).ready(function () { | |
let element = $("#container"); | |
let config = { backgroundColor: "white" }; | |
let viewer = $3Dmol.createViewer(element, config); | |
viewer.addModel(pdb, "pdb"); | |
viewer.getModel(0).setStyle({}, { cartoon: { colorscheme:"whiteCarbon" } });""" | |
+ residue_code + | |
"""viewer.render(); | |
viewer.zoomTo(); | |
}) | |
</script> | |
""") | |
# with open('./scriptTESTVIEW.txt', 'w') as f: | |
# f.write(script) | |
else: | |
script = "" | |
x = ( | |
"""<!DOCTYPE html> | |
<html> | |
<head> | |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
<style> | |
body{ | |
overflow: hidden; | |
} | |
.mol-container { | |
width: 100%; | |
height: 600px; | |
position: relative; | |
} | |
.mol-container select{ | |
background-image:None; | |
} | |
</style> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script> | |
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
</head> | |
<body> | |
<div id="container" class="mol-container"></div> | |
""" + script + """ | |
</body> | |
</html>""") | |
return x | |
def get_iframe(variants): | |
x = generate_iframe_html(variants) | |
return f"""<iframe id="molecule" style="width: 100%; height: 600px;" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
def to_zero_based(variants): | |
zero_based = [] | |
for line in variants: | |
line_as_json = json.loads(line) | |
new_variants = [] | |
for variant in line_as_json: | |
new_variant = [] | |
mutations = variant.split(',') | |
for mutation in mutations: | |
residue_zero_based = int(mutation[1:-1]) - 1 | |
new_variant.append(f"{mutation[0]}{residue_zero_based}{mutation[-1]}") | |
new_variants.append(",".join(new_variant)) | |
zero_based.append(new_variants) | |
return zero_based | |
def get_lines_from_multimodal(modal_output): | |
if len(modal_output['text']) == 0 and len(modal_output['files']) == 0: | |
return [] | |
if len(modal_output['files']) == 0: | |
return [modal_output['text']] | |
text = open(modal_output['files'][0], 'r').readlines() | |
text = [line.strip() for line in text] | |
if len(modal_output['text']) > 0: | |
text.append(modal_output['text']) | |
return text | |
def get_color(color, model_text): | |
return f'<span style="color: {color};">{model_text}</span>' | |
def get_file(filepath: str): | |
global pdb_path | |
pdb_path = filepath | |
print(filepath) | |
iframe = get_iframe(None) | |
return iframe | |
def empty_pdb_path(button): | |
global pdb_path | |
pdb_path = None | |
return "" | |
def load_model(model_id, _): | |
global metl | |
if not isinstance(model_id, str): | |
return get_color(RED, "Select Model"), gr.Button(interactive=False) | |
if model_id.lower() in metl.config.IDENT_UUID_MAP: | |
metl.load_from_ident(model_id) | |
elif model_id in metl.config.UUID_URL_MAP: | |
metl.load_from_uuid(model_id) | |
else: | |
return get_color(RED, "Model Load Failed"), gr.Button(interactive=False) | |
return get_color(GREEN, f"{model_id} loaded"), gr.Button(interactive=True) | |
def update_pdb(variant_modal, indexing): | |
if len(variant_modal['text']) == 0 and len(variant_modal['files']) == 0: | |
return gr.CheckboxGroup() | |
variants = get_lines_from_multimodal(modal_output = variant_modal) | |
if indexing == 1: | |
variants = to_zero_based(variants)[0] | |
else: | |
variants = json.loads(variants[0]) | |
print(variants) | |
return gr.CheckboxGroup(choices=variants, value=variants, visible=True), get_iframe(variants=variants) | |
def select_or_deselect_all(button_name, choices, variant_modal, indexing): | |
if "De" in button_name: | |
return gr.CheckboxGroup(visible=True, choices=choices, value=[]), get_iframe(variants=None) | |
variants = get_lines_from_multimodal(modal_output = variant_modal) | |
if indexing == 1: | |
variants = to_zero_based(variants) | |
variants = variants[0] | |
checkbox, iframe = update_pdb(variant_modal=variant_modal, indexing=indexing) | |
return checkbox, iframe | |
def hide_variants(checkbox_values): | |
return get_iframe(variants=checkbox_values) | |
def populate_example(): | |
global pdb_path | |
model = "metl-l-2m-3d-gb1" | |
wt = "MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE" | |
variants = '["T17P,T54F", "V28L,F51A"]' | |
# "T17P,V28L,F51A,T54F" | |
pdb_path = './2qmt_p.pdb' | |
status, pred_button = load_model(model, None) | |
wt_dict = { | |
"text": wt, | |
"files": [] | |
} | |
variants_dict = { | |
"text": variants, | |
"files": [] | |
} | |
default_checkbox, iframe = update_pdb(variant_modal=variants_dict, indexing=0) | |
return ( | |
model, | |
status, | |
pred_button, | |
wt_dict, | |
variants_dict, | |
iframe, | |
pdb_path, | |
default_checkbox, | |
gr.Button(visible=True), | |
gr.Button(visible=True) | |
) | |
def hide_mutation(): | |
pass | |
def predict(input_multi_modal, variant_multi_modal, variant_index_type): | |
global pdb_path | |
global metl | |
input_sequences = get_lines_from_multimodal(input_multi_modal) | |
variants = get_lines_from_multimodal(variant_multi_modal) | |
if len(input_sequences) == 0 or len(variants) == 0: | |
err_out = "Invalid input. " | |
if len(input_sequences) == 0: | |
err_out += "Input sequences were not given, but a wild type must be given." | |
if len(variants) == 0: | |
err_out += "Mutations were not given, but mutations must be given in JSON array format to predict with METL" | |
return err_out, get_iframe(None), gr.Button(interactive=False) | |
try: | |
if variant_index_type == 1: | |
variants = to_zero_based(variants) | |
else: | |
variants = [json.loads(variant) for variant in variants] | |
except: | |
err_out = "One or more of the mutations given were not in a valid JSON list format" | |
return err_out, get_iframe(None), gr.Button(interactive=False) | |
metl.eval() | |
outputs = [] | |
sequence = input_sequences[0] | |
for index, variant in enumerate(variants): | |
if index >= 100: | |
break | |
encoded_variants = metl.encoder.encode_variants(sequence, variant) | |
with torch.no_grad(): | |
if pdb_path is not None: | |
predictions = metl(torch.tensor(encoded_variants), pdb_fn=pdb_path) | |
else: | |
predictions = metl(torch.tensor(encoded_variants)) | |
outputs.append({ | |
"wt": sequence, | |
"variants": variant, | |
"logits": predictions.tolist() | |
}) | |
out_str = json.dumps(outputs) if len(outputs) > 1 else str(outputs[0]['logits']) | |
variants_dict = { | |
"text": json.dumps(variants[0]), | |
"files": [] | |
} | |
# We do 0 for the indexing even though we are using 1 sometimes because we update it here already | |
# We don't want update_pdb to double up on subtracting 1 from the index again | |
checkbox, iframe = update_pdb(variant_modal=variants_dict, indexing=0) | |
return out_str, iframe, checkbox | |
with gr.Blocks(css=RADIO_CSS) as demo: | |
with gr.Row(equal_height=True, elem_id="modelPDBRow"): | |
with gr.Column(elem_id="modelInputCol"): | |
metl_model_id = gr.Dropdown(label="METL model IDENT or UUID", choices=list(metl.config.IDENT_UUID_MAP.keys()), allow_custom_value=False) | |
metl_model_status = gr.HTML(get_color(RED, "Select Model"), elem_id="modelStatus") | |
with gr.Column(): | |
upload_pdb = gr.File(label="PDB File upload", elem_id="pdbUpload", file_types=[".pdb", ".txt"]) | |
with gr.Column(): | |
metl_seq_input = gr.MultimodalTextbox(label="Input Protein Sequence", interactive=True, elem_classes="multiModalText", elem_id="wildTypeSequence") | |
with gr.Row(): | |
metl_variants = gr.MultimodalTextbox(label="JSON variant list", scale=100, interactive=True, elem_classes="multiModalText", file_types=[".json", ".txt"]) | |
variant_indexing = gr.Radio(choices=[0, 1], elem_id=["indexing"], min_width=80, label='Indexing', value=0) | |
metl_update_pdb_display = gr.Button("Update PDB display", min_width=100) | |
metl_output = gr.TextArea(label="Output from METL", interactive=False, show_copy_button=True) | |
help_text = gr.Markdown("Load a model and if necessary, upload a pdb file to get started. File inputs for mutations must be 1 JSON list per line. When a file is uploaded with multiple mutations, the first mutation will be displayed on the molecule.") | |
metl_run_button = gr.Button("Run Prediction", interactive=False) | |
metl_load_example = gr.Button("Load Example") | |
with gr.Row(elem_id="moleculeRow"): | |
molecule = gr.HTML() | |
with gr.Column(scale=10): | |
with gr.Row(): | |
select_all = gr.Button(elem_classes=["selectorButton"], value="Deselect all", visible=False) | |
deselect_all = gr.Button(elem_classes=["selectorButton"], value="Select all", visible=False) | |
show_hide_variants = gr.CheckboxGroup(show_label=False, visible=False, elem_id="variantCheck") | |
with Modal(visible=False) as help_modal: | |
modal_text = """ | |
This is a demo for [METL](https://huggingface.co/gitter-lab/METL). The supported METL models are listed in the dropdown. | |
The specifics of each of these models may be found in the above 🤗 link or [here](https://github.com/gitter-lab/metl?tab=readme-ov-file). | |
To run this demo, follow these steps: | |
1. select a model in the provided dropdown. | |
2. upload a pdb file if it is required for your prediction. | |
3. paste in your wild type sequence | |
4. paste in your mutations in JSON list format, where each mutation in the list is a CSV string separated by double quotes ("). | |
- an exapmle is provided when the "Load Example" button is pressed. | |
- if a PDB is given, and mutations are in the corresponding text box then update button may be used to display those mutations in an interactive molecule display using 3Dmol.js. | |
5. Press "Run Prediction" | |
For cases where many combinations of mutations are given, uploading a text file where each line is a new JSON list (as described above) will allow up to 100 different METL predictions! | |
""" | |
gr.Markdown(modal_text, elem_id="helpModalText") | |
help_alert = gr.Button("Help!", elem_id="helpModal") | |
help_alert.click(lambda: Modal(visible=True), None, help_modal) | |
## Model and PDB event handlers | |
metl_model_id.input(fn=load_model, inputs=[metl_model_id, metl_model_id], outputs=[metl_model_status, metl_run_button]) | |
upload_pdb.clear() | |
upload_pdb.upload(fn=get_file, inputs=upload_pdb, outputs=molecule, show_progress=False) | |
metl_update_pdb_display.click(fn=update_pdb, inputs=[metl_variants, variant_indexing], outputs=[show_hide_variants, molecule], show_progress=False) | |
## Predicting event handlers | |
metl_run_button.click(fn=predict, inputs=[metl_seq_input, metl_variants, variant_indexing], outputs=[metl_output, molecule, show_hide_variants], show_progress=False) | |
## Load example | |
metl_load_example.click(fn=populate_example, outputs=[metl_model_id, metl_model_status, metl_run_button, metl_seq_input, | |
metl_variants, molecule, upload_pdb, show_hide_variants, | |
select_all, deselect_all], show_progress=False) | |
## Event handlers for the molcule display | |
select_all.click(fn=select_or_deselect_all, inputs=[select_all, show_hide_variants, metl_variants, variant_indexing], outputs=[show_hide_variants, molecule], show_progress=False) | |
deselect_all.click(fn=select_or_deselect_all, inputs=[deselect_all, show_hide_variants, metl_variants, variant_indexing], outputs=[show_hide_variants, molecule], show_progress=False) | |
show_hide_variants.input(fn=hide_variants, inputs=show_hide_variants, outputs=molecule, show_progress=False) | |
demo.launch() | |
# test model: METL-L-2M-3D-GB1 | |
# test wild type: MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE | |
# test variants: ["T17P,T54F", "V28L,F51A", "T17P,V28L,F51A,T54F"] | |