stock-predictor / app.py
hariharan220's picture
Create app.py
3baf523 verified
import numpy as np
import tensorflow as tf
import joblib
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
# FastAPI initialization
app = FastAPI(title="Universal Stock Prediction API")
# Request model structure
class StockPredictionInput(BaseModel):
stock_symbol: str
prices: List[List[float]]
# Constants
SEQ_LENGTH = 150
NUM_FEATURES = 10
# Load the trained model
try:
model = tf.keras.models.load_model("universal_stock_model_with_fundamentals.h5")
logging.info("βœ… Model loaded successfully.")
except Exception as e:
logging.error(f"❌ Failed to load model: {e}")
raise RuntimeError("Model file not found or corrupted!")
@app.get("/")
def home():
return {"message": "Welcome to Stock Prediction API"}
@app.post("/predict")
async def predict_stock(data: StockPredictionInput):
stock = data.stock_symbol.upper()
logging.info(f"πŸ” Processing request for stock: {stock}")
try:
scaler = joblib.load(f"scaler_{stock}.pkl")
logging.info(f"βœ… Scaler loaded for stock: {stock}")
except Exception:
logging.warning(f"⚠️ Scaler not found for {stock}. Using identity transform.")
scaler = None
# Convert input to NumPy array and validate shape
prices_array = np.array(data.prices, dtype=np.float32)
if prices_array.shape != (SEQ_LENGTH, NUM_FEATURES):
raise HTTPException(status_code=400, detail=f"Expected input shape ({SEQ_LENGTH}, {NUM_FEATURES}) but got {prices_array.shape}")
# Scale first 9 features
features_to_scale = prices_array[:, :9]
extra_feature = prices_array[:, 9:]
if scaler:
features_scaled = scaler.transform(features_to_scale)
else:
features_scaled = features_to_scale
prices_scaled = np.concatenate([features_scaled, extra_feature], axis=1)
prices_scaled = np.expand_dims(prices_scaled, axis=0)
prediction = model.predict(prices_scaled)
predicted_price = float(prediction[0][0])
if scaler:
try:
dummy_input = np.zeros((1, 9))
dummy_input[0, 0] = predicted_price
inversed_values = scaler.inverse_transform(dummy_input)
predicted_price = inversed_values[0, 0]
except Exception as e:
logging.warning(f"⚠️ Error in inverse transformation: {e}")
return {"stock": stock, "predicted_price": round(predicted_price, 2)}