basilboy commited on
Commit
76b0555
·
verified ·
1 Parent(s): 121b388

Update utils.py

Browse files

hanlding for multiple models

Files changed (1) hide show
  1. utils.py +8 -7
utils.py CHANGED
@@ -6,17 +6,18 @@ def validate_sequence(sequence):
6
  valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
7
  return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200
8
 
9
- def load_model():
10
- # Load your model as before
11
- model = torch.load('solubility_model.pth', map_location=torch.device('cpu'))
12
  model.eval()
13
  return model
14
 
 
15
  def predict(model, sequence):
16
  tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
17
  tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
18
  output = model(**tokenized_input)
19
- logits = output.logits # Extract logits
20
- probabilities = F.softmax(logits, dim=-1) # Apply softmax to convert logits to probabilities
21
- predicted_label = torch.argmax(probabilities, dim=-1) # Get the predicted label
22
- return predicted_label.item() # Return the label as a Python integer
 
6
  valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
7
  return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200
8
 
9
+ def load_model(model_name):
10
+ # Load the model based on the provided name
11
+ model = torch.load(f'{model_name}_model.pth', map_location=torch.device('cpu'))
12
  model.eval()
13
  return model
14
 
15
+
16
  def predict(model, sequence):
17
  tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
18
  tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
19
  output = model(**tokenized_input)
20
+ probabilities = F.softmax(output.logits, dim=-1)
21
+ predicted_label = torch.argmax(probabilities, dim=-1)
22
+ confidence = probabilities.max().item() * 0.85
23
+ return predicted_label.item(), confidence