MultivexAI commited on
Commit
461b1c6
·
verified ·
1 Parent(s): d7739bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import torch.nn as nn
4
  import os
5
 
6
- # the model architecture
7
  class AddModel(nn.Module):
8
  def __init__(self):
9
  super(AddModel, self).__init__()
@@ -19,24 +19,23 @@ class AddModel(nn.Module):
19
  x = self.fc3(x)
20
  return x
21
 
22
- # load the model from a specified path
23
  def load_model(model_path):
24
  model = AddModel()
25
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
26
- model.eval() # Set the model to evaluation mode
27
  return model
28
 
29
- #predictions
30
  def predict_sum(model, x1, x2):
31
  with torch.no_grad():
32
  input_tensor = torch.tensor([[x1, x2]], dtype=torch.float32)
33
  prediction = model(input_tensor)
34
  return prediction.item()
35
 
 
36
  def main():
37
  st.title("Sum Predictor using Neural Network")
38
-
39
- model_path = "./models/MA1T.pth"
40
  if os.path.exists(model_path):
41
  model = load_model(model_path)
42
  st.success("Model loaded successfully.")
 
3
  import torch.nn as nn
4
  import os
5
 
6
+ # architecture
7
  class AddModel(nn.Module):
8
  def __init__(self):
9
  super(AddModel, self).__init__()
 
19
  x = self.fc3(x)
20
  return x
21
 
 
22
  def load_model(model_path):
23
  model = AddModel()
24
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
25
+ model.eval() # evaluation mode
26
  return model
27
 
 
28
  def predict_sum(model, x1, x2):
29
  with torch.no_grad():
30
  input_tensor = torch.tensor([[x1, x2]], dtype=torch.float32)
31
  prediction = model(input_tensor)
32
  return prediction.item()
33
 
34
+ # Streamlit app
35
  def main():
36
  st.title("Sum Predictor using Neural Network")
37
+
38
+ model_path = "MA1T.pth" # Update with your model path if necessary
39
  if os.path.exists(model_path):
40
  model = load_model(model_path)
41
  st.success("Model loaded successfully.")