Update utils.py
Browse fileshanlding for multiple models
utils.py
CHANGED
@@ -6,17 +6,18 @@ def validate_sequence(sequence):
|
|
6 |
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
|
7 |
return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200
|
8 |
|
9 |
-
def load_model():
|
10 |
-
# Load
|
11 |
-
model = torch.load('
|
12 |
model.eval()
|
13 |
return model
|
14 |
|
|
|
15 |
def predict(model, sequence):
|
16 |
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
|
17 |
tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
|
18 |
output = model(**tokenized_input)
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
return predicted_label.item()
|
|
|
6 |
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
|
7 |
return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200
|
8 |
|
9 |
+
def load_model(model_name):
|
10 |
+
# Load the model based on the provided name
|
11 |
+
model = torch.load(f'{model_name}_model.pth', map_location=torch.device('cpu'))
|
12 |
model.eval()
|
13 |
return model
|
14 |
|
15 |
+
|
16 |
def predict(model, sequence):
|
17 |
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
|
18 |
tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
|
19 |
output = model(**tokenized_input)
|
20 |
+
probabilities = F.softmax(output.logits, dim=-1)
|
21 |
+
predicted_label = torch.argmax(probabilities, dim=-1)
|
22 |
+
confidence = probabilities.max().item() * 0.85
|
23 |
+
return predicted_label.item(), confidence
|