|
import os |
|
import gradio as gr |
|
from prediction import run_sequence_prediction |
|
import torch |
|
import torchvision.transforms as T |
|
from celle.utils import process_image |
|
from celle_main import instantiate_from_config |
|
from omegaconf import OmegaConf |
|
from huggingface_hub import hf_hub_download |
|
|
|
def bold_predicted_letters(input_string: str, output_string: str) -> str: |
|
result = [] |
|
i = j = 0 |
|
input_string = input_string.upper() |
|
output_string = output_string.upper() |
|
|
|
while i < len(input_string): |
|
if input_string[i:i+6] == "<MASK>": |
|
start_index = i |
|
end_index = i + 6 |
|
while end_index < len(input_string) and input_string[end_index:end_index+6] == "<MASK>": |
|
end_index += 6 |
|
|
|
result.append("**" + output_string[j:j+(end_index-start_index)//6] + "**") |
|
i = end_index |
|
j += (end_index-start_index)//6 |
|
else: |
|
result.append(input_string[i]) |
|
i += 1 |
|
if input_string[i-1] != "<": |
|
j += 1 |
|
|
|
return "".join(result) |
|
|
|
def diff_texts(string): |
|
new_string = [] |
|
|
|
bold = False |
|
|
|
for idx, letter in enumerate(string): |
|
|
|
if letter == '*' and string[min(idx + 1, len(string)-1)] == '*' and bold == False: |
|
bold = True |
|
|
|
elif letter == '*' and string[min(idx + 1, len(string)-1)] == '*' and bold == True: |
|
bold = False |
|
if letter != '*': |
|
if bold : |
|
new_string.append((letter,'+')) |
|
else: |
|
new_string.append((letter, None)) |
|
|
|
return new_string |
|
|
|
class model: |
|
def __init__(self): |
|
self.model = None |
|
self.model_name = None |
|
self.model_path = None |
|
|
|
def gradio_demo(self, model_name, sequence_input, image): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if self.model_name != model_name: |
|
if self.model_path is not None: |
|
os.remove(self.model_path) |
|
del self.model |
|
self.model_name = model_name |
|
model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt") |
|
self.model_path = model_ckpt_path |
|
model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml") |
|
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml") |
|
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml") |
|
|
|
|
|
config = OmegaConf.load(model_config_path) |
|
if config["model"]["params"]["ckpt_path"] is None: |
|
config["model"]["params"]["ckpt_path"] = model_ckpt_path |
|
|
|
|
|
config["model"]["params"]["condition_model_path"] = None |
|
config["model"]["params"]["vqgan_model_path"] = None |
|
|
|
base_path = os.getcwd() |
|
|
|
os.chdir(os.path.dirname(model_ckpt_path)) |
|
|
|
|
|
self.model = instantiate_from_config(config.model).to(device) |
|
self.model = torch.compile(self.model,mode='max-autotune') |
|
|
|
os.chdir(base_path) |
|
|
|
|
|
if "Finetuned" in model_name: |
|
dataset = "OpenCell" |
|
|
|
else: |
|
dataset = "HPA" |
|
|
|
|
|
nucleus_image = image['image'].convert('L') |
|
protein_image = image['mask'].convert('L') |
|
|
|
to_tensor = T.ToTensor() |
|
nucleus_image = to_tensor(nucleus_image) |
|
protein_image = to_tensor(protein_image) |
|
stacked_images = torch.stack([nucleus_image, protein_image], dim=0) |
|
processed_images = process_image(stacked_images, dataset) |
|
|
|
nucleus_image = processed_images[0].unsqueeze(0) |
|
protein_image = processed_images[1].unsqueeze(0) |
|
protein_image = protein_image/torch.max(protein_image) |
|
|
|
formatted_predicted_sequence = run_sequence_prediction( |
|
sequence_input=sequence_input, |
|
nucleus_image=nucleus_image, |
|
protein_image=protein_image, |
|
model=self.model, |
|
device=device, |
|
) |
|
print('test2') |
|
formatted_predicted_sequence = formatted_predicted_sequence[0] |
|
formatted_predicted_sequence = formatted_predicted_sequence.replace("<pad>","") |
|
formatted_predicted_sequence = formatted_predicted_sequence.replace("<cls>","") |
|
formatted_predicted_sequence = formatted_predicted_sequence.replace("<eos>","") |
|
|
|
formatted_predicted_sequence = bold_predicted_letters(sequence_input, formatted_predicted_sequence) |
|
formatted_predicted_sequence = diff_texts(formatted_predicted_sequence) |
|
return T.ToPILImage()(protein_image[0,0]), T.ToPILImage()(nucleus_image[0,0]), formatted_predicted_sequence |
|
|
|
base_class = model() |
|
|
|
with gr.Blocks(theme='gradio/soft') as demo: |
|
gr.Markdown("## Inputs") |
|
gr.Markdown("Select the prediction model. **Note the first run may take ~2-3 minutes, but will take 3-4 seconds afterwards.**") |
|
gr.Markdown( |
|
"- ```CELL-E_2_HPA_2560``` is a good general purpose model for various cell types using ICC-IF." |
|
) |
|
gr.Markdown( |
|
"- ```CELL-E_2_OpenCell_2560``` is trained on OpenCell and is good more live-cell predictions on HEK cells." |
|
) |
|
with gr.Row(): |
|
model_name = gr.Dropdown( |
|
["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"], |
|
value="CELL-E_2_HPA_2560", |
|
label="Model Name", |
|
) |
|
with gr.Row(): |
|
gr.Markdown( |
|
"Input the desired amino acid sequence. GFP is shown below by default. The sequence must include ```<mask>``` for a prediction to be run." |
|
) |
|
|
|
with gr.Row(): |
|
sequence_input = gr.Textbox( |
|
value="M<mask><mask><mask><mask><mask>SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK", |
|
label="Sequence", |
|
) |
|
with gr.Row(): |
|
gr.Markdown( |
|
"Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image." |
|
) |
|
|
|
with gr.Row(equal_height=True): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nucleus_image = gr.ImageMask( |
|
label = "Nucleus Image", |
|
interactive = "True", |
|
image_mode = "L", |
|
brush_color = "#ffffff", |
|
type = "pil" |
|
) |
|
|
|
with gr.Row(): |
|
gr.Markdown("## Outputs") |
|
|
|
with gr.Row(equal_height=True): |
|
nucleus_crop = gr.Image( |
|
label="Nucleus Image (Crop)", |
|
image_mode="L", |
|
type="pil" |
|
) |
|
|
|
mask = gr.Image( |
|
label="Threshold Image", |
|
image_mode="L", |
|
type="pil" |
|
) |
|
with gr.Row(): |
|
gr.Markdown("Sequence predictions are show below.") |
|
|
|
with gr.Row(equal_height=True): |
|
|
|
predicted_sequence = gr.HighlightedText( |
|
label="Predicted Sequence", |
|
combine_adjacent=True, |
|
show_legend=False, |
|
color_map={"+": "green"}) |
|
|
|
|
|
with gr.Row(): |
|
button = gr.Button("Run Model") |
|
|
|
inputs = [model_name, sequence_input, nucleus_image] |
|
|
|
outputs = [mask, nucleus_crop, predicted_sequence] |
|
|
|
button.click(base_class.gradio_demo, inputs, outputs) |
|
|
|
demo.queue(max_size=1).launch() |
|
|