File size: 4,627 Bytes
2f1b920
cc7fb29
2f1b920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db49d51
2f1b920
 
 
 
 
 
204cf40
2f1b920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b0eb6d
2f1b920
 
 
 
f81e0aa
 
bec715b
 
2f1b920
 
 
 
 
 
 
 
 
 
 
2760d61
2f1b920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, EsmForSequenceClassification
from transformers import set_seed
import torch
import torch.nn as nn
import warnings
from tqdm import tqdm
import gradio as gr

warnings.filterwarnings('ignore')
device = "cpu"
model_checkpoint1 = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint1)


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert1 = EsmForSequenceClassification.from_pretrained(model_checkpoint1, num_labels=3000)#3000
        # for param in self.bert1.parameters():
        #     param.requires_grad = False
        self.bn1 = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(64)
        self.relu = nn.LeakyReLU()
        self.fc1 = nn.Linear(3000, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.output_layer = nn.Linear(64, 2)
        self.dropout = nn.Dropout(0.3)  # 0.3

    def forward(self, x):
        with torch.no_grad():
            bert_output = self.bert1(input_ids=x['input_ids'],
                                     attention_mask=x['attention_mask'])
        # output_feature = bert_output["logits"]
        # print(output_feature.size())
        # output_feature = self.bn1(self.fc1(output_feature))
        # output_feature = self.bn2(self.fc1(output_feature))
        # output_feature = self.relu(self.bn3(self.fc3(output_feature)))
        # output_feature = self.dropout(self.output_layer(output_feature))
        output_feature = self.dropout(bert_output["logits"])
        output_feature = self.dropout(self.relu(self.bn1(self.fc1(output_feature))))
        output_feature = self.dropout(self.relu(self.bn2(self.fc2(output_feature))))
        output_feature = self.dropout(self.relu(self.bn3(self.fc3(output_feature))))
        output_feature = self.dropout(self.output_layer(output_feature))
        # return torch.sigmoid(output_feature),output_feature
        return torch.softmax(output_feature, dim=1)


def AMP(test_sequences, model):
    # 保持 AMP 函数不变,只处理传入的 test_sequences 数据
    max_len = 18
    test_data = tokenizer(test_sequences, max_length=max_len, padding="max_length", truncation=True,
                          return_tensors='pt')
    model = model.to(device)
    model.eval()
    out_probability = []
    with torch.no_grad():
        predict = model(test_data)
        out_probability.extend(np.max(np.array(predict.cpu()), axis=1).tolist())
        test_argmax = np.argmax(predict.cpu(), axis=1).tolist()
    id2str = {0: "non-AMP", 1: "AMP"}
    return id2str[test_argmax[0]], out_probability[0]


def classify_sequence(sequence):
    # Check if the sequence is a valid amino acid sequence and has a length of at least 3
    valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY")
    sequence = sequence.upper()

    if all(aa in valid_amino_acids for aa in sequence) and len(sequence) >= 3:
        result, probability = AMP(sequence, model)
        return "yes" if result == "AMP" else "no"
    else:
        return "Invalid Sequence"

# 加载模型
model = MyModel()
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')),strict=False)


if __name__ == "__main__":
    with gr.Blocks() as demo:
        gr.Markdown(
            """
        ![image](https://github.com/wrab12/diff-amp/blob/main/111.png)
        """)
        gr.Markdown(
            """

        # Welcome to Antimicrobial Peptide Recognition Model
        This is an antimicrobial peptide recognition model derived from Diff-AMP, which is a branch of a comprehensive system integrating generation, recognition, and optimization. In this recognition model, you can simply input a sequence, and it will predict whether it is an antimicrobial peptide. Due to limited website capacity, we can only perform simple predictions. 
        If you require large-scale computations, please contact my email at [email protected]. Feel free to reach out if you have any questions or inquiries.

            """)

        # 添加示例输入和输出
        examples = [
            ["KLLKKLLKLWKKLLKKLK"],
            ["FLGLLFHGVHHVGKWIHGLIHGHH"],
            ["GLMSTLKGAATNAAVTLLNKLQCKLTGTC"]
        ]

        # 创建 Gradio 接口并应用美化样式和示例
        iface = gr.Interface(
            fn=classify_sequence,
            inputs="text",
            outputs="text",
            title="AMP Sequence Detector",
            examples=examples
        )


    demo.launch()