MultivexAI commited on
Commit
d0307a5
·
verified ·
1 Parent(s): 8fabc2d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+
6
+ # Define the model architecture
7
+ class AddModel(nn.Module):
8
+ def __init__(self):
9
+ super(AddModel, self).__init__()
10
+ self.fc1 = nn.Linear(2, 32)
11
+ self.relu1 = nn.ReLU()
12
+ self.fc2 = nn.Linear(32, 64)
13
+ self.relu2 = nn.ReLU()
14
+ self.fc3 = nn.Linear(64, 1)
15
+
16
+ def forward(self, x):
17
+ x = self.relu1(self.fc1(x))
18
+ x = self.relu2(self.fc2(x))
19
+ x = self.fc3(x)
20
+ return x
21
+
22
+ # Load the model from a specified path
23
+ def load_model(model_path):
24
+ model = AddModel()
25
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
26
+ model.eval() # Set the model to evaluation mode
27
+ return model
28
+
29
+ # Function to make predictions
30
+ def predict_sum(model, x1, x2):
31
+ with torch.no_grad():
32
+ input_tensor = torch.tensor([[x1, x2]], dtype=torch.float32)
33
+ prediction = model(input_tensor)
34
+ return prediction.item()
35
+
36
+ # Streamlit app
37
+ def main():
38
+ st.title("Sum Predictor using Neural Network")
39
+
40
+ # Specify the path to your model
41
+ model_path = "./models/best_model.pth" # Update with your model path
42
+ if os.path.exists(model_path):
43
+ model = load_model(model_path)
44
+ st.success("Model loaded successfully.")
45
+
46
+ # User input for prediction
47
+ x1 = st.number_input("Enter the first number:", value=0.0)
48
+ x2 = st.number_input("Enter the second number:", value=0.0)
49
+
50
+ if st.button("Predict"):
51
+ predicted_sum = predict_sum(model, x1, x2)
52
+ st.write(f"The predicted sum of {x1} and {x2} is: {predicted_sum:.2f}")
53
+ else:
54
+ st.error("Model file not found. Please upload the model.")
55
+
56
+ if __name__ == "__main__":
57
+ main()