MultivexAI's picture
Update app.py
42d8fb3 verified
raw
history blame
1.62 kB
import streamlit as st
import torch
import torch.nn as nn
import os
# the model architecture
class AddModel(nn.Module):
def __init__(self):
super(AddModel, self).__init__()
self.fc1 = nn.Linear(2, 32)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(32, 64)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(64, 1)
def forward(self, x):
x = self.relu1(self.fc1(x))
x = self.relu2(self.fc2(x))
x = self.fc3(x)
return x
# load the model from a specified path
def load_model(model_path):
model = AddModel()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval() # Set the model to evaluation mode
return model
#predictions
def predict_sum(model, x1, x2):
with torch.no_grad():
input_tensor = torch.tensor([[x1, x2]], dtype=torch.float32)
prediction = model(input_tensor)
return prediction.item()
def main():
st.title("Sum Predictor using Neural Network")
model_path = "./models/MA1T.pth"
if os.path.exists(model_path):
model = load_model(model_path)
st.success("Model loaded successfully.")
x1 = st.number_input("Enter the first number:", value=0.0)
x2 = st.number_input("Enter the second number:", value=0.0)
if st.button("Predict"):
predicted_sum = predict_sum(model, x1, x2)
st.write(f"The predicted sum of {x1} and {x2} is: {predicted_sum:.2f}")
else:
st.error("Model file not found. Please upload the model.")
if __name__ == "__main__":
main()