import numpy as np import tensorflow as tf import joblib import logging from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List # FastAPI initialization app = FastAPI(title="Universal Stock Prediction API") # Request model structure class StockPredictionInput(BaseModel): stock_symbol: str prices: List[List[float]] # Constants SEQ_LENGTH = 150 NUM_FEATURES = 10 # Load the trained model try: model = tf.keras.models.load_model("universal_stock_model_with_fundamentals.h5") logging.info("✅ Model loaded successfully.") except Exception as e: logging.error(f"❌ Failed to load model: {e}") raise RuntimeError("Model file not found or corrupted!") @app.get("/") def home(): return {"message": "Welcome to Stock Prediction API"} @app.post("/predict") async def predict_stock(data: StockPredictionInput): stock = data.stock_symbol.upper() logging.info(f"🔍 Processing request for stock: {stock}") try: scaler = joblib.load(f"scaler_{stock}.pkl") logging.info(f"✅ Scaler loaded for stock: {stock}") except Exception: logging.warning(f"⚠️ Scaler not found for {stock}. Using identity transform.") scaler = None # Convert input to NumPy array and validate shape prices_array = np.array(data.prices, dtype=np.float32) if prices_array.shape != (SEQ_LENGTH, NUM_FEATURES): raise HTTPException(status_code=400, detail=f"Expected input shape ({SEQ_LENGTH}, {NUM_FEATURES}) but got {prices_array.shape}") # Scale first 9 features features_to_scale = prices_array[:, :9] extra_feature = prices_array[:, 9:] if scaler: features_scaled = scaler.transform(features_to_scale) else: features_scaled = features_to_scale prices_scaled = np.concatenate([features_scaled, extra_feature], axis=1) prices_scaled = np.expand_dims(prices_scaled, axis=0) prediction = model.predict(prices_scaled) predicted_price = float(prediction[0][0]) if scaler: try: dummy_input = np.zeros((1, 9)) dummy_input[0, 0] = predicted_price inversed_values = scaler.inverse_transform(dummy_input) predicted_price = inversed_values[0, 0] except Exception as e: logging.warning(f"⚠️ Error in inverse transformation: {e}") return {"stock": stock, "predicted_price": round(predicted_price, 2)}