|
""" |
|
Kickstarter Success Prediction API |
|
|
|
This module serves as the main FastAPI application for the Kickstarter Success Prediction service. |
|
It provides endpoints for predicting the success probability of Kickstarter campaigns and |
|
includes the Longformer embedding in the response for further analysis. |
|
|
|
Author: Angus Fung |
|
Date: April 2025 |
|
""" |
|
|
|
import os |
|
import json |
|
import torch |
|
import numpy as np |
|
import logging |
|
from pathlib import Path |
|
from contextlib import asynccontextmanager |
|
from fastapi import FastAPI, HTTPException, Request, Response |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
from src.model import KickstarterModel |
|
from src.explainer import KickstarterExplainer |
|
from src.ProcessOneSingleCampaign import CampaignProcessor |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[logging.StreamHandler()] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
import numpy.core.multiarray |
|
torch.serialization.add_safe_globals([numpy.core.multiarray.scalar]) |
|
logger.info("Added numpy.core.multiarray.scalar to safe globals") |
|
except Exception as e: |
|
logger.warning(f"Failed to add safe globals: {str(e)}") |
|
|
|
|
|
NUMERICAL_FIELDS = [ |
|
'description_length', 'funding_goal', 'image_count', 'video_count', |
|
'campaign_duration', 'previous_projects_count', 'previous_success_rate', |
|
'previous_pledged', 'previous_funding_goal' |
|
] |
|
|
|
EMBEDDING_NAMES = [ |
|
'description_embedding', 'blurb_embedding', 'risk_embedding', |
|
'subcategory_embedding', 'category_embedding', 'country_embedding' |
|
] |
|
|
|
|
|
model = None |
|
explainer = None |
|
processor = None |
|
device = None |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
""" |
|
Lifecycle manager for the FastAPI application. |
|
|
|
This function handles the startup and shutdown of the application, |
|
managing resources like model loading and caching directories. |
|
|
|
Args: |
|
app: The FastAPI application instance |
|
|
|
Yields: |
|
None: Control is yielded back to the application while it's running |
|
""" |
|
|
|
global model, explainer, processor, device |
|
|
|
logger.info("Starting application initialization...") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
cache_dir = "/tmp/model_cache" |
|
os.makedirs(cache_dir, exist_ok=True) |
|
logger.info(f"Created cache directory at {cache_dir}") |
|
|
|
|
|
os.environ["TRANSFORMERS_CACHE"] = cache_dir |
|
os.environ["HF_HOME"] = cache_dir |
|
|
|
|
|
logger.info("Initializing CampaignProcessor...") |
|
processor = CampaignProcessor(data=[], lazy_load=True) |
|
|
|
|
|
model_path = "best_model.pth" |
|
hidden_dim = 256 |
|
|
|
logger.info(f"Initializing KickstarterModel with hidden_dim={hidden_dim}...") |
|
model = KickstarterModel(hidden_dim=hidden_dim) |
|
|
|
if os.path.exists(model_path): |
|
logger.info(f"Loading model weights from {model_path}...") |
|
try: |
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.to(device) |
|
model.eval() |
|
logger.info("Model loaded successfully!") |
|
except Exception as e: |
|
logger.error(f"Error loading model weights: {str(e)}") |
|
logger.info("Continuing with uninitialized model weights.") |
|
else: |
|
logger.warning(f"Model file not found: {model_path}") |
|
logger.info("Continuing with uninitialized model weights.") |
|
|
|
|
|
logger.info("Initializing KickstarterExplainer...") |
|
explainer = KickstarterExplainer(model, device) |
|
|
|
logger.info("Application initialization completed successfully!") |
|
|
|
yield |
|
|
|
|
|
logger.info("Cleaning up resources...") |
|
|
|
app = FastAPI( |
|
title="Kickstarter Success Prediction API", |
|
description="API for predicting the success of Kickstarter campaigns", |
|
version="1.0.0", |
|
lifespan=lifespan, |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.get("/") |
|
async def root(): |
|
""" |
|
Root endpoint providing API information. |
|
|
|
Returns: |
|
dict: Basic API information and usage instructions |
|
""" |
|
return { |
|
"message": "Kickstarter Success Prediction API", |
|
"description": "Send a POST request to /predict with campaign data to get a prediction" |
|
} |
|
|
|
@app.post("/predict") |
|
async def predict(request: Request): |
|
""" |
|
Prediction endpoint for Kickstarter campaign success. |
|
|
|
This endpoint processes campaign data and returns: |
|
- Success probability |
|
- Predicted outcome (Success/Failure) |
|
- SHAP values for feature importance explanation |
|
- Longformer embedding of the campaign description |
|
|
|
Args: |
|
request: FastAPI request object containing campaign data as JSON |
|
|
|
Returns: |
|
JSONResponse: Prediction results and explanations |
|
|
|
Raises: |
|
HTTPException: If an error occurs during prediction |
|
""" |
|
try: |
|
|
|
logger.info("Received prediction request") |
|
campaign_data = await request.json() |
|
logger.info(f"Campaign data received: {json.dumps(campaign_data)[:100]}...") |
|
|
|
|
|
logger.info("Processing campaign data...") |
|
processed_data = preprocess_raw_data(campaign_data) |
|
logger.info("Campaign data processed successfully") |
|
|
|
|
|
raw_longformer_embedding = None |
|
if 'description_embedding' in processed_data: |
|
raw_longformer_embedding = processed_data['description_embedding'] |
|
|
|
|
|
logger.info("Preparing inputs for model...") |
|
processed_inputs = {} |
|
for embedding_name in EMBEDDING_NAMES: |
|
if embedding_name in processed_data: |
|
processed_inputs[embedding_name] = torch.tensor(processed_data[embedding_name], dtype=torch.float32).unsqueeze(0) |
|
else: |
|
|
|
dim = 768 if embedding_name == 'description_embedding' else \ |
|
384 if embedding_name in ['blurb_embedding', 'risk_embedding'] else \ |
|
100 if embedding_name in ['subcategory_embedding', 'country_embedding'] else 15 |
|
processed_inputs[embedding_name] = torch.zeros((1, dim), dtype=torch.float32) |
|
logger.warning(f"Using zero tensor for missing embedding: {embedding_name}") |
|
|
|
|
|
numerical_features = [processed_data.get(field, 0) for field in NUMERICAL_FIELDS] |
|
processed_inputs['numerical_features'] = torch.tensor([numerical_features], dtype=torch.float32) |
|
|
|
|
|
logger.info("Running prediction and generating explanations...") |
|
prediction, shap_values = explainer.explain_prediction(processed_inputs) |
|
logger.info(f"Prediction completed: {float(prediction):.4f}") |
|
|
|
|
|
sorted_shap = dict(sorted(shap_values.items(), key=lambda x: abs(x[1]), reverse=True)) |
|
|
|
|
|
result = { |
|
"success_probability": float(prediction), |
|
"predicted_outcome": "Success" if prediction >= 0.52 else "Failure", |
|
"shap_values": {k: float(v) for k, v in sorted_shap.items()} |
|
} |
|
|
|
|
|
if raw_longformer_embedding is not None: |
|
result["longformer_embedding"] = raw_longformer_embedding |
|
|
|
logger.info("Returning prediction results") |
|
return JSONResponse(content=result) |
|
|
|
except Exception as e: |
|
logger.error(f"Error during prediction: {str(e)}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") |
|
|
|
def preprocess_raw_data(campaign_data): |
|
""" |
|
Preprocess raw campaign data using CampaignProcessor. |
|
|
|
This function transforms raw text and numerical campaign data into |
|
the format required by the prediction model, including: |
|
- Text embeddings generation for description, blurb, and risks |
|
- Logarithmic transformation of monetary values (funding goals, pledged amounts) |
|
- Country name standardization (conversion to ISO alpha-2 codes) |
|
- Category and country encoding |
|
- Extraction and normalization of numerical features |
|
|
|
Args: |
|
campaign_data (dict): Raw campaign data with text and numerical features |
|
|
|
Returns: |
|
dict: Processed data with embeddings and normalized numerical features |
|
|
|
Raises: |
|
Exception: If preprocessing fails |
|
""" |
|
try: |
|
|
|
logger.info("Processing campaign with CampaignProcessor...") |
|
|
|
|
|
if 'raw_country' in campaign_data: |
|
country_name = campaign_data.get('raw_country', '') |
|
if country_name: |
|
logger.info(f"Found country in input data: '{country_name}' (will be converted to ISO alpha-2 code)") |
|
|
|
|
|
|
|
import copy |
|
prepared_data = copy.deepcopy(campaign_data) |
|
|
|
|
|
logger.info(f"Input previous_projects_count: {prepared_data.get('previous_projects_count', 'N/A')}") |
|
logger.info(f"Input previous_success_rate: {prepared_data.get('previous_success_rate', 'N/A')}") |
|
logger.info(f"Input previous_pledged: {prepared_data.get('previous_pledged', 'N/A')}") |
|
logger.info(f"Input previous_funding_goal: {prepared_data.get('previous_funding_goal', 'N/A')}") |
|
|
|
|
|
if 'previous_success_rate' in campaign_data and 'previous_projects_count' in campaign_data: |
|
success_rate = float(campaign_data['previous_success_rate']) |
|
projects_count = int(campaign_data['previous_projects_count']) |
|
|
|
if projects_count > 0: |
|
prepared_data['previous_successful_projects'] = round(success_rate * projects_count) |
|
logger.info(f"Calculated previous_successful_projects: {prepared_data['previous_successful_projects']} " + |
|
f"from success rate: {success_rate} and count: {projects_count}") |
|
|
|
|
|
processed_data = processor.process_campaign(prepared_data, idx=0) |
|
|
|
|
|
|
|
non_transformed_fields = [ |
|
'description_length', 'image_count', 'video_count', |
|
'campaign_duration', 'previous_projects_count', 'previous_success_rate' |
|
] |
|
|
|
|
|
transformed_fields = [ |
|
'funding_goal', 'previous_funding_goal', 'previous_pledged' |
|
] |
|
|
|
|
|
for field in non_transformed_fields: |
|
if field in campaign_data: |
|
processed_data[field] = campaign_data[field] |
|
logger.info(f"Using provided value for {field}: {campaign_data[field]}") |
|
|
|
|
|
for field in transformed_fields: |
|
if field in campaign_data and campaign_data.get('bypass_transformation', False): |
|
processed_data[field] = campaign_data[field] |
|
logger.warning( |
|
f"Bypassing logarithmic transformation for {field} as requested. " |
|
"This may affect model performance." |
|
) |
|
elif field in campaign_data: |
|
|
|
logger.info(f"Using logarithmically transformed {field} value for better model performance.") |
|
|
|
|
|
logger.info(f"Final previous_projects_count: {processed_data.get('previous_projects_count', 'N/A')}") |
|
logger.info(f"Final previous_success_rate: {processed_data.get('previous_success_rate', 'N/A')}") |
|
logger.info(f"Final previous_pledged: {processed_data.get('previous_pledged', 'N/A')}") |
|
logger.info(f"Final previous_funding_goal: {processed_data.get('previous_funding_goal', 'N/A')}") |
|
|
|
logger.info("Preprocessing complete with numerical transformations applied") |
|
return processed_data |
|
|
|
except Exception as e: |
|
logger.error(f"Error preprocessing raw data: {str(e)}", exc_info=True) |
|
raise Exception(f"Error preprocessing raw data: {str(e)}") |
|
|
|
@app.get("/debug") |
|
async def debug(): |
|
""" |
|
Debug endpoint for checking API status and component health. |
|
|
|
This endpoint provides diagnostic information about the API's status, |
|
model loading, connectivity, disk space, and other components. |
|
|
|
Returns: |
|
JSONResponse: Comprehensive diagnostic information |
|
""" |
|
global model, explainer, processor, device |
|
|
|
|
|
internet_check = {"status": "unknown", "message": ""} |
|
try: |
|
import requests |
|
response = requests.get("https://huggingface.co", timeout=5) |
|
internet_check = { |
|
"status": "connected" if response.status_code == 200 else "error", |
|
"status_code": response.status_code, |
|
"message": "Successfully connected to huggingface.co" |
|
} |
|
except Exception as e: |
|
internet_check = {"status": "error", "message": f"Error connecting to internet: {str(e)}"} |
|
|
|
|
|
tokenizer_check = {"status": "unknown", "message": ""} |
|
try: |
|
from transformers import AutoTokenizer |
|
cache_dir = "/tmp/model_cache" |
|
os.makedirs(cache_dir, exist_ok=True) |
|
test_model_name = "allenai/longformer-base-4096" |
|
tokenizer = AutoTokenizer.from_pretrained(test_model_name, cache_dir=cache_dir) |
|
tokenizer_check = {"status": "success", "message": f"Successfully loaded {test_model_name} tokenizer"} |
|
except Exception as e: |
|
tokenizer_check = {"status": "error", "message": f"Error loading tokenizer: {str(e)}"} |
|
|
|
|
|
disk_space = {"status": "unknown", "message": ""} |
|
try: |
|
import shutil |
|
total, used, free = shutil.disk_usage("/tmp") |
|
disk_space = { |
|
"status": "ok", |
|
"total_gb": total / (1024**3), |
|
"used_gb": used / (1024**3), |
|
"free_gb": free / (1024**3), |
|
"percent_used": (used / total) * 100 |
|
} |
|
except Exception as e: |
|
disk_space = {"status": "error", "message": f"Error checking disk space: {str(e)}"} |
|
|
|
debug_info = { |
|
"api_status": "running", |
|
"device": str(device), |
|
"model_loaded": model is not None, |
|
"explainer_loaded": explainer is not None, |
|
"processor_loaded": processor is not None, |
|
"cuda_available": torch.cuda.is_available(), |
|
"environment_variables": { |
|
"TRANSFORMERS_CACHE": os.environ.get("TRANSFORMERS_CACHE", "Not set"), |
|
"HF_HOME": os.environ.get("HF_HOME", "Not set"), |
|
}, |
|
"model_cache_exists": os.path.exists("/tmp/model_cache"), |
|
"model_file_exists": os.path.exists("best_model.pth"), |
|
"tmp_directory_writable": os.access("/tmp", os.W_OK), |
|
"internet_connectivity": internet_check, |
|
"tokenizer_test": tokenizer_check, |
|
"disk_space": disk_space |
|
} |
|
|
|
return JSONResponse(content=debug_info) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |