METL_demo / app.py
jgpeters's picture
Update app.py
c9b4304 verified
raw
history blame
17 kB
import gradio as gr
from transformers import AutoModel, AutoConfig, logging
from huggingface_hub import login
import os
import json
import torch
from pathlib import Path
import seaborn as sns
from collections import defaultdict
from gradio_modal import Modal
logging.set_verbosity_error()
metl_config = AutoConfig.from_pretrained('gitter-lab/METL', trust_remote_code=True, cache_dir='./cache')
metl = AutoModel.from_pretrained('gitter-lab/METL', trust_remote_code=True, cache_dir='./cache')
pdb_path = None
RADIO_CSS = """
#indexing>div {
flex-direction: column;
width: 50px;
}
#pdbUpload {
height: 150px;
}
#modelPDBRow{
height: 150px;
}
#modelInputCol{
gap:0px;
}
#modelStatus {
margin: auto;
text-align: center;
}
.options {
height: 300px; !important
}
.multiModalText > label > div > textarea{
border-style: solid;
border-radius: var(--block-radius);
border-color: red; !important
padding-left: 5px;
border-color: var(--border-color-primary);
border-width: 1px;
}
.multiModalText > label > div > button.upload-button {
margin-right: 5px;
}
#wildTypeSequence > label > div > button.upload-button {
display: none;
}
.multiModalText > label > div > button.submit-button {
display: none;
}
.main {
width: 50vw;
min-width: 700px;
margin-left: auto;
margin-right: auto;
}
#moleculeRow > div:not(.form) {
flex-grow: 90;
overflow: hidden;
}
#variantCheck > div.wrap > label {
width: 100%;
}
div:has(> button.selectorButton){
flex-direction: row;
flex-wrap: nowrap;
}
.selectorButton{
width: 100px;
min-width: 100px;
}
.selectionHint{
text-align: center;
}
#helpModal{
position: absolute;
bottom: calc(100% - 2.58rem);
width: 5rem;
left: 101%;
}
.modal-container{
width: 45rem;
}
#helpModalText{
font-size:large;
}
p {
font-size: large;
}
li {
font-size: large;
}
"""
RED = "#DA667B"
GREEN = "#6B9A5F"
def generate_iframe_html(variants):
global pdb_path
if pdb_path is not None:
with open(pdb_path, 'r') as f:
mol = f.read()
residue_code = ""
if variants is not None:
if isinstance(variants, str):
try:
variants = json.loads(variants)
except:
return '<span style="color: white;">The variants given were not in a valid JSON list format</span>'
if len(variants) <= 9:
cmap = sns.color_palette('colorblind').as_hex()
del cmap[-3] #Doesn't show up well on the molecule
else:
#no colorblind support past 9 items. RIP
cmap = sns.color_palette('husl', len(variants)).as_hex()
duplicate_dict = defaultdict(int)
visited_dict = defaultdict(int)
# I have to do this twice so I can color the duplicates the duplicate color (black)
for index, variant in enumerate(variants):
variant_list = variant.split(',')
for mutation in variant_list:
residue = mutation[1:-1]
duplicate_dict[residue] += 1
for index, variant in enumerate(variants):
variant_list = variant.split(',')
for mutation in variant_list:
residue = mutation[1:-1]
visited_dict[residue] += 1
if duplicate_dict[residue] > 1 and visited_dict[residue] == 1:
continue
elif duplicate_dict[residue] > 1:
residue_code += 'viewer.getModel(0).setStyle({resi:[' + residue + ']}, {cartoon:{color:"#570606"}});\n'
else:
residue_code += 'viewer.getModel(0).setStyle({resi:[' + residue + ']}, {cartoon:{color:" ' + cmap[index] + '"}});\n'
script = (
"""
<script>
let pdb = `""" + mol + """`
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "white" };
let viewer = $3Dmol.createViewer(element, config);
viewer.addModel(pdb, "pdb");
viewer.getModel(0).setStyle({}, { cartoon: { colorscheme:"whiteCarbon" } });"""
+ residue_code +
"""viewer.render();
viewer.zoomTo();
})
</script>
""")
# with open('./scriptTESTVIEW.txt', 'w') as f:
# f.write(script)
else:
script = ""
x = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body{
overflow: hidden;
}
.mol-container {
width: 100%;
height: 600px;
position: relative;
}
.mol-container select{
background-image:None;
}
</style>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
""" + script + """
</body>
</html>""")
return x
def get_iframe(variants):
x = generate_iframe_html(variants)
return f"""<iframe id="molecule" style="width: 100%; height: 600px;" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
def to_zero_based(variants):
zero_based = []
for line in variants:
line_as_json = json.loads(line)
new_variants = []
for variant in line_as_json:
new_variant = []
mutations = variant.split(',')
for mutation in mutations:
residue_zero_based = int(mutation[1:-1]) - 1
new_variant.append(f"{mutation[0]}{residue_zero_based}{mutation[-1]}")
new_variants.append(",".join(new_variant))
zero_based.append(new_variants)
return zero_based
def get_lines_from_multimodal(modal_output):
if len(modal_output['text']) == 0 and len(modal_output['files']) == 0:
return []
if len(modal_output['files']) == 0:
return [modal_output['text']]
text = open(modal_output['files'][0], 'r').readlines()
text = [line.strip() for line in text]
if len(modal_output['text']) > 0:
text.append(modal_output['text'])
return text
def get_color(color, model_text):
return f'<span style="color: {color};">{model_text}</span>'
def get_file(filepath: str):
global pdb_path
pdb_path = filepath
print(filepath)
iframe = get_iframe(None)
return iframe
def empty_pdb_path(button):
global pdb_path
pdb_path = None
return ""
def load_model(model_id, _):
global metl
if not isinstance(model_id, str):
return get_color(RED, "Select Model"), gr.Button(interactive=False)
if model_id.lower() in metl.config.IDENT_UUID_MAP:
metl.load_from_ident(model_id)
elif model_id in metl.config.UUID_URL_MAP:
metl.load_from_uuid(model_id)
else:
return get_color(RED, "Model Load Failed"), gr.Button(interactive=False)
return get_color(GREEN, f"{model_id} loaded"), gr.Button(interactive=True)
def update_pdb(variant_modal, indexing):
if len(variant_modal['text']) == 0 and len(variant_modal['files']) == 0:
return gr.CheckboxGroup()
variants = get_lines_from_multimodal(modal_output = variant_modal)
if indexing == 1:
variants = to_zero_based(variants)[0]
else:
variants = json.loads(variants[0])
print(variants)
return gr.CheckboxGroup(choices=variants, value=variants, visible=True), get_iframe(variants=variants)
def select_or_deselect_all(button_name, choices, variant_modal, indexing):
if "De" in button_name:
return gr.CheckboxGroup(visible=True, choices=choices, value=[]), get_iframe(variants=None)
variants = get_lines_from_multimodal(modal_output = variant_modal)
if indexing == 1:
variants = to_zero_based(variants)
variants = variants[0]
checkbox, iframe = update_pdb(variant_modal=variant_modal, indexing=indexing)
return checkbox, iframe
def hide_variants(checkbox_values):
return get_iframe(variants=checkbox_values)
def populate_example():
global pdb_path
model = "metl-l-2m-3d-gb1"
wt = "MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE"
variants = '["T17P,T54F", "V28L,F51A"]'
# "T17P,V28L,F51A,T54F"
pdb_path = './2qmt_p.pdb'
status, pred_button = load_model(model, None)
wt_dict = {
"text": wt,
"files": []
}
variants_dict = {
"text": variants,
"files": []
}
default_checkbox, iframe = update_pdb(variant_modal=variants_dict, indexing=0)
return (
model,
status,
pred_button,
wt_dict,
variants_dict,
iframe,
pdb_path,
default_checkbox,
gr.Button(visible=True),
gr.Button(visible=True)
)
def hide_mutation():
pass
def predict(input_multi_modal, variant_multi_modal, variant_index_type):
global pdb_path
global metl
input_sequences = get_lines_from_multimodal(input_multi_modal)
variants = get_lines_from_multimodal(variant_multi_modal)
if len(input_sequences) == 0 or len(variants) == 0:
err_out = "Invalid input. "
if len(input_sequences) == 0:
err_out += "Input sequences were not given, but a wild type must be given."
if len(variants) == 0:
err_out += "Mutations were not given, but mutations must be given in JSON array format to predict with METL"
return err_out, get_iframe(None), gr.Button(interactive=False)
try:
if variant_index_type == 1:
variants = to_zero_based(variants)
else:
variants = [json.loads(variant) for variant in variants]
except:
err_out = "One or more of the mutations given were not in a valid JSON list format"
return err_out, get_iframe(None), gr.Button(interactive=False)
metl.eval()
outputs = []
sequence = input_sequences[0]
for index, variant in enumerate(variants):
if index >= 100:
break
encoded_variants = metl.encoder.encode_variants(sequence, variant)
with torch.no_grad():
if pdb_path is not None:
predictions = metl(torch.tensor(encoded_variants), pdb_fn=pdb_path)
else:
predictions = metl(torch.tensor(encoded_variants))
outputs.append({
"wt": sequence,
"variants": variant,
"logits": predictions.tolist()
})
out_str = json.dumps(outputs) if len(outputs) > 1 else str(outputs[0]['logits'])
variants_dict = {
"text": json.dumps(variants[0]),
"files": []
}
# We do 0 for the indexing even though we are using 1 sometimes because we update it here already
# We don't want update_pdb to double up on subtracting 1 from the index again
checkbox, iframe = update_pdb(variant_modal=variants_dict, indexing=0)
return out_str, iframe, checkbox
with gr.Blocks(css=RADIO_CSS) as demo:
with gr.Row(equal_height=True, elem_id="modelPDBRow"):
with gr.Column(elem_id="modelInputCol"):
metl_model_id = gr.Dropdown(label="METL model IDENT or UUID", choices=list(metl.config.IDENT_UUID_MAP.keys()), allow_custom_value=False)
metl_model_status = gr.HTML(get_color(RED, "Select Model"), elem_id="modelStatus")
with gr.Column():
upload_pdb = gr.File(label="PDB File upload", elem_id="pdbUpload", file_types=[".pdb", ".txt"])
with gr.Column():
metl_seq_input = gr.MultimodalTextbox(label="Input Protein Sequence", interactive=True, elem_classes="multiModalText", elem_id="wildTypeSequence")
with gr.Row():
metl_variants = gr.MultimodalTextbox(label="JSON variant list", scale=100, interactive=True, elem_classes="multiModalText", file_types=[".json", ".txt"])
variant_indexing = gr.Radio(choices=[0, 1], elem_id=["indexing"], min_width=80, label='Indexing', value=0)
metl_update_pdb_display = gr.Button("Update PDB display", min_width=100)
metl_output = gr.TextArea(label="Output from METL", interactive=False, show_copy_button=True)
help_text = gr.Markdown("Load a model and if necessary, upload a pdb file to get started. File inputs for mutations must be 1 JSON list per line. When a file is uploaded with multiple mutations, the first mutation will be displayed on the molecule.")
metl_run_button = gr.Button("Run Prediction", interactive=False)
metl_load_example = gr.Button("Load Example")
with gr.Row(elem_id="moleculeRow"):
molecule = gr.HTML()
with gr.Column(scale=10):
with gr.Row():
select_all = gr.Button(elem_classes=["selectorButton"], value="Deselect all", visible=False)
deselect_all = gr.Button(elem_classes=["selectorButton"], value="Select all", visible=False)
show_hide_variants = gr.CheckboxGroup(show_label=False, visible=False, elem_id="variantCheck")
with Modal(visible=False) as help_modal:
modal_text = """
This is a demo for [METL](https://huggingface.co/gitter-lab/METL). The supported METL models are listed in the dropdown.
The specifics of each of these models may be found in the above 🤗 link or [here](https://github.com/gitter-lab/metl?tab=readme-ov-file).
To run this demo, follow these steps:
1. select a model in the provided dropdown.
2. upload a pdb file if it is required for your prediction.
3. paste in your wild type sequence
4. paste in your mutations in JSON list format, where each mutation in the list is a CSV string separated by double quotes (").
- an exapmle is provided when the "Load Example" button is pressed.
- if a PDB is given, and mutations are in the corresponding text box then update button may be used to display those mutations in an interactive molecule display using 3Dmol.js.
5. Press "Run Prediction"
For cases where many combinations of mutations are given, uploading a text file where each line is a new JSON list (as described above) will allow up to 100 different METL predictions!
"""
gr.Markdown(modal_text, elem_id="helpModalText")
help_alert = gr.Button("Help!", elem_id="helpModal")
help_alert.click(lambda: Modal(visible=True), None, help_modal)
## Model and PDB event handlers
metl_model_id.input(fn=load_model, inputs=[metl_model_id, metl_model_id], outputs=[metl_model_status, metl_run_button])
upload_pdb.clear()
upload_pdb.upload(fn=get_file, inputs=upload_pdb, outputs=molecule, show_progress=False)
metl_update_pdb_display.click(fn=update_pdb, inputs=[metl_variants, variant_indexing], outputs=[show_hide_variants, molecule], show_progress=False)
## Predicting event handlers
metl_run_button.click(fn=predict, inputs=[metl_seq_input, metl_variants, variant_indexing], outputs=[metl_output, molecule, show_hide_variants], show_progress=False)
## Load example
metl_load_example.click(fn=populate_example, outputs=[metl_model_id, metl_model_status, metl_run_button, metl_seq_input,
metl_variants, molecule, upload_pdb, show_hide_variants,
select_all, deselect_all], show_progress=False)
## Event handlers for the molcule display
select_all.click(fn=select_or_deselect_all, inputs=[select_all, show_hide_variants, metl_variants, variant_indexing], outputs=[show_hide_variants, molecule], show_progress=False)
deselect_all.click(fn=select_or_deselect_all, inputs=[deselect_all, show_hide_variants, metl_variants, variant_indexing], outputs=[show_hide_variants, molecule], show_progress=False)
show_hide_variants.input(fn=hide_variants, inputs=show_hide_variants, outputs=molecule, show_progress=False)
demo.launch()
# test model: METL-L-2M-3D-GB1
# test wild type: MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE
# test variants: ["T17P,T54F", "V28L,F51A", "T17P,V28L,F51A,T54F"]