Spaces:
Sleeping
Sleeping
File size: 4,615 Bytes
0c8cec9 fda141d 0c8cec9 b93c8a7 0c8cec9 49831fb 0c8cec9 8e4db71 0c8cec9 fda141d b93c8a7 fda141d 0c8cec9 fda141d 0c8cec9 b93c8a7 0c8cec9 fda141d 0c8cec9 b93c8a7 0c8cec9 b93c8a7 0c8cec9 fda141d 0c8cec9 8e4db71 0c8cec9 e3cb71b 0c8cec9 f368d83 0c8cec9 fda141d 0c8cec9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.examples.protein_solubility.task import ProteinSolubilityTask
from mammal.keys import (
CLS_PRED,
ENCODER_INPUTS_STR,
SCORES,
)
from mammal.model import Mammal
from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
data_preprocessing = ProteinSolubilityTask.data_preprocessing
process_model_output = ProteinSolubilityTask.process_model_output
class PsTask(MammalTask):
def __init__(self, model_dict):
super().__init__(name="Protein Solubility", model_dict=model_dict)
self.description = "Protein Solubility (PS)"
self.examples = {
"protein_seq": "LLQTGIHVRVSQPSL",
}
self.markup_text = """
# Mammal based protein solubility estimation
Given the protein sequence, estimate if it's water-soluble.
"""
def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
"""convert sample_inputs to sample_dict including creating a proper prompt
Args:
sample_inputs (dict): dictionary containing the inputs to the model
model_holder (MammalObjectBroker): model holder
Returns:
dict: sample_dict for feeding into model
"""
sample_dict = dict(sample_inputs) # shallow copy
sample_dict = data_preprocessing(
sample_dict=sample_dict,
protein_sequence_key="protein_seq",
tokenizer_op=model_holder.tokenizer_op,
device=model_holder.model.device,
)
return sample_dict
def run_model(self, sample_dict, model: Mammal):
# Generate Prediction
batch_dict = model.generate(
[sample_dict],
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=5,
)
return batch_dict
def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp) -> list:
"""
Extract predicted class and scores
"""
ans_dict = process_model_output(
tokenizer_op=tokenizer_op,
decoder_output=batch_dict[CLS_PRED][0],
decoder_output_scores=batch_dict[SCORES][0],
)
ans = [
tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]),
ans_dict["pred"],
ans_dict["not_normalized_scores"].item(),
ans_dict["normalized_scores"].item(),
]
return ans
def create_and_run_prompt(self, model_name, protein_seq):
model_holder = self.model_dict[model_name]
sample_inputs = {
"protein_seq": protein_seq,
}
sample_dict = self.crate_sample_dict(
sample_inputs=sample_inputs, model_holder=model_holder
)
prompt = sample_dict[ENCODER_INPUTS_STR]
batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
res = prompt, *self.decode_output(
batch_dict, tokenizer_op=model_holder.tokenizer_op
)
return res
def create_demo(self, model_name_widget):
with gr.Group() as demo:
gr.Markdown(self.markup_text)
with gr.Row():
protein_textbox = gr.Textbox(
label="Protein sequence",
# info="standard",
interactive=True,
lines=3,
value=self.examples["protein_seq"],
)
with gr.Row():
run_mammal = gr.Button(
"Run Mammal prompt for protein solubility estimation",
variant="primary",
)
with gr.Row():
prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
with gr.Row():
decoded = gr.Textbox(label="Mammal output")
predicted_class = gr.Textbox(label="Mammal prediction")
with gr.Column():
non_norm_score = gr.Number(label="Non normalized score")
norm_score = gr.Number(label="normalized score")
run_mammal.click(
fn=self.create_and_run_prompt,
inputs=[model_name_widget, protein_textbox],
outputs=[
prompt_box,
decoded,
predicted_class,
non_norm_score,
norm_score,
],
)
demo.visible = False
return demo
|