Spaces:
Sleeping
Sleeping
File size: 2,003 Bytes
d0307a5 533fab5 25018ea d0307a5 533fab5 25018ea 461b1c6 d0307a5 461b1c6 d0307a5 461b1c6 d0307a5 461b1c6 d0307a5 c66cdac d0307a5 c66cdac d0307a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import streamlit as st
import os
def safe_import(module_name):
try:
return __import__(module_name)
except ImportError:
return None
torch = safe_import('torch')
if torch is None:
st.error("Torch is not installed yet. Please wait a moment for the dependencies to install.")
st.stop()
import torch.nn as nn
# architecture
class AddModel(nn.Module):
def __init__(self):
super(AddModel, self).__init__()
self.fc1 = nn.Linear(2, 32)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(32, 64)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(64, 1)
def forward(self, x):
x = self.relu1(self.fc1(x))
x = self.relu2(self.fc2(x))
x = self.fc3(x)
return x
def load_model(model_path):
model = AddModel()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval() # evaluation mode
return model
def predict_sum(model, x1, x2):
with torch.no_grad():
input_tensor = torch.tensor([[x1, x2]], dtype=torch.float32)
prediction = model(input_tensor)
return prediction.item()
# Streamlit app
def main():
st.title("Sum Predictor using Neural Network")
model_path = "MA1T.pth" # Update with your model path if necessary
if os.path.exists(model_path):
model = load_model(model_path)
st.success("Model loaded successfully.")
x1 = st.number_input("Enter the first number:", value=0.0)
x2 = st.number_input("Enter the second number:", value=0.0)
if st.button("Predict"):
predicted_sum = predict_sum(model, x1, x2)
correct_sum = x1 + x2 # Calculate the correct answer
st.write(f"The predicted sum of {x1} and {x2} is: {predicted_sum:.2f}")
st.write(f"The correct sum of {x1} and {x2} is: {correct_sum:.2f}")
else:
st.error("Model file not found. Please upload the model.")
if __name__ == "__main__":
main() |