File size: 7,639 Bytes
37d70c1
 
 
 
 
 
bc153d9
37d70c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import mols2grid
from ipywidgets import interact, widgets
import textwrap
import moses
from transformers import EncoderDecoderModel, RobertaTokenizer

from moses.metrics.utils import QED, SA, logP, NP, weight, get_n_rings
from moses.utils import mapper, get_mol

# @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
from typing import List

from util import filter_dataframe


@st.cache(suppress_st_warning=True)
def load_models():
    # protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo")
    # mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
    model1 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenOne")
    model2 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenTwo")
    return model1, model2  # , protein_tokenizer, mol_tokenizer


def count(smiles_list: List[str]):
    counts = []
    for smiles in smiles_list:
        counts.append(len(smiles))

    return counts


def remove_none_elements(mol_list, smiles_list):
    filtered_mol_list = []
    filtered_smiles_list = []
    indices = []
    for i, element in enumerate(mol_list):
        if element is not None:
            filtered_mol_list.append(element)
        else:
            indices.append(i)
    removed_len = len(indices)

    for i in range(len(smiles_list)):
        if i not in indices:
            filtered_smiles_list.append(smiles_list.__getitem__(i))

    return filtered_mol_list, filtered_smiles_list, removed_len


def format_list_numbers(lst):
    for i, value in enumerate(lst):
        lst[i] = float("{:.3f}".format(value))


def generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, target, pool):
    protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo")
    mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
    # model1, model2, protein_tokenizer, mol_tokenizer = load_models()
    model1, model2 = load_models()
    inputs = protein_tokenizer(target, return_tensors="pt")

    model = model1 if model_name == 'WarmMolGenOne' else model2
    outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id,
                             eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id,
                             max_length=int(max_new_tokens), num_return_sequences=int(num_mols),
                             do_sample=do_sample, num_beams=num_beams)
    output_smiles = mol_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    st.write("### Generated Molecules")
    # mol_list = list(map(MolFromSmiles, output_smiles))
    # print(mol_list)
    # QED_scores = list(map(QED.qed, mol_list))
    # print(QED_scores)
    # st.write(output_smiles)
    mol_list = mapper(pool)(get_mol, output_smiles)
    mol_list, output_smiles, removed_len = remove_none_elements(mol_list, output_smiles)
    if removed_len != 0:
        st.write(f"#### Note that: {removed_len} numbers of generated invalid molecules are discarded.")

    QED_scores = mapper(pool)(QED, mol_list)
    SA_scores = mapper(pool)(SA, mol_list)
    logP_scores = mapper(pool)(logP, mol_list)
    NP_scores = mapper(pool)(NP, mol_list)
    weight_scores = mapper(pool)(weight, mol_list)

    format_list_numbers(QED_scores)
    format_list_numbers(SA_scores)
    format_list_numbers(logP_scores)
    format_list_numbers(NP_scores)
    format_list_numbers(weight_scores)

    df_smiles = pd.DataFrame(
        {'SMILES': output_smiles, "QED": QED_scores, "SA": SA_scores, "logP": logP_scores, "NP": NP_scores,
         "Weight": weight_scores})

    return df_smiles


def warm_molgen_demo():
    with st.form("my_form"):
        with st.sidebar:
            st.sidebar.subheader("Configurable parameters")

            model_name = st.sidebar.selectbox(
                "Model Selector",
                options=[
                    "WarmMolGenOne",
                    "WarmMolGenTwo",
                ],
                index=0,
            )

            num_mols = st.sidebar.number_input(
                "Number of generated molecules",
                min_value=0,
                max_value=20,
                value=10,
                help="The number of molecules to be generated.",
            )

            max_new_tokens = st.sidebar.number_input(
                "Maximum length",
                min_value=0,
                max_value=1024,
                value=128,
                help="The maximum length of the sequence to be generated.",
            )
            do_sample = st.sidebar.selectbox(
                "Sampling?",
                (True, False),
                help="Whether or not to use sampling; use beam decoding otherwise.",
            )
            target = st.text_area(
                "Target Sequence",
                "MENTENSVDSKSIKNLEPKIIHGSESMDSGISLDNSYKMDYPEMGLCIIINNKNFHKSTG",
            )
            generate_new_molecules = st.form_submit_button("Generate Molecules")

    num_beams = None if do_sample is True else int(num_mols)

    pool = 1

    if generate_new_molecules:
        st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams,
                                                 target, pool)
    if 'df' not in st.session_state:
        st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams,
                                                 target, pool)
    df = st.session_state.df

    filtered_df = filter_dataframe(df)
    if filtered_df.empty:
        st.markdown(
            """
            <span style='color: blue; font-size: 30px;'>No molecules were found with specified properties.</span>
        """,
            unsafe_allow_html=True
        )
    else:
        raw_html = mols2grid.display(filtered_df, height=1000)._repr_html_()
        components.html(raw_html, width=900, height=450, scrolling=True)

    st.markdown("## How to Generate")
    generation_code = f"""
    from transformers import EncoderDecoderModel, RobertaTokenizer
    protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/{model_name}")
    mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
    model = EncoderDecoderModel.from_pretrained("gokceuludogan/{model_name}")
    inputs = protein_tokenizer("{target}", return_tensors="pt")
    outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id,
                             eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id,
                             max_length={max_new_tokens}, num_return_sequences={num_mols}, do_sample={do_sample}, num_beams={num_beams})
    mol_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    """
    st.code(textwrap.dedent(generation_code))  # textwrap.dedent("".join("Halletcez")))


st.set_page_config(page_title="WarmMolGen Demo", page_icon="🔥", layout='wide')
st.markdown("# WarmMolGen Demo")
st.sidebar.header("WarmMolGen Demo")
st.markdown(
    """
    This demo illustrates WarmMolGen models' generation capabilities.
    Given a target sequence and a set of parameters, the models generate molecules targeting the given protein sequence.
    Please enter an input sequence below 👇  and configure parameters from the sidebar 👈 to generate molecules!
    See below for saving the output molecules and the code snippet generating them!
"""
)

warm_molgen_demo()