angusfung's picture
Update app.py
6c3a56c verified
"""
Kickstarter Success Prediction API
This module serves as the main FastAPI application for the Kickstarter Success Prediction service.
It provides endpoints for predicting the success probability of Kickstarter campaigns and
includes the Longformer embedding in the response for further analysis.
Author: Angus Fung
Date: April 2025
"""
import os
import json
import torch
import numpy as np
import logging
from pathlib import Path
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from src.model import KickstarterModel
from src.explainer import KickstarterExplainer
from src.ProcessOneSingleCampaign import CampaignProcessor
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Allow numpy.core.multiarray.scalar to be loaded safely
try:
import numpy.core.multiarray
torch.serialization.add_safe_globals([numpy.core.multiarray.scalar])
logger.info("Added numpy.core.multiarray.scalar to safe globals")
except Exception as e:
logger.warning(f"Failed to add safe globals: {str(e)}")
# Constants
NUMERICAL_FIELDS = [
'description_length', 'funding_goal', 'image_count', 'video_count',
'campaign_duration', 'previous_projects_count', 'previous_success_rate',
'previous_pledged', 'previous_funding_goal'
]
EMBEDDING_NAMES = [
'description_embedding', 'blurb_embedding', 'risk_embedding',
'subcategory_embedding', 'category_embedding', 'country_embedding'
]
# Global variables to store the model and processor
model = None
explainer = None
processor = None
device = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifecycle manager for the FastAPI application.
This function handles the startup and shutdown of the application,
managing resources like model loading and caching directories.
Args:
app: The FastAPI application instance
Yields:
None: Control is yielded back to the application while it's running
"""
# Load resources on startup
global model, explainer, processor, device
logger.info("Starting application initialization...")
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Create cache directories in /tmp which is writable
cache_dir = "/tmp/model_cache"
os.makedirs(cache_dir, exist_ok=True)
logger.info(f"Created cache directory at {cache_dir}")
# Set environment variables for model caching
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HF_HOME"] = cache_dir
# Load the CampaignProcessor with lazy loading
logger.info("Initializing CampaignProcessor...")
processor = CampaignProcessor(data=[], lazy_load=True)
# Load model with default parameters
model_path = "best_model.pth"
hidden_dim = 256
logger.info(f"Initializing KickstarterModel with hidden_dim={hidden_dim}...")
model = KickstarterModel(hidden_dim=hidden_dim)
if os.path.exists(model_path):
logger.info(f"Loading model weights from {model_path}...")
try:
# Using both approaches for maximum compatibility
# 1. Added safe globals above
# 2. Setting weights_only=False explicitly
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval() # Set model to evaluation mode
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model weights: {str(e)}")
logger.info("Continuing with uninitialized model weights.")
else:
logger.warning(f"Model file not found: {model_path}")
logger.info("Continuing with uninitialized model weights.")
# Initialize explainer
logger.info("Initializing KickstarterExplainer...")
explainer = KickstarterExplainer(model, device)
logger.info("Application initialization completed successfully!")
yield
# Clean up resources on shutdown
logger.info("Cleaning up resources...")
app = FastAPI(
title="Kickstarter Success Prediction API",
description="API for predicting the success of Kickstarter campaigns",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""
Root endpoint providing API information.
Returns:
dict: Basic API information and usage instructions
"""
return {
"message": "Kickstarter Success Prediction API",
"description": "Send a POST request to /predict with campaign data to get a prediction"
}
@app.post("/predict")
async def predict(request: Request):
"""
Prediction endpoint for Kickstarter campaign success.
This endpoint processes campaign data and returns:
- Success probability
- Predicted outcome (Success/Failure)
- SHAP values for feature importance explanation
- Longformer embedding of the campaign description
Args:
request: FastAPI request object containing campaign data as JSON
Returns:
JSONResponse: Prediction results and explanations
Raises:
HTTPException: If an error occurs during prediction
"""
try:
# Parse the incoming JSON data
logger.info("Received prediction request")
campaign_data = await request.json()
logger.info(f"Campaign data received: {json.dumps(campaign_data)[:100]}...")
# Process the campaign data
logger.info("Processing campaign data...")
processed_data = preprocess_raw_data(campaign_data)
logger.info("Campaign data processed successfully")
# Store the raw longformer embedding for returning in the response
raw_longformer_embedding = None
if 'description_embedding' in processed_data:
raw_longformer_embedding = processed_data['description_embedding']
# Process embeddings
logger.info("Preparing inputs for model...")
processed_inputs = {}
for embedding_name in EMBEDDING_NAMES:
if embedding_name in processed_data:
processed_inputs[embedding_name] = torch.tensor(processed_data[embedding_name], dtype=torch.float32).unsqueeze(0)
else:
# Use appropriate zero vector
dim = 768 if embedding_name == 'description_embedding' else \
384 if embedding_name in ['blurb_embedding', 'risk_embedding'] else \
100 if embedding_name in ['subcategory_embedding', 'country_embedding'] else 15
processed_inputs[embedding_name] = torch.zeros((1, dim), dtype=torch.float32)
logger.warning(f"Using zero tensor for missing embedding: {embedding_name}")
# Process numerical features
numerical_features = [processed_data.get(field, 0) for field in NUMERICAL_FIELDS]
processed_inputs['numerical_features'] = torch.tensor([numerical_features], dtype=torch.float32)
# Predict and explain
logger.info("Running prediction and generating explanations...")
prediction, shap_values = explainer.explain_prediction(processed_inputs)
logger.info(f"Prediction completed: {float(prediction):.4f}")
# Sort SHAP values by absolute magnitude
sorted_shap = dict(sorted(shap_values.items(), key=lambda x: abs(x[1]), reverse=True))
# Return the results
result = {
"success_probability": float(prediction),
"predicted_outcome": "Success" if prediction >= 0.52 else "Failure",
"shap_values": {k: float(v) for k, v in sorted_shap.items()}
}
# Add raw longformer embedding to result if available
if raw_longformer_embedding is not None:
result["longformer_embedding"] = raw_longformer_embedding
logger.info("Returning prediction results")
return JSONResponse(content=result)
except Exception as e:
logger.error(f"Error during prediction: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
def preprocess_raw_data(campaign_data):
"""
Preprocess raw campaign data using CampaignProcessor.
This function transforms raw text and numerical campaign data into
the format required by the prediction model, including:
- Text embeddings generation for description, blurb, and risks
- Logarithmic transformation of monetary values (funding goals, pledged amounts)
- Country name standardization (conversion to ISO alpha-2 codes)
- Category and country encoding
- Extraction and normalization of numerical features
Args:
campaign_data (dict): Raw campaign data with text and numerical features
Returns:
dict: Processed data with embeddings and normalized numerical features
Raises:
Exception: If preprocessing fails
"""
try:
# Process the single campaign
logger.info("Processing campaign with CampaignProcessor...")
# Log country conversion if present
if 'raw_country' in campaign_data:
country_name = campaign_data.get('raw_country', '')
if country_name:
logger.info(f"Found country in input data: '{country_name}' (will be converted to ISO alpha-2 code)")
# Map field names to the expected structure for the processor
# Make a deep copy to avoid modifying the original
import copy
prepared_data = copy.deepcopy(campaign_data)
# Log input values for debugging
logger.info(f"Input previous_projects_count: {prepared_data.get('previous_projects_count', 'N/A')}")
logger.info(f"Input previous_success_rate: {prepared_data.get('previous_success_rate', 'N/A')}")
logger.info(f"Input previous_pledged: {prepared_data.get('previous_pledged', 'N/A')}")
logger.info(f"Input previous_funding_goal: {prepared_data.get('previous_funding_goal', 'N/A')}")
# Special handling for success rate calculation
if 'previous_success_rate' in campaign_data and 'previous_projects_count' in campaign_data:
success_rate = float(campaign_data['previous_success_rate'])
projects_count = int(campaign_data['previous_projects_count'])
# Calculate successful projects from rate and count
if projects_count > 0:
prepared_data['previous_successful_projects'] = round(success_rate * projects_count)
logger.info(f"Calculated previous_successful_projects: {prepared_data['previous_successful_projects']} " +
f"from success rate: {success_rate} and count: {projects_count}")
# Now process the prepared data
processed_data = processor.process_campaign(prepared_data, idx=0)
# SELECTIVE OVERRIDE: Only override non-transformed numeric fields
# Fields that should NOT undergo logarithmic transformation
non_transformed_fields = [
'description_length', 'image_count', 'video_count',
'campaign_duration', 'previous_projects_count', 'previous_success_rate'
]
# Fields that SHOULD undergo logarithmic transformation
transformed_fields = [
'funding_goal', 'previous_funding_goal', 'previous_pledged'
]
# Override only the non-transformed fields if they exist in input
for field in non_transformed_fields:
if field in campaign_data:
processed_data[field] = campaign_data[field]
logger.info(f"Using provided value for {field}: {campaign_data[field]}")
# For transformed fields, check if the user explicitly wants to bypass transformation
for field in transformed_fields:
if field in campaign_data and campaign_data.get('bypass_transformation', False):
processed_data[field] = campaign_data[field]
logger.warning(
f"Bypassing logarithmic transformation for {field} as requested. "
"This may affect model performance."
)
elif field in campaign_data:
# Log that we're keeping the transformed value
logger.info(f"Using logarithmically transformed {field} value for better model performance.")
# Verify that the previous metrics are set correctly
logger.info(f"Final previous_projects_count: {processed_data.get('previous_projects_count', 'N/A')}")
logger.info(f"Final previous_success_rate: {processed_data.get('previous_success_rate', 'N/A')}")
logger.info(f"Final previous_pledged: {processed_data.get('previous_pledged', 'N/A')}")
logger.info(f"Final previous_funding_goal: {processed_data.get('previous_funding_goal', 'N/A')}")
logger.info("Preprocessing complete with numerical transformations applied")
return processed_data
except Exception as e:
logger.error(f"Error preprocessing raw data: {str(e)}", exc_info=True)
raise Exception(f"Error preprocessing raw data: {str(e)}")
@app.get("/debug")
async def debug():
"""
Debug endpoint for checking API status and component health.
This endpoint provides diagnostic information about the API's status,
model loading, connectivity, disk space, and other components.
Returns:
JSONResponse: Comprehensive diagnostic information
"""
global model, explainer, processor, device
# Check internet connectivity
internet_check = {"status": "unknown", "message": ""}
try:
import requests
response = requests.get("https://huggingface.co", timeout=5)
internet_check = {
"status": "connected" if response.status_code == 200 else "error",
"status_code": response.status_code,
"message": "Successfully connected to huggingface.co"
}
except Exception as e:
internet_check = {"status": "error", "message": f"Error connecting to internet: {str(e)}"}
# Try to load the tokenizer directly as a test
tokenizer_check = {"status": "unknown", "message": ""}
try:
from transformers import AutoTokenizer
cache_dir = "/tmp/model_cache"
os.makedirs(cache_dir, exist_ok=True)
test_model_name = "allenai/longformer-base-4096"
tokenizer = AutoTokenizer.from_pretrained(test_model_name, cache_dir=cache_dir)
tokenizer_check = {"status": "success", "message": f"Successfully loaded {test_model_name} tokenizer"}
except Exception as e:
tokenizer_check = {"status": "error", "message": f"Error loading tokenizer: {str(e)}"}
# Check disk space
disk_space = {"status": "unknown", "message": ""}
try:
import shutil
total, used, free = shutil.disk_usage("/tmp")
disk_space = {
"status": "ok",
"total_gb": total / (1024**3),
"used_gb": used / (1024**3),
"free_gb": free / (1024**3),
"percent_used": (used / total) * 100
}
except Exception as e:
disk_space = {"status": "error", "message": f"Error checking disk space: {str(e)}"}
debug_info = {
"api_status": "running",
"device": str(device),
"model_loaded": model is not None,
"explainer_loaded": explainer is not None,
"processor_loaded": processor is not None,
"cuda_available": torch.cuda.is_available(),
"environment_variables": {
"TRANSFORMERS_CACHE": os.environ.get("TRANSFORMERS_CACHE", "Not set"),
"HF_HOME": os.environ.get("HF_HOME", "Not set"),
},
"model_cache_exists": os.path.exists("/tmp/model_cache"),
"model_file_exists": os.path.exists("best_model.pth"),
"tmp_directory_writable": os.access("/tmp", os.W_OK),
"internet_connectivity": internet_check,
"tokenizer_test": tokenizer_check,
"disk_space": disk_space
}
return JSONResponse(content=debug_info)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)