angusfung's picture
Initial setup with Longformer embedding feature
7812756
raw
history blame
11.1 kB
import os
import json
import torch
import numpy as np
import logging
from pathlib import Path
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from src.model import KickstarterModel
from src.explainer import KickstarterExplainer
from src.ProcessOneSingleCampaign import CampaignProcessor
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Allow numpy.core.multiarray.scalar to be loaded safely
try:
import numpy.core.multiarray
torch.serialization.add_safe_globals([numpy.core.multiarray.scalar])
logger.info("Added numpy.core.multiarray.scalar to safe globals")
except Exception as e:
logger.warning(f"Failed to add safe globals: {str(e)}")
# Constants
NUMERICAL_FIELDS = [
'description_length', 'funding_goal', 'image_count', 'video_count',
'campaign_duration', 'previous_projects_count', 'previous_success_rate',
'previous_pledged', 'previous_funding_goal'
]
EMBEDDING_NAMES = [
'description_embedding', 'blurb_embedding', 'risk_embedding',
'subcategory_embedding', 'category_embedding', 'country_embedding'
]
# Global variables to store the model and processor
model = None
explainer = None
processor = None
device = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load resources on startup
global model, explainer, processor, device
logger.info("Starting application initialization...")
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Create cache directories in /tmp which is writable
cache_dir = "/tmp/model_cache"
os.makedirs(cache_dir, exist_ok=True)
logger.info(f"Created cache directory at {cache_dir}")
# Set environment variables for model caching
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HF_HOME"] = cache_dir
# Load the CampaignProcessor with lazy loading
logger.info("Initializing CampaignProcessor...")
processor = CampaignProcessor(data=[], lazy_load=True)
# Load model with default parameters
model_path = "best_model.pth"
hidden_dim = 256
logger.info(f"Initializing KickstarterModel with hidden_dim={hidden_dim}...")
model = KickstarterModel(hidden_dim=hidden_dim)
if os.path.exists(model_path):
logger.info(f"Loading model weights from {model_path}...")
try:
# Using both approaches for maximum compatibility
# 1. Added safe globals above
# 2. Setting weights_only=False explicitly
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval() # Set model to evaluation mode
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model weights: {str(e)}")
logger.info("Continuing with uninitialized model weights.")
else:
logger.warning(f"Model file not found: {model_path}")
logger.info("Continuing with uninitialized model weights.")
# Initialize explainer
logger.info("Initializing KickstarterExplainer...")
explainer = KickstarterExplainer(model, device)
logger.info("Application initialization completed successfully!")
yield
# Clean up resources on shutdown
logger.info("Cleaning up resources...")
app = FastAPI(
title="Kickstarter Success Prediction API",
description="API for predicting the success of Kickstarter campaigns",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {
"message": "Kickstarter Success Prediction API",
"description": "Send a POST request to /predict with campaign data to get a prediction"
}
@app.post("/predict")
async def predict(request: Request):
try:
# Parse the incoming JSON data
logger.info("Received prediction request")
campaign_data = await request.json()
logger.info(f"Campaign data received: {json.dumps(campaign_data)[:100]}...")
# Process the campaign data
logger.info("Processing campaign data...")
processed_data = preprocess_raw_data(campaign_data)
logger.info("Campaign data processed successfully")
# Store the raw longformer embedding for returning in the response
raw_longformer_embedding = None
if 'description_embedding' in processed_data:
raw_longformer_embedding = processed_data['description_embedding']
# Process embeddings
logger.info("Preparing inputs for model...")
processed_inputs = {}
for embedding_name in EMBEDDING_NAMES:
if embedding_name in processed_data:
processed_inputs[embedding_name] = torch.tensor(processed_data[embedding_name], dtype=torch.float32).unsqueeze(0)
else:
# Use appropriate zero vector
dim = 768 if embedding_name == 'description_embedding' else \
384 if embedding_name in ['blurb_embedding', 'risk_embedding'] else \
100 if embedding_name in ['subcategory_embedding', 'country_embedding'] else 15
processed_inputs[embedding_name] = torch.zeros((1, dim), dtype=torch.float32)
logger.warning(f"Using zero tensor for missing embedding: {embedding_name}")
# Process numerical features
numerical_features = [processed_data.get(field, 0) for field in NUMERICAL_FIELDS]
processed_inputs['numerical_features'] = torch.tensor([numerical_features], dtype=torch.float32)
# Predict and explain
logger.info("Running prediction and generating explanations...")
prediction, shap_values = explainer.explain_prediction(processed_inputs)
logger.info(f"Prediction completed: {float(prediction):.4f}")
# Sort SHAP values by absolute magnitude
sorted_shap = dict(sorted(shap_values.items(), key=lambda x: abs(x[1]), reverse=True))
# Return the results
result = {
"success_probability": float(prediction),
"predicted_outcome": "Success" if prediction >= 0.5 else "Failure",
"shap_values": {k: float(v) for k, v in sorted_shap.items()}
}
# Add raw longformer embedding to result if available
if raw_longformer_embedding is not None:
result["longformer_embedding"] = raw_longformer_embedding
logger.info("Returning prediction results")
return JSONResponse(content=result)
except Exception as e:
logger.error(f"Error during prediction: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
def preprocess_raw_data(campaign_data):
"""Preprocess raw data using CampaignProcessor"""
try:
# Process the single campaign
logger.info("Processing campaign with CampaignProcessor...")
processed_data = processor.process_campaign(campaign_data, idx=0)
# Preserve existing numerical values from input if present
for field in NUMERICAL_FIELDS:
if field in campaign_data:
processed_data[field] = campaign_data[field]
logger.info(f"Using provided value for {field}: {campaign_data[field]}")
return processed_data
except Exception as e:
logger.error(f"Error preprocessing raw data: {str(e)}", exc_info=True)
raise Exception(f"Error preprocessing raw data: {str(e)}")
# Debugging endpoint to check the environment and loaded resources
@app.get("/debug")
async def debug():
"""Endpoint for checking the status of the API and its components"""
global model, explainer, processor, device
# Check internet connectivity
internet_check = {"status": "unknown", "message": ""}
try:
import requests
response = requests.get("https://huggingface.co", timeout=5)
internet_check = {
"status": "connected" if response.status_code == 200 else "error",
"status_code": response.status_code,
"message": "Successfully connected to huggingface.co"
}
except Exception as e:
internet_check = {"status": "error", "message": f"Error connecting to internet: {str(e)}"}
# Try to load the tokenizer directly as a test
tokenizer_check = {"status": "unknown", "message": ""}
try:
from transformers import AutoTokenizer
cache_dir = "/tmp/model_cache"
os.makedirs(cache_dir, exist_ok=True)
test_model_name = "allenai/longformer-base-4096"
tokenizer = AutoTokenizer.from_pretrained(test_model_name, cache_dir=cache_dir)
tokenizer_check = {"status": "success", "message": f"Successfully loaded {test_model_name} tokenizer"}
except Exception as e:
tokenizer_check = {"status": "error", "message": f"Error loading tokenizer: {str(e)}"}
# Check disk space
disk_space = {"status": "unknown", "message": ""}
try:
import shutil
total, used, free = shutil.disk_usage("/tmp")
disk_space = {
"status": "ok",
"total_gb": total / (1024**3),
"used_gb": used / (1024**3),
"free_gb": free / (1024**3),
"percent_used": (used / total) * 100
}
except Exception as e:
disk_space = {"status": "error", "message": f"Error checking disk space: {str(e)}"}
debug_info = {
"api_status": "running",
"device": str(device),
"model_loaded": model is not None,
"explainer_loaded": explainer is not None,
"processor_loaded": processor is not None,
"cuda_available": torch.cuda.is_available(),
"environment_variables": {
"TRANSFORMERS_CACHE": os.environ.get("TRANSFORMERS_CACHE", "Not set"),
"HF_HOME": os.environ.get("HF_HOME", "Not set"),
},
"model_cache_exists": os.path.exists("/tmp/model_cache"),
"model_file_exists": os.path.exists("best_model.pth"),
"tmp_directory_writable": os.access("/tmp", os.W_OK),
"internet_connectivity": internet_check,
"tokenizer_test": tokenizer_check,
"disk_space": disk_space
}
return JSONResponse(content=debug_info)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)