Abs6187's picture
Upload 12 files
c5ec08c verified
"""
Financial Fraud Detection System - TechMatrix Solvers
Team Members:
- Abhay Gupta
- Jay Kumar
- Kripanshu Gupta
- Bhumika Patel
A comprehensive fraud detection system using machine learning algorithms.
"""
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
import os
import pickle
import time
import warnings
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
roc_auc_score, confusion_matrix, classification_report, roc_curve
)
from imblearn.over_sampling import SMOTE
# Suppress warnings
warnings.filterwarnings('ignore')
# Set page configuration
st.set_page_config(
page_title="TechMatrix Fraud Detection System",
page_icon="πŸ”’",
layout="wide",
initial_sidebar_state="collapsed"
)
# Custom CSS for better styling
st.markdown("""
<style>
/* Main theme colors */
:root {
--primary: #2E7D32;
--primary-light: #81C784;
--primary-dark: #1B5E20;
--secondary: #1976D2;
--secondary-light: #64B5F6;
--text-on-primary: #FFFFFF;
--text-primary: #212121;
--text-secondary: #757575;
--background: #F5F5F5;
--card-bg: #FFFFFF;
--success: #43A047;
--warning: #FFA000;
--error: #D32F2F;
--info: #1976D2;
}
/* Base styles */
.main-header {
font-size: 2.8rem;
color: var(--primary);
text-align: center;
margin-bottom: 1.5rem;
font-weight: 700;
background: linear-gradient(90deg, var(--primary), var(--secondary));
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
padding: 0.5rem 0;
}
.sub-header {
font-size: 2rem;
color: var(--primary-dark);
margin-top: 2rem;
margin-bottom: 1rem;
font-weight: 600;
border-bottom: 2px solid var(--primary-light);
padding-bottom: 0.5rem;
}
.metric-card {
text-align: center;
padding: 1.2rem;
border-radius: 0.8rem;
background-color: rgba(46, 125, 50, 0.1);
transition: transform 0.3s ease;
border-left: 4px solid var(--primary);
}
.metric-card:hover {
transform: translateY(-5px);
background-color: rgba(46, 125, 50, 0.15);
}
.metric-value {
font-size: 2.5rem;
font-weight: 700;
color: var(--primary);
margin: 0.5rem 0;
}
.metric-label {
font-size: 1rem;
color: var(--text-secondary);
margin-bottom: 0.5rem;
}
div[data-testid="stMetric"] {
background-color: rgba(46, 125, 50, 0.1);
padding: 1rem;
border-radius: 0.8rem;
border-left: 4px solid var(--primary);
transition: transform 0.3s ease;
}
div[data-testid="stMetric"]:hover {
transform: translateY(-5px);
background-color: rgba(46, 125, 50, 0.15);
}
div[data-testid="stMetric"] > div {
gap: 0.2rem;
}
div[data-testid="stMetric"] label {
color: var(--text-secondary) !important;
}
div[data-testid="stMetric"] .css-1wivap2 {
color: var(--primary) !important;
}
.stButton > button {
background-color: var(--primary);
color: var(--text-on-primary);
border-radius: 0.5rem;
padding: 0.5rem 1rem;
font-weight: 600;
border: none;
transition: all 0.3s ease;
}
.stButton > button:hover {
background-color: var(--primary-dark);
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
transform: translateY(-2px);
}
.stProgress > div > div > div {
background-color: var(--primary);
background-image: linear-gradient(45deg,
rgba(255,255,255,.15) 25%,
transparent 25%,
transparent 50%,
rgba(255,255,255,.15) 50%,
rgba(255,255,255,.15) 75%,
transparent 75%,
transparent
);
background-size: 1rem 1rem;
animation: progress-animation 1s linear infinite;
}
@keyframes progress-animation {
0% { background-position: 0 0; }
100% { background-position: 1rem 0; }
}
.success-text {
color: var(--success);
font-weight: bold;
}
.warning-text {
color: var(--warning);
font-weight: bold;
}
.error-text {
color: var(--error);
font-weight: bold;
}
.info-text {
color: var(--info);
font-weight: bold;
}
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
}
.animate-fade-in {
animation: fadeIn 0.8s ease-in-out;
}
[data-testid="stSidebarNav"] ul li:nth-child(2) {
display: none;
}
.dataframe {
border-collapse: collapse;
border: none;
font-size: 0.9rem;
}
.dataframe th {
background-color: var(--primary-light);
color: var(--text-primary);
padding: 0.5rem;
text-align: left;
}
.dataframe td {
padding: 0.5rem;
border-bottom: 1px solid #eee;
}
.dataframe tr:hover {
background-color: #f5f5f5;
}
.stSlider > div > div {
background-color: var(--primary-light);
}
.stSelectbox > div > div {
background-color: var(--card-bg);
border-radius: 0.5rem;
border: 1px solid var(--primary-light);
}
@keyframes pulse {
0% { opacity: 0.6; }
50% { opacity: 1; }
100% { opacity: 0.6; }
}
.loading-pulse {
animation: pulse 1.5s infinite ease-in-out;
}
</style>
""", unsafe_allow_html=True)
# Create necessary directories
os.makedirs("data", exist_ok=True)
os.makedirs("models", exist_ok=True)
# Initialize session state
if 'current_page' not in st.session_state:
st.session_state['current_page'] = 'home'
if 'data' not in st.session_state:
st.session_state['data'] = None
if 'preprocessed_data' not in st.session_state:
st.session_state['preprocessed_data'] = None
if 'engineered_data' not in st.session_state:
st.session_state['engineered_data'] = None
if 'target_col' not in st.session_state:
st.session_state['target_col'] = 'Class'
if 'trained_models' not in st.session_state:
st.session_state['trained_models'] = {}
if 'predictions' not in st.session_state:
st.session_state['predictions'] = None
if 'progress' not in st.session_state:
st.session_state['progress'] = 0
# Main title
st.markdown("<div class='animate-fade-in'><h1 class='main-header'>TechMatrix Fraud Detection System</h1></div>", unsafe_allow_html=True)
# Team information
st.markdown("""
<div style='text-align: center; margin-bottom: 2rem;'>
<h3>Team TechMatrix Solvers</h3>
<p>Abhay Gupta | Jay Kumar | Kripanshu Gupta | Bhumika Patel</p>
</div>
""", unsafe_allow_html=True)
# Home Page
if st.session_state['current_page'] == 'home':
# Introduction section
st.markdown("<div class='animate-fade-in'><h2 class='sub-header'>Welcome to TechMatrix Fraud Detection System</h2></div>", unsafe_allow_html=True)
col1, col2 = st.columns([2, 1])
with col1:
st.markdown("""
Our advanced fraud detection system leverages cutting-edge machine learning algorithms to identify and prevent fraudulent transactions in real-time.
### Understanding Financial Fraud
Financial fraud encompasses various deceptive practices aimed at unauthorized acquisition of funds or assets.
Our system specifically addresses:
- Credit card transaction fraud
- Identity theft incidents
- Account compromise attempts
- Suspicious transaction patterns
### Machine Learning Implementation
Our system employs sophisticated machine learning models that analyze transaction patterns and behavioral data.
The models are trained on historical fraud data and continuously updated to adapt to emerging fraud patterns.
### System Advantages:
- **Real-time Monitoring**: Instant detection of suspicious activities
- **Scalable Processing**: Efficient handling of large transaction volumes
- **Pattern Recognition**: Advanced detection of complex fraud patterns
- **Risk Assessment**: Probability-based fraud scoring system
""")
with col2:
# Create a unique visualization of the fraud detection process
fig = go.Figure()
# Create a hexagonal flow diagram
angles = np.linspace(0, 2*np.pi, 6, endpoint=False)
x = 0.5 + 0.4 * np.cos(angles)
y = 0.5 + 0.4 * np.sin(angles)
# Add connecting lines with gradient effect
for i in range(len(angles)):
next_i = (i + 1) % len(angles)
fig.add_trace(go.Scatter(
x=[x[i], x[next_i]],
y=[y[i], y[next_i]],
mode='lines',
line=dict(
color='rgba(46, 125, 50, 0.5)',
width=2,
dash='dot'
),
showlegend=False
))
# Add nodes with updated colors and labels
node_labels = ['Input Data', 'Validation', 'Processing', 'Analysis', 'Detection', 'Action']
node_colors = ['#2E7D32', '#43A047', '#81C784', '#1976D2', '#64B5F6', '#D32F2F']
for i in range(len(angles)):
fig.add_trace(go.Scatter(
x=[x[i]],
y=[y[i]],
mode='markers+text',
marker=dict(
size=30,
color=node_colors[i],
symbol='hexagon'
),
text=node_labels[i],
textposition="middle center",
textfont=dict(color='white', size=12),
showlegend=False
))
# Add title in the center with updated styling
fig.add_trace(go.Scatter(
x=[0.5],
y=[0.5],
mode='text',
text='Fraud<br>Detection<br>Pipeline',
textposition="middle center",
textfont=dict(
color='#212121',
size=14,
family='Arial, bold'
),
showlegend=False
))
fig.update_layout(
height=400,
width=400,
margin=dict(l=0, r=0, t=0, b=0),
xaxis=dict(
showgrid=False,
zeroline=False,
showticklabels=False,
range=[0, 1]
),
yaxis=dict(
showgrid=False,
zeroline=False,
showticklabels=False,
range=[0, 1]
),
plot_bgcolor='rgba(0,0,0,0)'
)
st.plotly_chart(fig)
# Workflow section
st.markdown("<div class='animate-fade-in'><h2 class='sub-header'>System Workflow</h2></div>", unsafe_allow_html=True)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.markdown("### 1. Data Ingestion")
st.markdown("Secure upload and validation of transaction data in CSV format.")
st.image("https://cdn-icons-png.flaticon.com/512/4208/4208479.png", width=100)
with col2:
st.markdown("### 2. Data Processing")
st.markdown("Advanced data cleaning and preparation for analysis.")
st.image("https://cdn-icons-png.flaticon.com/512/1875/1875627.png", width=100)
with col3:
st.markdown("### 3. Feature Extraction")
st.markdown("Intelligent feature engineering and pattern recognition.")
st.image("https://cdn-icons-png.flaticon.com/512/2103/2103633.png", width=100)
with col4:
st.markdown("### 4. Model Deployment")
st.markdown("Real-time fraud detection and risk assessment.")
st.image("https://cdn-icons-png.flaticon.com/512/2103/2103658.png", width=100)
# Sample visualizations section
st.markdown("<div class='animate-fade-in'><h2 class='sub-header'>System Analytics</h2></div>", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
# Sample ROC curve with improved styling
fig = go.Figure()
fpr = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
tpr_lr = [0, 0.4, 0.55, 0.68, 0.75, 0.8, 0.85, 0.9, 0.94, 0.98, 1.0]
tpr_rf = [0, 0.5, 0.65, 0.78, 0.85, 0.88, 0.91, 0.95, 0.97, 0.99, 1.0]
tpr_xgb = [0, 0.55, 0.7, 0.8, 0.87, 0.9, 0.93, 0.96, 0.98, 0.99, 1.0]
fig.add_trace(go.Scatter(
x=fpr,
y=tpr_lr,
mode='lines',
name='Logistic Regression (AUC = 0.85)',
line=dict(color='#2E7D32', width=3)
))
fig.add_trace(go.Scatter(
x=fpr,
y=tpr_rf,
mode='lines',
name='Random Forest (AUC = 0.92)',
line=dict(color='#1976D2', width=3)
))
fig.add_trace(go.Scatter(
x=fpr,
y=tpr_xgb,
mode='lines',
name='XGBoost (AUC = 0.94)',
line=dict(color='#D32F2F', width=3)
))
fig.add_trace(go.Scatter(
x=[0, 1],
y=[0, 1],
mode='lines',
name='Random',
line=dict(dash='dash', color='#757575', width=2)
))
fig.update_layout(
title='Model Performance Comparison',
xaxis_title='False Positive Rate',
yaxis_title='True Positive Rate',
legend=dict(x=0.01, y=0.99),
width=600,
height=400,
template='plotly_white',
margin=dict(l=40, r=40, t=40, b=40)
)
st.plotly_chart(fig)
with col2:
# Sample feature importance with improved styling
features = ['Transaction Amount', 'Time of Day', 'Merchant Category', 'Location', 'Transaction Frequency',
'Device Used', 'IP Address', 'Account Age', 'Previous Fraud Flag', 'Transaction Type']
importance = [0.23, 0.18, 0.15, 0.12, 0.09, 0.08, 0.06, 0.04, 0.03, 0.02]
fig = px.bar(
x=importance,
y=features,
orientation='h',
title='Feature Importance Analysis',
labels={'x': 'Importance Score', 'y': 'Feature'},
color=importance,
color_continuous_scale=['#2E7D32', '#43A047', '#81C784']
)
fig.update_layout(
width=600,
height=400,
template='plotly_white',
margin=dict(l=40, r=40, t=40, b=40)
)
st.plotly_chart(fig)
# Get started button
st.markdown("<div style='text-align: center; margin-top: 2rem;'>", unsafe_allow_html=True)
if st.button("Get Started", key="get_started", help="Begin the fraud detection process"):
st.session_state['current_page'] = 'upload'
st.rerun()
st.markdown("</div>", unsafe_allow_html=True)
# Data Upload Page
elif st.session_state['current_page'] == 'upload':
st.markdown("<div class='animate-fade-in'><h2 class='sub-header'>Step 1: Data Ingestion</h2></div>", unsafe_allow_html=True)
# File uploader with size limit warning
st.markdown("""
### Secure Data Upload
Upload your transaction data securely in CSV format. The system supports the following:
- Transaction details (amount, timestamp, location, etc.)
- Target column for fraud classification (default: 'Class' with 0 for normal, 1 for fraud)
- **Maximum file size: 200 MB**
For testing purposes, you can use the [Credit Card Fraud Detection dataset](https://www.kaggle.com/mlg-ulb/creditcardfraud) from Kaggle.
### Data Requirements:
- CSV format with UTF-8 encoding
- No missing values in critical fields
- Proper date/time formatting
- Numeric values for transaction amounts
""")
uploaded_file = st.file_uploader(
"Upload transaction data (CSV file)",
type="csv",
help="Maximum file size: 200 MB"
)
if uploaded_file is not None:
# Check file size (200 MB limit)
file_details = {"FileName": uploaded_file.name, "FileType": uploaded_file.type}
# Read the file into a buffer to check its size
file_buffer = uploaded_file.getvalue()
file_size_mb = len(file_buffer) / (1024 * 1024)
if file_size_mb > 200:
st.error(f"File size exceeds the 200 MB limit. Your file is {file_size_mb:.2f} MB. Please upload a smaller file.")
st.stop()
else:
st.info(f"File size: {file_size_mb:.2f} MB")
# Load data with progress bar
progress_bar = st.progress(0)
status_text = st.empty()
status_text.text("Initializing data ingestion...")
progress_bar.progress(25)
time.sleep(0.3)
try:
# Use BytesIO to avoid loading the file twice
from io import BytesIO
df = pd.read_csv(BytesIO(file_buffer))
st.session_state['data'] = df
progress_bar.progress(50)
status_text.text("Validating data structure...")
time.sleep(0.3)
progress_bar.progress(75)
status_text.text("Preparing data preview...")
time.sleep(0.3)
progress_bar.progress(100)
status_text.text("Data ingestion completed!")
time.sleep(0.3)
status_text.empty()
progress_bar.empty()
# Show basic data info
st.success(f"Data ingested successfully! Shape: {df.shape[0]} rows and {df.shape[1]} columns")
col1, col2 = st.columns(2)
with col1:
st.subheader("Data Preview")
st.dataframe(df.head())
with col2:
st.subheader("Data Structure")
# Display data types and missing values
data_info = pd.DataFrame({
'Data Type': df.dtypes,
'Non-Null Count': df.count(),
'Missing Values': df.isnull().sum(),
'Unique Values': [df[col].nunique() for col in df.columns]
})
st.dataframe(data_info)
# Check for target column
if 'Class' in df.columns:
fraud_count = df['Class'].sum()
total_count = len(df)
fraud_percentage = (fraud_count / total_count) * 100
st.info(f"Target column 'Class' detected with {fraud_count} fraud cases ({fraud_percentage:.2f}% of data)")
else:
st.warning("No 'Class' column detected. You'll need to specify the target column in the next step.")
except Exception as e:
st.error(f"Error during data ingestion: {str(e)}")
st.info("Please ensure the file is a valid CSV with proper formatting.")
# Navigation buttons
col1, col2 = st.columns([1, 5])
with col1:
if st.button("← Back to Home", key="back_to_home"):
st.session_state['current_page'] = 'home'
st.rerun()
with col2:
if st.session_state['data'] is not None:
if st.button("Continue to Data Processing β†’", key="to_preprocess"):
st.session_state['current_page'] = 'preprocess'
st.rerun()
# Data Preprocessing Page
elif st.session_state['current_page'] == 'preprocess':
st.markdown("<div class='animate-fade-in'><h2 class='sub-header'>Step 2: Data Processing</h2></div>", unsafe_allow_html=True)
if st.session_state['data'] is None:
st.error("No data found. Please upload data first.")
if st.button("Go back to Data Ingestion"):
st.session_state['current_page'] = 'upload'
st.rerun()
else:
df = st.session_state['data']
st.markdown("""
### Advanced Data Processing
Enhance your data quality through our comprehensive processing pipeline. The system will:
- Handle missing values intelligently
- Remove statistical outliers
- Normalize numerical features
- Balance class distribution
Select the processing options below to customize the pipeline.
""")
# Target column selection
if 'Class' in df.columns:
target_col = 'Class'
st.info(f"Target column 'Class' detected with values: {df[target_col].unique()}")
else:
target_col = st.selectbox("Select the target column (fraud indicator)", df.columns)
st.session_state['target_col'] = target_col
# Preprocessing options
st.subheader("Processing Options")
col1, col2 = st.columns(2)
with col1:
handle_missing = st.checkbox("Handle Missing Values", value=True,
help="Fill missing numerical values with mean and categorical values with mode")
remove_outliers = st.checkbox("Remove Outliers", value=False,
help="Remove extreme values that might affect model performance")
with col2:
normalize_data = st.checkbox("Normalize Data", value=True,
help="Scale numerical features to have zero mean and unit variance")
balance_classes = st.checkbox("Balance Classes", value=True,
help="Handle class imbalance using SMOTE in the training phase")
# Handle missing values
if st.button("Process Data"):
with st.spinner("Processing data..."):
# Create a copy of the dataframe
df_processed = df.copy()
# Progress bar
progress_bar = st.progress(0)
status_text = st.empty()
# Handle missing values
if handle_missing:
status_text.text("Processing missing values...")
progress_bar.progress(25)
time.sleep(0.3)
for col in df_processed.columns:
if df_processed[col].dtype in ['int64', 'float64']:
df_processed[col] = df_processed[col].fillna(df_processed[col].mean())
else:
df_processed[col] = df_processed[col].fillna(df_processed[col].mode()[0])
# Remove outliers if selected
if remove_outliers:
status_text.text("Processing outliers...")
progress_bar.progress(50)
time.sleep(0.3)
# Only apply to numerical columns
num_cols = df_processed.select_dtypes(include=['int64', 'float64']).columns
for col in num_cols:
if col != target_col: # Don't remove outliers from target column
Q1 = df_processed[col].quantile(0.25)
Q3 = df_processed[col].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 3 * IQR
upper_bound = Q3 + 3 * IQR
df_processed = df_processed[(df_processed[col] >= lower_bound) &
(df_processed[col] <= upper_bound)]
# Store the processed data
status_text.text("Finalizing data processing...")
progress_bar.progress(100)
time.sleep(0.3)
st.session_state['preprocessed_data'] = df_processed
status_text.empty()
progress_bar.empty()
st.success("Data processing completed!")
# Show class distribution
if target_col in df_processed.columns:
st.subheader("Class Distribution After Processing")
col1, col2 = st.columns(2)
with col1:
# Create pie chart with improved styling
labels = ['Normal', 'Fraud']
values = [len(df_processed[df_processed[target_col] == 0]),
len(df_processed[df_processed[target_col] == 1])]
fig = px.pie(
values=values,
names=labels,
title='Transaction Distribution',
color_discrete_sequence=['#2E7D32', '#D32F2F'],
hole=0.4
)
fig.update_traces(textposition='inside', textinfo='percent+label')
fig.update_layout(
template='plotly_white',
margin=dict(l=20, r=20, t=30, b=20)
)
st.plotly_chart(fig)
with col2:
# Calculate statistics
fraud_count = df_processed[target_col].sum()
total_count = len(df_processed)
fraud_percentage = (fraud_count / total_count) * 100
st.metric("Total Transactions", f"{total_count:,}")
st.metric("Fraud Transactions", f"{fraud_count:,}")
st.metric("Fraud Percentage", f"{fraud_percentage:.2f}%")
if fraud_percentage < 1:
st.warning("Your dataset is highly imbalanced. Class balancing will be applied during model training.")
# Navigation buttons
col1, col2 = st.columns([1, 5])
with col1:
if st.button("← Back to Upload", key="back_to_upload"):
st.session_state['current_page'] = 'upload'
st.rerun()
with col2:
if st.session_state['preprocessed_data'] is not None:
if st.button("Continue to Feature Extraction β†’", key="to_feature_eng"):
st.session_state['current_page'] = 'feature_engineering'
st.rerun()
# Feature Engineering Page
elif st.session_state['current_page'] == 'feature_engineering':
st.markdown("<div class='animate-fade-in'><h2 class='sub-header'>Step 3: Feature Extraction</h2></div>", unsafe_allow_html=True)
if st.session_state['preprocessed_data'] is None:
st.error("No processed data found. Please complete data processing first.")
if st.button("Go back to Data Processing"):
st.session_state['current_page'] = 'preprocess'
st.rerun()
else:
df_processed = st.session_state['preprocessed_data']
target_col = st.session_state['target_col']
st.markdown("""
### Intelligent Feature Extraction
Enhance your fraud detection capabilities through advanced feature engineering. Our system provides:
- Time-based pattern analysis
- Transaction amount profiling
- Behavioral feature extraction
- Cross-feature interaction analysis
Select the features to extract below to optimize your model's performance.
""")
# Feature engineering options
st.subheader("Feature Extraction Options")
col1, col2 = st.columns(2)
with col1:
create_time_features = st.checkbox("Time-based Features", value=True,
help="Extract temporal patterns and behavioral indicators")
create_amount_features = st.checkbox("Amount-based Features", value=True,
help="Generate transaction amount profiles and risk indicators")
with col2:
create_aggregations = st.checkbox("Aggregation Features", value=False,
help="Create aggregated metrics for transaction patterns")
create_interactions = st.checkbox("Interaction Features", value=False,
help="Generate cross-feature interactions for complex pattern detection")
# Apply feature engineering
if st.button("Extract Features"):
with st.spinner("Extracting features..."):
# Create a copy of the dataframe
df_engineered = df_processed.copy()
# Progress bar
progress_bar = st.progress(0)
status_text = st.empty()
# Time-based features
if create_time_features and 'Time' in df_engineered.columns:
status_text.text("Extracting temporal features...")
progress_bar.progress(25)
time.sleep(0.3)
# Hour of day
df_engineered['Hour'] = (df_engineered['Time'] / 3600) % 24
# Flag for transactions during odd hours (midnight to 5 AM)
df_engineered['Odd_Hour'] = ((df_engineered['Hour'] >= 0) & (df_engineered['Hour'] < 5)).astype(int)
# Part of day
df_engineered['Part_of_Day'] = pd.cut(
df_engineered['Hour'],
bins=[0, 6, 12, 18, 24],
labels=['Night', 'Morning', 'Afternoon', 'Evening']
)
# Amount-based features
if create_amount_features and 'Amount' in df_engineered.columns:
status_text.text("Extracting amount-based features...")
progress_bar.progress(50)
time.sleep(0.3)
# Log transform for amount (to handle skewed distribution)
df_engineered['Log_Amount'] = np.log1p(df_engineered['Amount'])
# Flag for high-value transactions (top 5%)
threshold = df_engineered['Amount'].quantile(0.95)
df_engineered['High_Value'] = (df_engineered['Amount'] > threshold).astype(int)
# Amount bins
df_engineered['Amount_Bin'] = pd.qcut(
df_engineered['Amount'],
q=5,
labels=['Very Low', 'Low', 'Medium', 'High', 'Very High']
)
# Aggregation features
if create_aggregations:
status_text.text("Generating aggregation features...")
progress_bar.progress(75)
time.sleep(0.3)
# Check if there's a card ID or similar column
potential_id_cols = [col for col in df_engineered.columns if 'id' in col.lower() or 'card' in col.lower()]
if potential_id_cols:
id_col = potential_id_cols[0]
# Number of transactions per card
tx_count = df_engineered.groupby(id_col).size().reset_index(name='Tx_Count')
df_engineered = df_engineered.merge(tx_count, on=id_col, how='left')
# Average transaction amount per card
if 'Amount' in df_engineered.columns:
avg_amount = df_engineered.groupby(id_col)['Amount'].mean().reset_index(name='Avg_Amount')
df_engineered = df_engineered.merge(avg_amount, on=id_col, how='left')
# Transaction amount deviation from average
df_engineered['Amount_Deviation'] = df_engineered['Amount'] - df_engineered['Avg_Amount']
# Interaction features
if create_interactions:
status_text.text("Generating interaction features...")
progress_bar.progress(90)
time.sleep(0.3)
# Only create interactions between numerical features
num_cols = df_engineered.select_dtypes(include=['int64', 'float64']).columns
num_cols = [col for col in num_cols if col != target_col and 'id' not in col.lower()]
# Limit to a few important features to avoid explosion of features
if len(num_cols) > 3:
num_cols = num_cols[:3]
# Create interactions
for i in range(len(num_cols)):
for j in range(i+1, len(num_cols)):
col1_name = num_cols[i]
col2_name = num_cols[j]
df_engineered[f'{col1_name}_x_{col2_name}'] = df_engineered[col1_name] * df_engineered[col2_name]
# Convert categorical columns to one-hot encoding
cat_cols = df_engineered.select_dtypes(include=['object', 'category']).columns
for col in cat_cols:
dummies = pd.get_dummies(df_engineered[col], prefix=col, drop_first=True)
df_engineered = pd.concat([df_engineered, dummies], axis=1)
df_engineered.drop(columns=[col], inplace=True)
# Store the engineered data
status_text.text("Finalizing feature extraction...")
progress_bar.progress(100)
time.sleep(0.3)
st.session_state['engineered_data'] = df_engineered
status_text.empty()
progress_bar.empty()
st.success("Feature extraction completed!")
# Show correlation with target
if target_col in df_engineered.columns:
st.subheader("Feature Correlation Analysis")
# Get correlation with target
corr_with_target = df_engineered.corr()[target_col].sort_values(ascending=False)
# Remove target's correlation with itself
corr_with_target = corr_with_target.drop(target_col)
# Get top 10 positive and negative correlations
top_pos = corr_with_target.head(10)
top_neg = corr_with_target.tail(10).iloc[::-1] # Reverse to show strongest negative first
col1, col2 = st.columns(2)
with col1:
# Plot top positive correlations with improved styling
fig = px.bar(
x=top_pos.values,
y=top_pos.index,
orientation='h',
title='Top Positive Correlations with Fraud',
labels={'x': 'Correlation', 'y': 'Feature'},
color=top_pos.values,
color_continuous_scale=['#2E7D32', '#43A047', '#81C784']
)
fig.update_layout(
height=400,
template='plotly_white',
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig)
with col2:
# Plot top negative correlations with improved styling
fig = px.bar(
x=top_neg.values,
y=top_neg.index,
orientation='h',
title='Top Negative Correlations with Fraud',
labels={'x': 'Correlation', 'y': 'Feature'},
color=top_neg.values,
color_continuous_scale=['#81C784', '#43A047', '#2E7D32']
)
fig.update_layout(
height=400,
template='plotly_white',
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig)
# Correlation heatmap
st.subheader("Feature Correlation Matrix")
# Get top correlated features
corr_matrix = df_engineered.corr()
top_corr_features = corr_with_target.abs().sort_values(ascending=False).head(15).index
# Create heatmap with selected features
top_corr_matrix = corr_matrix.loc[top_corr_features, top_corr_features]
fig = px.imshow(
top_corr_matrix,
text_auto='.2f',
color_continuous_scale=['#2E7D32', 'white', '#1976D2'],
title='Feature Correlation Matrix'
)
fig.update_layout(
height=600,
width=800,
template='plotly_white',
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig)
# Feature distributions
st.subheader("Feature Distribution Analysis")
# Select a feature to visualize
numeric_cols = df_engineered.select_dtypes(include=['int64', 'float64']).columns
numeric_cols = [col for col in numeric_cols if col != target_col]
selected_feature = st.selectbox("Select feature to analyze", numeric_cols)
# Create distribution plot with improved styling
fig = px.histogram(
df_engineered,
x=selected_feature,
color=target_col,
marginal="box",
opacity=0.7,
barmode="overlay",
color_discrete_map={0: "#2E7D32", 1: "#D32F2F"},
labels={target_col: "Class", "0": "Normal", "1": "Fraud"}
)
fig.update_layout(
title=f"Distribution Analysis of {selected_feature}",
template='plotly_white',
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig)
# Navigation buttons
col1, col2 = st.columns([1, 5])
with col1:
if st.button("← Back to Processing", key="back_to_preprocess"):
st.session_state['current_page'] = 'preprocess'
st.rerun()
with col2:
if st.session_state['engineered_data'] is not None:
if st.button("Continue to Model Training β†’", key="to_model_training"):
st.session_state['current_page'] = 'model_training'
st.rerun()
# Model Training Page
elif st.session_state['current_page'] == 'model_training':
st.markdown("<div class='animate-fade-in'><h2 class='sub-header'>Step 4: Model Training</h2></div>", unsafe_allow_html=True)
if st.session_state['engineered_data'] is None:
st.error("No engineered data found. Please complete feature extraction first.")
if st.button("Go back to Feature Extraction"):
st.session_state['current_page'] = 'feature_engineering'
st.rerun()
else:
df_engineered = st.session_state['engineered_data']
target_col = st.session_state['target_col']
st.markdown("""
### Advanced Model Training
Train sophisticated machine learning models for fraud detection. Our system provides:
- Multiple model architectures
- Automated hyperparameter optimization
- Cross-validation for robust evaluation
- Performance metrics visualization
Select your preferred models and training parameters below.
""")
# Training options
st.subheader("Training Configuration")
col1, col2 = st.columns(2)
with col1:
# Data sampling for faster training - default to a smaller sample for speed
use_sample = st.checkbox("Use Data Sample for Faster Training", value=True,
help="Use a sample of the data to speed up training (recommended for large datasets)")
if use_sample:
sample_size = st.slider("Sample Size (%)", min_value=10, max_value=100, value=20,
help="Percentage of data to use for training")
# Test size
test_size = st.slider("Test Set Size (%)", min_value=10, max_value=50, value=20,
help="Percentage of data to use for testing")
# Class balancing
use_smote = st.checkbox("Apply SMOTE for Class Balancing", value=True,
help="Use SMOTE to handle class imbalance")
with col2:
# Model selection
st.write("Select Models to Train:")
train_lr = st.checkbox("Logistic Regression", value=True)
train_rf = st.checkbox("Random Forest", value=True)
train_xgb = st.checkbox("XGBoost", value=True)
# Advanced options - reduced default values for faster training
show_advanced = st.checkbox("Show Advanced Options", value=False)
if show_advanced:
# Number of estimators for tree models - reduced for speed
n_estimators = st.slider("Number of Estimators", min_value=10, max_value=200, value=50,
help="Number of trees for Random Forest and XGBoost (higher = more accurate but slower)")
# Max depth for tree models
max_depth = st.slider("Max Tree Depth", min_value=3, max_value=15, value=6,
help="Maximum depth of trees (higher = more complex model)")
# Start training
if st.button("Train Models"):
with st.spinner("Training models..."):
status_container = st.empty()
status_container.markdown(
'<div class="loading-pulse">Training in progress... This may take a few minutes.</div>',
unsafe_allow_html=True
)
# Prepare data for training
X = df_engineered.drop(columns=[target_col])
y = df_engineered[target_col]
# Use sample if selected
if use_sample and sample_size < 100:
sample_frac = sample_size / 100
# Stratified sampling to maintain class distribution
X_sample = pd.DataFrame()
y_sample = pd.Series()
for class_value in y.unique():
X_class = X[y == class_value]
y_class = y[y == class_value]
n_samples = int(len(X_class) * sample_frac)
indices = np.random.choice(X_class.index, size=n_samples, replace=False)
X_sample = pd.concat([X_sample, X_class.loc[indices]])
y_sample = pd.concat([y_sample, y_class.loc[indices]])
X = X_sample
y = y_sample
# Progress bar
progress_bar = st.progress(0)
status_text = st.empty()
status_text.text("Preparing training data...")
progress_bar.progress(10)
# Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size/100, random_state=42, stratify=y
)
status_text.text("Scaling features...")
progress_bar.progress(20)
# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Handle class imbalance with SMOTE if selected
if use_smote:
status_text.text("Applying SMOTE for class balancing...")
progress_bar.progress(30)
smote = SMOTE(random_state=42)
X_train_resampled, y_train_resampled = smote.fit_resample(X_train_scaled, y_train)
else:
X_train_resampled, y_train_resampled = X_train_scaled, y_train
# Save preprocessor
with open("models/scaler.pkl", "wb") as f:
pickle.dump(scaler, f)
# Save feature columns
with open("models/feature_columns.pkl", "wb") as f:
pickle.dump(X.columns.tolist(), f)
# Initialize results list
results = []
trained_models = {}
# Train selected models
if train_lr:
status_text.text("Training Logistic Regression...")
progress_bar.progress(40)
# Train Logistic Regression
lr_model = LogisticRegression(max_iter=1000, class_weight='balanced')
lr_model.fit(X_train_resampled, y_train_resampled)
# Make predictions
y_pred = lr_model.predict(X_test_scaled)
y_pred_proba = lr_model.predict_proba(X_test_scaled)[:, 1]
# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_pred_proba)
cm = confusion_matrix(y_test, y_pred)
# Store results
lr_results = {
'model_name': 'Logistic Regression',
'model': lr_model,
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1_score': f1,
'auc': auc,
'confusion_matrix': cm,
'y_test': y_test,
'y_pred_proba': y_pred_proba
}
results.append(lr_results)
trained_models['lr'] = lr_model
# Save model
with open("models/logistic_regression.pkl", "wb") as f:
pickle.dump(lr_model, f)
if train_rf:
status_text.text("Training Random Forest...")
progress_bar.progress(60)
# Get parameters - use smaller values for speed
n_est = n_estimators if show_advanced else 50
m_depth = max_depth if show_advanced else 6
# Train Random Forest
rf_model = RandomForestClassifier(
n_estimators=n_est,
max_depth=m_depth,
class_weight='balanced',
random_state=42
)
rf_model.fit(X_train_resampled, y_train_resampled)
# Make predictions
y_pred = rf_model.predict(X_test_scaled)
y_pred_proba = rf_model.predict_proba(X_test_scaled)[:, 1]
# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_pred_proba)
cm = confusion_matrix(y_test, y_pred)
# Store results
rf_results = {
'model_name': 'Random Forest',
'model': rf_model,
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1_score': f1,
'auc': auc,
'confusion_matrix': cm,
'y_test': y_test,
'y_pred_proba': y_pred_proba
}
results.append(rf_results)
trained_models['rf'] = rf_model
# Save model
with open("models/random_forest.pkl", "wb") as f:
pickle.dump(rf_model, f)
if train_xgb:
status_text.text("Training XGBoost...")
progress_bar.progress(80)
# Get parameters - use smaller values for speed
n_est = n_estimators if show_advanced else 50
m_depth = max_depth if show_advanced else 6
# Train XGBoost
xgb_model = XGBClassifier(
n_estimators=n_est,
max_depth=m_depth,
scale_pos_weight=10,
random_state=42,
use_label_encoder=False,
eval_metric='logloss'
)
xgb_model.fit(X_train_resampled, y_train_resampled)
# Make predictions
y_pred = xgb_model.predict(X_test_scaled)
y_pred_proba = xgb_model.predict_proba(X_test_scaled)[:, 1]
# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_pred_proba)
cm = confusion_matrix(y_test, y_pred)
# Store results
xgb_results = {
'model_name': 'XGBoost',
'model': xgb_model,
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1_score': f1,
'auc': auc,
'confusion_matrix': cm,
'y_test': y_test,
'y_pred_proba': y_pred_proba
}
results.append(xgb_results)
trained_models['xgb'] = xgb_model
# Save model
with open("models/xgboost.pkl", "wb") as f:
pickle.dump(xgb_model, f)
# Save test data
with open("models/test_data.pkl", "wb") as f:
pickle.dump({"X_test": X_test_scaled, "y_test": y_test}, f)
st.session_state['trained_models'] = trained_models
# Automatically make predictions on the original dataset
status_text.text("Generating predictions...")
progress_bar.progress(90)
# Find the best model based on F1 score (good for imbalanced data)
best_model = None
best_f1 = -1
best_model_name = ""
for result in results:
if result['f1_score'] > best_f1:
best_f1 = result['f1_score']
best_model = result['model']
best_model_name = result['model_name']
if best_model is not None:
# Prepare full dataset for prediction
X_full = df_engineered.drop(columns=[target_col])
# Scale the data
X_full_scaled = scaler.transform(X_full)
# Make predictions
y_pred = best_model.predict(X_full_scaled)
y_pred_proba = best_model.predict_proba(X_full_scaled)[:, 1]
# Add predictions to the dataframe
df_with_predictions = df_engineered.copy()
df_with_predictions['Fraud_Probability'] = y_pred_proba
df_with_predictions['Predicted_Fraud'] = y_pred
# Store predictions
st.session_state['predictions'] = {
'df': df_with_predictions,
'model_name': best_model_name,
'results': results
}
status_text.text("Training completed!")
progress_bar.progress(100)
time.sleep(0.3)
status_text.empty()
progress_bar.empty()
st.success("Models trained successfully!")
# Display comparison of results
if results:
st.subheader("Model Performance Analysis")
# Create comparison table
comparison_df = pd.DataFrame([
{
'Model': r['model_name'],
'Accuracy': r['accuracy'],
'Precision': r['precision'],
'Recall': r['recall'],
'F1 Score': r['f1_score'],
'AUC': r['auc']
} for r in results
])
st.dataframe(comparison_df.style.highlight_max(axis=0, color='#81C784'))
# Plot metrics comparison with improved styling
fig = px.bar(
comparison_df.melt(id_vars=['Model'], var_name='Metric', value_name='Value'),
x='Model',
y='Value',
color='Metric',
barmode='group',
title='Model Performance Comparison',
labels={'Value': 'Score', 'Model': 'Model'},
color_discrete_sequence=['#2E7D32', '#43A047', '#81C784', '#1976D2', '#D32F2F']
)
fig.update_layout(
height=500,
template='plotly_white',
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig)
# Plot ROC curves with improved styling
st.subheader("ROC Curve Analysis")
fig = go.Figure()
colors = ['#2E7D32', '#1976D2', '#D32F2F']
for i, result in enumerate(results):
model_name = result['model_name']
y_test = result['y_test']
y_pred_proba = result['y_pred_proba']
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
auc = result['auc']
fig.add_trace(go.Scatter(
x=fpr,
y=tpr,
mode='lines',
name=f'{model_name} (AUC = {auc:.3f})',
line=dict(color=colors[i % len(colors)], width=3)
))
fig.add_trace(go.Scatter(
x=[0, 1],
y=[0, 1],
mode='lines',
name='Random',
line=dict(dash='dash', color='#757575', width=2)
))
fig.update_layout(
title='ROC Curve Analysis',
xaxis_title='False Positive Rate',
yaxis_title='True Positive Rate',
legend=dict(x=0.01, y=0.99),
height=500,
template='plotly_white',
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig)
# Show confusion matrices with improved styling
st.subheader("Confusion Matrix Analysis")
cols = st.columns(len(results))
for i, result in enumerate(results):
with cols[i]:
model_name = result['model_name']
cm = result['confusion_matrix']
# Calculate percentages
cm_percent = cm / cm.sum()
# Create annotation text
annotations = []
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
annotations.append({
'x': j,
'y': i,
'text': f"{cm[i, j]}<br>({cm_percent[i, j]:.1%})",
'showarrow': False,
'font': {'color': 'white' if cm_percent[i, j] > 0.5 else 'black'}
})
# Create heatmap
fig = go.Figure(data=go.Heatmap(
z=cm,
x=['Predicted Normal', 'Predicted Fraud'],
y=['Actual Normal', 'Actual Fraud'],
colorscale=[[0, '#81C784'], [1, '#2E7D32']],
showscale=False
))
fig.update_layout(
title=f"{model_name}",
annotations=annotations,
height=300,
template='plotly_white',
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig)
# Feature importance for tree-based models with improved styling
st.subheader("Feature Importance Analysis")
for result in results:
model_name = result['model_name']
model = result['model']
if model_name in ['Random Forest', 'XGBoost']:
# Get feature importance
if hasattr(model, 'feature_importances_'):
importances = model.feature_importances_
feature_names = X.columns
# Sort by importance
indices = np.argsort(importances)[::-1]
top_indices = indices[:10] # Show top 10 features for speed
# Create bar chart
fig = px.bar(
x=importances[top_indices],
y=[feature_names[i] for i in top_indices],
orientation='h',
title=f'Top Features - {model_name}',
labels={'x': 'Importance', 'y': 'Feature'},
color=importances[top_indices],
color_continuous_scale=['#81C784', '#43A047', '#2E7D32']
)
fig.update_layout(
height=400,
template='plotly_white',
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig)
# Navigation buttons
col1, col2 = st.columns([1, 5])
with col1:
if st.button("← Back to Feature Extraction", key="back_to_feature_eng"):
st.session_state['current_page'] = 'feature_engineering'
st.rerun()
with col2:
if st.session_state['predictions'] is not None:
if st.button("Continue to Results β†’", key="to_results"):
st.session_state['current_page'] = 'results'
st.rerun()
# Fraud Detection Results Page
elif st.session_state['current_page'] == 'results':
st.markdown("<div class='animate-fade-in'><h2 class='sub-header'>Step 5: Fraud Detection Results</h2></div>", unsafe_allow_html=True)
if st.session_state['predictions'] is None:
st.error("No predictions found. Please complete model training first.")
if st.button("Go back to Model Training"):
st.session_state['current_page'] = 'model_training'
st.rerun()
else:
predictions = st.session_state['predictions']
df_with_predictions = predictions['df']
model_name = predictions['model_name']
st.markdown(f"<h3 class='sub-header'>Fraud Detection Results using {model_name}</h3>", unsafe_allow_html=True)
# Summary of predictions
fraud_count = df_with_predictions['Predicted_Fraud'].sum()
total_count = len(df_with_predictions)
fraud_percentage = (fraud_count / total_count) * 100
# Create metrics display with improved styling
col1, col2, col3 = st.columns(3)
with col1:
st.metric(
label="Total Transactions",
value=f"{total_count:,}",
delta=None
)
with col2:
st.metric(
label="Detected Frauds",
value=f"{fraud_count:,}",
delta=None
)
with col3:
st.metric(
label="Fraud Percentage",
value=f"{fraud_percentage:.2f}%",
delta=None
)
# Visualization of fraud distribution with improved styling
st.subheader("Fraud Probability Distribution")
fig = px.histogram(
df_with_predictions,
x='Fraud_Probability',
nbins=50,
color='Predicted_Fraud',
color_discrete_map={0: "#6200EA", 1: "#D50000"},
labels={'Predicted_Fraud': 'Prediction', '0': 'Normal', '1': 'Fraud'},
title='Distribution of Fraud Probabilities',
marginal='box'
)
fig.update_layout(
height=500,
template='plotly_white',
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig)
# Show high probability transactions
st.subheader("High Fraud Probability Transactions")
# Slider for probability threshold
threshold = st.slider(
"Fraud Probability Threshold",
min_value=0.5,
max_value=0.95,
value=0.7,
step=0.05,
help="Transactions with fraud probability above this threshold will be shown"
)
high_prob_df = df_with_predictions[df_with_predictions['Fraud_Probability'] > threshold]
if len(high_prob_df) > 0:
st.write(f"Found {len(high_prob_df)} transactions with fraud probability > {threshold}")
# Sort by probability
high_prob_df = high_prob_df.sort_values('Fraud_Probability', ascending=False)
# Select columns to display
display_cols = ['Fraud_Probability', 'Predicted_Fraud']
# Add original features
if 'Amount' in high_prob_df.columns:
display_cols.insert(0, 'Amount')
if 'Time' in high_prob_df.columns:
display_cols.insert(0, 'Time')
# Add target column if it exists
if st.session_state['target_col'] in high_prob_df.columns:
display_cols.append(st.session_state['target_col'])
# Display dataframe
st.dataframe(high_prob_df[display_cols])
# Download button
csv = high_prob_df.to_csv(index=False)
st.download_button(
label="Download High Risk Transactions",
data=csv,
file_name="high_risk_transactions.csv",
mime="text/csv"
)
else:
st.info(f"No transactions found with fraud probability > {threshold}")
# Show top 10 highest probability transactions instead
st.write("Top 10 highest fraud probability transactions:")
st.dataframe(df_with_predictions.sort_values('Fraud_Probability', ascending=False).head(10))
# Compare actual vs predicted (if actual labels exist)
target_col = st.session_state['target_col']
if target_col in df_with_predictions.columns:
st.subheader("Actual vs Predicted Fraud")
# Confusion matrix with improved styling
cm = confusion_matrix(df_with_predictions[target_col], df_with_predictions['Predicted_Fraud'])
# Calculate percentages
cm_percent = cm / cm.sum()
# Create annotation text
annotations = []
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
annotations.append({
'x': j,
'y': i,
'text': f"{cm[i, j]}<br>({cm_percent[i, j]:.1%})",
'showarrow': False,
'font': {'color': 'white' if cm_percent[i, j] > 0.5 else 'black'}
})
# Create heatmap
fig = go.Figure(data=go.Heatmap(
z=cm,
x=['Predicted Normal', 'Predicted Fraud'],
y=['Actual Normal', 'Actual Fraud'],
colorscale=[[0, '#81C784'], [1, '#2E7D32']],
showscale=False
))
fig.update_layout(
title=f"Confusion Matrix - {model_name}",
annotations=annotations,
height=400,
template='plotly_white',
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig)
# Calculate metrics
accuracy = accuracy_score(df_with_predictions[target_col], df_with_predictions['Predicted_Fraud'])
# Calculate metrics
precision = precision_score(df_with_predictions[target_col], df_with_predictions['Predicted_Fraud'])
recall = recall_score(df_with_predictions[target_col], df_with_predictions['Predicted_Fraud'])
f1 = f1_score(df_with_predictions[target_col], df_with_predictions['Predicted_Fraud'])
# Display metrics with improved styling
st.subheader("Performance Metrics on Full Dataset")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric(
label="Accuracy",
value=f"{accuracy:.4f}",
delta=None
)
with col2:
st.metric(
label="Precision",
value=f"{precision:.4f}",
delta=None
)
with col3:
st.metric(
label="Recall",
value=f"{recall:.4f}",
delta=None
)
with col4:
st.metric(
label="F1 Score",
value=f"{f1:.4f}",
delta=None
)
# Download all predictions
st.subheader("Download Results")
csv = df_with_predictions.to_csv(index=False)
st.download_button(
label="Download All Predictions as CSV",
data=csv,
file_name="fraud_predictions.csv",
mime="text/csv"
)
# Navigation buttons
col1, col2 = st.columns([1, 5])
with col1:
if st.button("← Back to Model Training", key="back_to_model_training"):
st.session_state['current_page'] = 'model_training'
st.rerun()
with col2:
if st.button("Start Over", key="start_over"):
# Reset session state
for key in list(st.session_state.keys()):
del st.session_state[key]
st.session_state['current_page'] = 'home'
st.rerun()