Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import transformers | |
from transformers import GenerationConfig, pipeline, AutoTokenizer, AutoModelForCausalLM, EsmForProteinFolding | |
import os | |
import tempfile | |
import subprocess | |
import pandas as pd | |
import numpy as np | |
import gradio as gr | |
from time import time | |
import requests | |
model_id = "Esperanto/Protein-Phi-3-mini" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "left" | |
#Creating the pipeline for generation | |
generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
# Loading the ESM Model | |
esm_model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1") | |
esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") | |
#Ensures that final output contains only valid amino acids | |
def clean_protein_sequence(protein_seq): | |
# Valid amino acid characters | |
valid_amino_acids = "ACDEFGHIKLMNPQRSTVWY" | |
# Filter out any characters that are not valid amino acids | |
cleaned_seq = ''.join([char for char in protein_seq if char in valid_amino_acids]) | |
return cleaned_seq | |
#convert pLDDT to percentage | |
def modify_b_factors(pdb_content, multiplier): | |
modified_pdb = [] | |
for line in pdb_content.split('\n'): | |
if line.startswith("ATOM"): | |
b_factor = float(line[60:66].strip()) | |
new_b_factor = b_factor * multiplier | |
new_line = f"{line[:60]}{new_b_factor:6.2f}{line[66:]}" | |
modified_pdb.append(new_line) | |
else: | |
modified_pdb.append(line) | |
return "\n".join(modified_pdb) | |
#saves the structure output from ESMFold as a PDB file in a temporary folder | |
def save_pdb(input_sequence): | |
inputs = esm_tokenizer([input_sequence], return_tensors="pt", add_special_tokens=False) | |
with torch.no_grad(): | |
outputs = esm_model(**inputs) | |
pdb_string_unscaled = esm_model.output_to_pdb(outputs)[0] | |
pdb_string = modify_b_factors(pdb_string_unscaled, 100) | |
plddt_values = outputs.plddt.tolist()[0][0] | |
plddt_values = [round(value * 100, 2) for value in plddt_values] | |
file_path = os.path.join('Protein-Llama-3-8B-Gradio/temporary_folder', f"protein.pdb") | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
with open(file_path, "w") as f: | |
f.write(pdb_string) | |
return np.mean(plddt_values) | |
#reads the PDB file | |
def read_prot(molpath): | |
with open(molpath, "r") as fp: | |
lines = fp.readlines() | |
mol = "" | |
for l in lines: | |
mol += l | |
return mol | |
def get_cov2_pdb(): | |
pdb_id = '6vxx' | |
url = f'https://files.rcsb.org/download/{pdb_id}.pdb' | |
# Fetch the PDB file from the RCSB PDB website | |
response = requests.get(url) | |
# Check if the request was successful | |
mol = response.text | |
return mol | |
def protein_visual_html(input_pdb): | |
if input_pdb == 'cov2': | |
mol = get_cov2_pdb() | |
else: | |
mol = read_prot(input_pdb) | |
x = ( | |
"""<!DOCTYPE html> | |
<html> | |
<head> | |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
<style> | |
body{ | |
font-family:sans-serif | |
} | |
.mol-container { | |
width: 100%; | |
height: 600px; | |
position: relative; | |
} | |
.mol-container select{ | |
background-image:None; | |
} | |
</style> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script> | |
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
</head> | |
<body> | |
<div id="container" class="mol-container"></div> | |
<script> | |
let pdb = `""" + mol + """` | |
$(document).ready(function () { | |
let element = $("#container"); | |
let config = { backgroundColor: "white" }; | |
let viewer = $3Dmol.createViewer(element, config); | |
viewer.addModel(pdb, "pdb"); | |
viewer.getModel(0).setStyle({}, { cartoon: { color:"spectrum" } }); | |
viewer.zoomTo(); | |
viewer.render(); | |
viewer.zoom(0.8, 2000); | |
}) | |
</script> | |
</body></html>""" | |
) | |
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
def predict_structure(input_sequence): | |
#Hard coding the SARS-CoV 2 protein sequence and structure for instant demo purposes | |
if input_sequence == 'SNASADAQSFLNRVCGVSAARLTPCGTGTSTDVVYRAFDIYNDKVAGFAKFLKTNCCRFQEKDEDDNLIDSYFVVKRHTFSNYQHEETIYNLLKDCPAVAKHDFFKFRIDGDMVPHISRQRLTKYTMADLVYALRHFDEGNCDTLKEILVTYNCCDDDYFNKKDWYDFVENPDILRVYANLGERVRQALLKTVQFCDAMRNAGIVGVLTLDNQDLNGNWYDFGDFIQTTPGSGVPVVDSYYSLLMPILTLTRALTAESHVDTDLTKPYIKWDLLKYDFTEERLKLFDRYFKYWDQTYHPNCVNCLDDRCILHCANFNVLFSTVFPPTSFGPLVRKIFVDGVPFVVSTGYHFRELGVVHNQDVNLHSSRLSFKELLVYAADPAMHAASGNLLLDKRTTCFSVAALTNNVAFQTVKPGNFNKDFYDFAVSKGFFKEGSSVELKHFFFAQDGNAAISDYDYYRYNLPTMCDIRQLLFVVEVVDKYFDCYDGGCINANQVI': | |
return protein_visual_html('Protein-Llama-3-8B-Gradio/sars_cov.pdb') | |
else: | |
plddt = save_pdb(input_sequence) | |
#Creating HTML visualization for the PDB file stores in temporary folder | |
pdb_path = 'cov2' | |
return protein_visual_html(pdb_path) | |
def generate_protein_sequence(sequence, seq_length, property=''): | |
enzymes = ["Non-Hemolytic", "Soluble", "Oxidoreductase", "Transferase", "Hydrolase", "Lyase", "Isomerase", "Ligase", "Translocase"] | |
start_time = time() | |
if property is None: | |
input_prompt = 'Seq=<' + sequence | |
elif property == 'SARS-CoV-2 Spike Protein (example)': | |
cleaned_seq = 'SNASADAQSFLNRVCGVSAARLTPCGTGTSTDVVYRAFDIYNDKVAGFAKFLKTNCCRFQEKDEDDNLIDSYFVVKRHTFSNYQHEETIYNLLKDCPAVAKHDFFKFRIDGDMVPHISRQRLTKYTMADLVYALRHFDEGNCDTLKEILVTYNCCDDDYFNKKDWYDFVENPDILRVYANLGERVRQALLKTVQFCDAMRNAGIVGVLTLDNQDLNGNWYDFGDFIQTTPGSGVPVVDSYYSLLMPILTLTRALTAESHVDTDLTKPYIKWDLLKYDFTEERLKLFDRYFKYWDQTYHPNCVNCLDDRCILHCANFNVLFSTVFPPTSFGPLVRKIFVDGVPFVVSTGYHFRELGVVHNQDVNLHSSRLSFKELLVYAADPAMHAASGNLLLDKRTTCFSVAALTNNVAFQTVKPGNFNKDFYDFAVSKGFFKEGSSVELKHFFFAQDGNAAISDYDYYRYNLPTMCDIRQLLFVVEVVDKYFDCYDGGCINANQVI' | |
end_time = time() | |
max_memory_used = 0 | |
return cleaned_seq, end_time - start_time, max_memory_used, 0 | |
elif property in enzymes: | |
input_prompt = '[Generate ' + property.lower() + ' protein] ' + 'Seq=<' + sequence | |
else: | |
input_prompt = '[Generate ' + property + ' protein] ' + 'Seq=<' + sequence | |
start_time = time() | |
protein_seq = generator(input_prompt, temperature=0.5, | |
top_k=40, | |
top_p=0.9, | |
do_sample=True, | |
repetition_penalty=1.2, | |
max_new_tokens=seq_length, | |
num_return_sequences=1)[0]["generated_text"] | |
end_time = time() | |
start_idx = protein_seq.find('Seq=<') | |
end_idx = protein_seq.find('>', start_idx) | |
protein_seq = protein_seq[start_idx:end_idx] | |
cleaned_seq = clean_protein_sequence(protein_seq) | |
tokens = tokenizer.encode(cleaned_seq, add_special_tokens=False) | |
tokens_per_second = len(tokens) / (end_time - start_time) | |
return cleaned_seq, end_time - start_time, tokens_per_second | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown(''' | |
### Interactive protein sequence generation and visualization. | |
Generating novel protein sequences possessing desired properties, termed as protein engineering, is crucial for industries like drug development and chemical synthesis. | |
This model supports two types of generation, uncontrollable and controllable. Uncontrollable generation refers to generating any viable protein sequence, whereas controllable refers to generating proteins having a desired property or a characteristic. | |
### Usage | |
For uncontrollable generation, input any starting amino acids and press 'Submit' without choosing a property. For controllable generation, choose any of the ten properties supported by this model before pressing 'Submit'. The important inference metrics will be displayed along with the generated output. | |
### Example | |
As an example, the protein sequence corresponding to the SARS-CoV-2 Spike Protein is given. This example does not run inference on the model as the sequence is hard-coded, however it showcases the flow of interacting with the demo. | |
''') | |
with gr.Row(): | |
input_text = gr.Textbox(label="Enter starting amino acids for protein sequence generation", placeholder="Example input: MK") | |
with gr.Row(): | |
seq_length = gr.Slider(2, 200, value=30, step=1, label="Length", info="Choose the number of tokens to generate") | |
classes = ["SARS-CoV-2 Spike Protein (example)", 'Tetratricopeptide-like helical domain superfamily', 'CheY-like superfamily', 'S-adenosyl-L-methionine-dependent methyltransferase superfamily', 'Thioredoxin-like superfamily', "Non-Hemolytic" ,"Soluble", "Oxidoreductase", "Transferase", "Hydrolase", "Lyase", "Isomerase", "Ligase", "Translocase"] | |
protein_property = gr.Dropdown(classes, label="Class") | |
with gr.Row(): | |
btn = gr.Button("Submit") | |
with gr.Row(): | |
output_text = gr.Textbox(label="Generated protein sequence will appear here") | |
with gr.Row(): | |
infer_time = gr.Number(label="Inference Time (s)", precision=2) | |
tokens_per_sec = gr.Number(label="Tokens/sec", precision=2) | |
# with gr.Row(): | |
# btn_vis = gr.Button("Visualize") | |
# with gr.Row(): | |
# structure_visual = gr.HTML() | |
btn.click(generate_protein_sequence, inputs=[input_text, seq_length, protein_property], outputs=[output_text, infer_time, tokens_per_sec]) | |
# btn_vis.click(predict_structure, inputs=output_text, outputs=[structure_visual]) | |
# Run the Gradio interface | |
demo.launch() |