Ahmedik95316's picture
Update app/streamlit_app.py
ce7aca5
raw
history blame
50.8 kB
import os
import io
import sys
import json
import time
import hashlib
import logging
import requests
import subprocess
import pandas as pd
import altair as alt
import streamlit as st
from pathlib import Path
import plotly.express as px
import plotly.graph_objects as go
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
import contextlib
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Add root to sys.path for imports
sys.path.append(str(Path(__file__).resolve().parent.parent))
# Try to import trainer directly for better progress tracking
try:
from model.train import RobustModelTrainer, estimate_training_time
DIRECT_TRAINING_AVAILABLE = True
except ImportError:
RobustModelTrainer = None
estimate_training_time = None
DIRECT_TRAINING_AVAILABLE = False
logger.warning("Direct training import failed, using subprocess fallback")
class StreamlitAppManager:
"""Manages Streamlit application state and functionality"""
def __init__(self):
self.setup_config()
self.setup_paths()
self.setup_api_client()
self.initialize_session_state()
def setup_config(self):
"""Setup application configuration"""
self.config = {
'api_url': "http://localhost:8000",
'max_upload_size': 10 * 1024 * 1024, # 10MB
'supported_file_types': ['csv', 'txt', 'json'],
'max_text_length': 10000,
'prediction_timeout': 30,
'refresh_interval': 60,
'max_batch_size': 10
}
def setup_paths(self):
"""Setup file paths"""
self.paths = {
'custom_data': Path("/tmp/custom_upload.csv"),
'metadata': Path("/tmp/metadata.json"),
'activity_log': Path("/tmp/activity_log.json"),
'drift_log': Path("/tmp/logs/monitoring_log.json"),
'prediction_log': Path("/tmp/prediction_log.json"),
'scheduler_log': Path("/tmp/logs/scheduler_execution.json"),
'error_log': Path("/tmp/logs/scheduler_errors.json")
}
def setup_api_client(self):
"""Setup API client with error handling"""
self.session = requests.Session()
self.session.timeout = self.config['prediction_timeout']
# Test API connection
self.api_available = self.test_api_connection()
def test_api_connection(self) -> bool:
"""Test API connection"""
try:
response = self.session.get(
f"{self.config['api_url']}/health", timeout=5)
return response.status_code == 200
except:
return False
def initialize_session_state(self):
"""Initialize Streamlit session state"""
if 'prediction_history' not in st.session_state:
st.session_state.prediction_history = []
if 'upload_history' not in st.session_state:
st.session_state.upload_history = []
if 'last_refresh' not in st.session_state:
st.session_state.last_refresh = datetime.now()
if 'auto_refresh' not in st.session_state:
st.session_state.auto_refresh = False
# Initialize app manager
app_manager = StreamlitAppManager()
# Page configuration
st.set_page_config(
page_title="Fake News Detection System",
page_icon="πŸ“°",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for better styling
st.markdown("""
<style>
.main-header {
font-size: 3rem;
font-weight: bold;
text-align: center;
color: #1f77b4;
margin-bottom: 2rem;
}
.metric-card {
background-color: #f0f2f6;
padding: 1rem;
border-radius: 0.5rem;
border-left: 4px solid #1f77b4;
}
.success-message {
background-color: #d4edda;
color: #155724;
padding: 1rem;
border-radius: 0.5rem;
border: 1px solid #c3e6cb;
}
.warning-message {
background-color: #fff3cd;
color: #856404;
padding: 1rem;
border-radius: 0.5rem;
border: 1px solid #ffeaa7;
}
.error-message {
background-color: #f8d7da;
color: #721c24;
padding: 1rem;
border-radius: 0.5rem;
border: 1px solid #f5c6cb;
}
</style>
""", unsafe_allow_html=True)
def load_json_file(file_path: Path, default: Any = None) -> Any:
"""Safely load JSON file with error handling"""
try:
if file_path.exists():
with open(file_path, 'r') as f:
return json.load(f)
return default or {}
except Exception as e:
logger.error(f"Failed to load {file_path}: {e}")
return default or {}
def save_prediction_to_history(text: str, prediction: str, confidence: float):
"""Save prediction to session history"""
prediction_entry = {
'timestamp': datetime.now().isoformat(),
'text': text[:100] + "..." if len(text) > 100 else text,
'prediction': prediction,
'confidence': confidence,
'text_length': len(text)
}
st.session_state.prediction_history.append(prediction_entry)
# Keep only last 50 predictions
if len(st.session_state.prediction_history) > 50:
st.session_state.prediction_history = st.session_state.prediction_history[-50:]
def make_prediction_request(text: str) -> Dict[str, Any]:
"""Make prediction request to API"""
try:
if not app_manager.api_available:
return {'error': 'API is not available'}
response = app_manager.session.post(
f"{app_manager.config['api_url']}/predict",
json={"text": text},
timeout=app_manager.config['prediction_timeout']
)
if response.status_code == 200:
return response.json()
else:
return {'error': f'API Error: {response.status_code} - {response.text}'}
except requests.exceptions.Timeout:
return {'error': 'Request timed out. Please try again.'}
except requests.exceptions.ConnectionError:
return {'error': 'Cannot connect to prediction service.'}
except Exception as e:
return {'error': f'Unexpected error: {str(e)}'}
def validate_text_input(text: str) -> tuple[bool, str]:
"""Validate text input"""
if not text or not text.strip():
return False, "Please enter some text to analyze."
if len(text) < 10:
return False, "Text must be at least 10 characters long."
if len(text) > app_manager.config['max_text_length']:
return False, f"Text must be less than {app_manager.config['max_text_length']} characters."
# Check for suspicious content
suspicious_patterns = ['<script', 'javascript:', 'data:']
if any(pattern in text.lower() for pattern in suspicious_patterns):
return False, "Text contains suspicious content."
return True, "Valid"
def create_confidence_gauge(confidence: float, prediction: str):
"""Create confidence gauge visualization"""
fig = go.Figure(go.Indicator(
mode="gauge+number+delta",
value=confidence * 100,
domain={'x': [0, 1], 'y': [0, 1]},
title={'text': f"Confidence: {prediction}"},
delta={'reference': 50},
gauge={
'axis': {'range': [None, 100]},
'bar': {'color': "red" if prediction == "Fake" else "green"},
'steps': [
{'range': [0, 50], 'color': "lightgray"},
{'range': [50, 80], 'color': "yellow"},
{'range': [80, 100], 'color': "lightgreen"}
],
'threshold': {
'line': {'color': "black", 'width': 4},
'thickness': 0.75,
'value': 90
}
}
))
fig.update_layout(height=300)
return fig
def create_prediction_history_chart():
"""Create prediction history visualization"""
if not st.session_state.prediction_history:
return None
df = pd.DataFrame(st.session_state.prediction_history)
df['timestamp'] = pd.to_datetime(df['timestamp'])
df['confidence_percent'] = df['confidence'] * 100
fig = px.scatter(
df,
x='timestamp',
y='confidence_percent',
color='prediction',
size='text_length',
hover_data=['text'],
title="Prediction History",
labels={'confidence_percent': 'Confidence (%)', 'timestamp': 'Time'}
)
fig.update_layout(height=400)
return fig
def estimate_detailed_training_time(dataset_size: int, enable_tuning: bool, cv_folds: int, num_models: int, max_features: int) -> str:
"""Estimate training time based on detailed parameters"""
# Base time per sample (in seconds)
base_time_per_sample = 0.01
# Feature complexity multiplier
feature_multiplier = max_features / 5000 # Normalized to 5000 features
# Cross-validation multiplier
cv_multiplier = cv_folds
# Hyperparameter tuning multiplier
tuning_multiplier = 8 if enable_tuning else 1
# Model count multiplier
model_multiplier = num_models
# Calculate total time
total_seconds = (
dataset_size *
base_time_per_sample *
feature_multiplier *
cv_multiplier *
tuning_multiplier *
model_multiplier
)
# Add base overhead
total_seconds += 10 # Base overhead
# Format time
if total_seconds < 60:
return f"{int(total_seconds)} seconds"
elif total_seconds < 3600:
minutes = int(total_seconds // 60)
seconds = int(total_seconds % 60)
return f"{minutes}:{seconds:02d}"
else:
hours = int(total_seconds // 3600)
minutes = int((total_seconds % 3600) // 60)
return f"{hours}:{minutes:02d}:00"
def estimate_training_time_streamlit(dataset_size: int) -> dict:
"""Estimate training time for Streamlit display"""
if estimate_training_time:
# Use the imported function
detailed_estimate = estimate_training_time(dataset_size, enable_tuning=True, cv_folds=3)
return {
'detailed': detailed_estimate,
'simple_range': f"{int(detailed_estimate['total_seconds']//60)}:{int(detailed_estimate['total_seconds']%60):02d}",
'category': 'small' if dataset_size < 100 else 'medium' if dataset_size < 1000 else 'large'
}
else:
# Fallback estimation
if dataset_size < 100:
return {'simple_range': '0:30-1:00', 'category': 'small'}
elif dataset_size < 1000:
return {'simple_range': '1:00-3:00', 'category': 'medium'}
else:
return {'simple_range': '3:00+', 'category': 'large'}
def render_enhanced_training_section(df_train):
"""Enhanced training section with progress tracking"""
st.header("Custom Model Training")
st.info("Upload your own dataset to retrain the model with custom data.")
# Show dataset info and time estimate
dataset_size = len(df_train)
time_estimate = estimate_training_time_streamlit(dataset_size)
# Training information display
st.markdown("### πŸ“Š Training Information")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Dataset Size", f"{dataset_size} samples")
with col2:
if 'detailed' in time_estimate:
est_time = time_estimate['detailed']['total_formatted']
else:
est_time = time_estimate['simple_range']
st.metric("Estimated Time", est_time)
with col3:
st.metric("Category", time_estimate['category'].title())
with col4:
training_method = "Full Pipeline" if dataset_size >= 50 else "Simplified"
st.metric("Training Mode", training_method)
# Dataset preview
with st.expander("πŸ‘€ Dataset Preview"):
st.dataframe(df_train.head(10))
# Dataset statistics
label_counts = df_train['label'].value_counts()
col1, col2 = st.columns(2)
with col1:
st.subheader("Class Distribution")
st.write(f"Real news (0): {label_counts.get(0, 0)}")
st.write(f"Fake news (1): {label_counts.get(1, 0)}")
with col2:
# Label distribution chart
fig_labels = px.pie(
values=label_counts.values,
names=['Real', 'Fake'],
title="Label Distribution"
)
st.plotly_chart(fig_labels, use_container_width=True)
# Training configuration
with st.expander("βš™οΈ Training Configuration", expanded=True):
st.markdown("**Configure your training parameters:**")
col1, col2 = st.columns(2)
with col1:
st.markdown("##### Core Settings")
# Test size slider
test_size = st.slider(
"Test Set Size (%)",
min_value=10,
max_value=50,
value=20,
step=5,
help="Percentage of data reserved for testing"
)
# Cross-validation folds
cv_folds = st.slider(
"Cross-Validation Folds",
min_value=2,
max_value=10,
value=3 if dataset_size < 100 else 5,
step=1,
help="Number of folds for cross-validation"
)
# Hyperparameter tuning toggle
enable_tuning = st.checkbox(
"Enable Hyperparameter Tuning",
value=dataset_size >= 50,
help="Enable grid search for optimal parameters (recommended for 50+ samples)"
)
with col2:
st.markdown("##### Advanced Options")
# Model selection
available_models = st.multiselect(
"Models to Train",
options=["Logistic Regression", "Random Forest"],
default=["Logistic Regression"] if dataset_size < 50 else ["Logistic Regression", "Random Forest"],
help="Select which models to train and compare"
)
# Feature engineering options
max_features = st.selectbox(
"Max TF-IDF Features",
options=[1000, 2000, 5000, 10000, 20000],
index=2 if dataset_size >= 100 else 1,
help="Maximum number of TF-IDF features to extract"
)
# N-gram range
ngram_option = st.selectbox(
"N-gram Range",
options=["Unigrams (1,1)", "Unigrams + Bigrams (1,2)", "Unigrams + Bigrams + Trigrams (1,3)"],
index=1,
help="Range of n-grams to include in feature extraction"
)
# Convert selections to parameters
ngram_map = {
"Unigrams (1,1)": (1, 1),
"Unigrams + Bigrams (1,2)": (1, 2),
"Unigrams + Bigrams + Trigrams (1,3)": (1, 3)
}
ngram_range = ngram_map[ngram_option]
model_map = {
"Logistic Regression": "logistic_regression",
"Random Forest": "random_forest"
}
selected_models = [model_map[model] for model in available_models]
# Training summary
st.markdown("---")
st.markdown("##### πŸ“‹ Training Summary")
summary_col1, summary_col2, summary_col3 = st.columns(3)
with summary_col1:
st.info(f"**Data Split:** {100-test_size}% train, {test_size}% test")
st.info(f"**Cross-Validation:** {cv_folds} folds")
with summary_col2:
tuning_status = "βœ… Enabled" if enable_tuning else "❌ Disabled"
st.info(f"**Hyperparameter Tuning:** {tuning_status}")
st.info(f"**Models:** {len(selected_models)} selected")
with summary_col3:
st.info(f"**Max Features:** {max_features:,}")
st.info(f"**N-grams:** {ngram_range}")
# Warnings and recommendations
if dataset_size < 20:
st.warning("⚠️ **Very small dataset detected:**")
st.warning("β€’ Hyperparameter tuning automatically disabled")
st.warning("β€’ Results may be unreliable")
st.warning("β€’ Consider using more data for better performance")
elif dataset_size < 50:
if enable_tuning:
st.warning("⚠️ **Small dataset with hyperparameter tuning:**")
st.warning("β€’ Training may take longer")
st.warning("β€’ Risk of overfitting")
else:
st.info("ℹ️ **Small dataset - good configuration**")
else:
if not enable_tuning:
st.info("ℹ️ **Large dataset without hyperparameter tuning:**")
st.info("β€’ Training will be faster")
st.info("β€’ Consider enabling tuning for better performance")
else:
st.success("βœ… **Optimal configuration for your dataset size**")
# Estimated training time with new parameters
estimated_time = estimate_detailed_training_time(
dataset_size, enable_tuning, cv_folds, len(selected_models), max_features
)
st.markdown("---")
st.markdown(f"##### ⏱️ **Estimated Training Time: {estimated_time}**")
# Training button and execution
if st.button("πŸƒβ€β™‚οΈ Start Training", type="primary", use_container_width=True):
# Validate configuration
if not selected_models:
st.error("❌ Please select at least one model to train!")
return
if dataset_size < 6:
st.error("❌ Dataset too small! Minimum 6 samples required.")
return
# Save training data with metadata
app_manager.paths['custom_data'].parent.mkdir(parents=True, exist_ok=True)
df_train.to_csv(app_manager.paths['custom_data'], index=False)
# Save training configuration
training_config = {
'test_size': test_size / 100, # Convert percentage to decimal
'cv_folds': cv_folds,
'enable_tuning': enable_tuning,
'selected_models': selected_models,
'max_features': max_features,
'ngram_range': ngram_range,
'dataset_size': dataset_size
}
config_path = Path("/tmp/training_config.json")
with open(config_path, 'w') as f:
json.dump(training_config, f, indent=2)
st.markdown("---")
st.markdown("### πŸ”„ Training Progress")
# Show final configuration
st.info(f"🎯 **Training Configuration:** {len(selected_models)} model(s), "
f"{test_size}% test split, {cv_folds}-fold CV, "
f"{'with' if enable_tuning else 'without'} hyperparameter tuning")
# Progress containers
progress_col1, progress_col2 = st.columns([3, 1])
with progress_col1:
progress_bar = st.progress(0)
status_text = st.empty()
with progress_col2:
time_display = st.empty()
# Start training
start_time = time.time()
if DIRECT_TRAINING_AVAILABLE:
# Method 1: Direct function call (shows progress in real-time)
status_text.text("Status: Initializing training with custom config...")
progress_bar.progress(5)
try:
# Create output capture
output_buffer = io.StringIO()
with st.spinner("Training model with custom configuration..."):
# Create trainer with custom config
trainer = RobustModelTrainer()
# Apply custom configuration
trainer.test_size = training_config['test_size']
trainer.cv_folds = training_config['cv_folds']
trainer.max_features = training_config['max_features']
trainer.ngram_range = training_config['ngram_range']
# Filter models based on selection
if len(selected_models) < len(trainer.models):
all_models = trainer.models.copy()
trainer.models = {k: v for k, v in all_models.items() if k in selected_models}
# Redirect stdout to capture progress
with contextlib.redirect_stdout(output_buffer):
success, message = trainer.train_model(
data_path=str(app_manager.paths['custom_data'])
)
elapsed_time = time.time() - start_time
time_display.text(f"Elapsed: {timedelta(seconds=int(elapsed_time))}")
# Show final progress
progress_bar.progress(100)
status_text.text("Status: Training completed!")
# Get captured output
captured_output = output_buffer.getvalue()
if success:
st.success("πŸŽ‰ **Training Completed Successfully!**")
st.info(f"πŸ“Š **{message}**")
# Show configuration used
with st.expander("βš™οΈ Configuration Used"):
st.json(training_config)
# Show captured progress if available
if captured_output:
with st.expander("πŸ“ˆ Training Progress Details"):
st.code(captured_output)
else:
st.error(f"❌ **Training Failed:** {message}")
if captured_output:
with st.expander("πŸ” Debug Output"):
st.code(captured_output)
except Exception as e:
st.error(f"❌ **Training Error:** {str(e)}")
else:
# Method 2: Subprocess with progress simulation
status_text.text("Status: Starting subprocess training...")
progress_bar.progress(10)
try:
# Calculate progress steps based on configuration
num_steps = len(selected_models) * (8 if enable_tuning else 2) * cv_folds
progress_steps = [
(20, "Loading and validating data..."),
(30, f"Configuring {len(selected_models)} model(s)..."),
(50, f"Training with {cv_folds}-fold cross-validation..."),
(70, "Performing hyperparameter tuning..." if enable_tuning else "Training models..."),
(85, "Evaluating performance..."),
(95, "Saving model artifacts...")
]
# Start subprocess with config
process = subprocess.Popen(
[sys.executable, "model/train.py",
"--data_path", str(app_manager.paths['custom_data']),
"--config_path", str(config_path)],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True
)
# Simulate progress while waiting
step_idx = 0
while process.poll() is None:
elapsed = time.time() - start_time
time_display.text(f"Elapsed: {timedelta(seconds=int(elapsed))}")
# Update progress based on elapsed time and configuration
if step_idx < len(progress_steps):
expected_time = dataset_size * 0.05 * (2 if enable_tuning else 1)
if elapsed > expected_time * (step_idx + 1) / len(progress_steps):
progress, status = progress_steps[step_idx]
progress_bar.progress(progress)
status_text.text(f"Status: {status}")
step_idx += 1
time.sleep(1)
# Get final output
stdout, _ = process.communicate()
# Final progress
progress_bar.progress(100)
status_text.text("Status: Training completed!")
elapsed_time = time.time() - start_time
time_display.text(f"Completed: {timedelta(seconds=int(elapsed_time))}")
if process.returncode == 0:
st.success("πŸŽ‰ **Training Completed Successfully!**")
# Show configuration used
with st.expander("βš™οΈ Configuration Used"):
st.json(training_config)
# Extract performance info from output
if stdout:
lines = stdout.strip().split('\n')
for line in lines[-10:]: # Check last 10 lines
if 'Best model:' in line:
st.info(f"πŸ“Š **{line}**")
elif any(keyword in line.lower() for keyword in ['accuracy', 'f1']):
if line.strip():
st.info(f"πŸ“ˆ **Performance:** {line}")
# Show full output in expander
with st.expander("πŸ“‹ Complete Training Log"):
st.code(stdout)
else:
st.error("❌ **Training Failed**")
with st.expander("πŸ” Error Details"):
st.code(stdout)
except Exception as e:
st.error(f"❌ **Training Error:** {str(e)}")
# Try to reload model in API regardless of training method
if app_manager.api_available:
try:
with st.spinner("Reloading model in API..."):
reload_response = app_manager.session.post(
f"{app_manager.config['api_url']}/model/reload",
timeout=30
)
if reload_response.status_code == 200:
st.success("βœ… **Model reloaded in API successfully!**")
else:
st.warning("⚠️ Model trained but API reload failed")
except Exception as e:
st.warning(f"⚠️ Model trained but API reload failed: {str(e)}")
# Training tips
st.markdown("---")
st.markdown("### πŸ’‘ Training Tips")
st.info("βœ“ **Model saved successfully** - You can now test predictions")
st.info("βœ“ **Try different datasets** to improve performance")
st.info("βœ“ **Larger datasets** (50+ samples) enable full hyperparameter tuning")
# Main application
def main():
"""Main Streamlit application"""
# Header
st.markdown('<h1 class="main-header">πŸ“° Fake News Detection System</h1>',
unsafe_allow_html=True)
# API Status indicator
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
if app_manager.api_available:
st.markdown(
'<div class="success-message">🟒 API Service: Online</div>', unsafe_allow_html=True)
else:
st.markdown(
'<div class="error-message">πŸ”΄ API Service: Offline</div>', unsafe_allow_html=True)
# Main content area
tab1, tab2, tab3, tab4, tab5 = st.tabs([
"πŸ” Prediction",
"πŸ“Š Batch Analysis",
"πŸ“ˆ Analytics",
"🎯 Model Training",
"βš™οΈ System Status"
])
# Tab 1: Individual Prediction
with tab1:
st.header("Single Text Analysis")
# Input methods
input_method = st.radio(
"Choose input method:",
["Type Text", "Upload File"],
horizontal=True
)
user_text = ""
if input_method == "Type Text":
user_text = st.text_area(
"Enter news article text:",
height=200,
placeholder="Paste or type the news article you want to analyze..."
)
else: # Upload File
uploaded_file = st.file_uploader(
"Upload text file:",
type=['txt', 'csv'],
help="Upload a text file containing the article to analyze"
)
if uploaded_file:
try:
if uploaded_file.type == "text/plain":
user_text = str(uploaded_file.read(), "utf-8")
elif uploaded_file.type == "text/csv":
df = pd.read_csv(uploaded_file)
if 'text' in df.columns:
user_text = df['text'].iloc[0] if len(
df) > 0 else ""
else:
st.error("CSV file must contain a 'text' column")
st.success(
f"File uploaded successfully! ({len(user_text)} characters)")
except Exception as e:
st.error(f"Error reading file: {e}")
# Prediction section
col1, col2 = st.columns([3, 1])
with col1:
if st.button("🧠 Analyze Text", type="primary", use_container_width=True):
if user_text:
# Validate input
is_valid, validation_message = validate_text_input(
user_text)
if not is_valid:
st.error(validation_message)
else:
# Show progress
with st.spinner("Analyzing text..."):
result = make_prediction_request(user_text)
if 'error' in result:
st.error(f"❌ {result['error']}")
else:
# Display results
prediction = result['prediction']
confidence = result['confidence']
# Save to history
save_prediction_to_history(
user_text, prediction, confidence)
# Results display
col_result1, col_result2 = st.columns(2)
with col_result1:
if prediction == "Fake":
st.markdown(f"""
<div class="error-message">
<h3>🚨 Prediction: FAKE NEWS</h3>
<p>Confidence: {confidence:.2%}</p>
</div>
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<div class="success-message">
<h3>βœ… Prediction: REAL NEWS</h3>
<p>Confidence: {confidence:.2%}</p>
</div>
""", unsafe_allow_html=True)
with col_result2:
# Confidence gauge
fig_gauge = create_confidence_gauge(
confidence, prediction)
st.plotly_chart(
fig_gauge, use_container_width=True)
# Additional information
with st.expander("πŸ“‹ Analysis Details"):
st.json({
"model_version": result.get('model_version', 'Unknown'),
"processing_time": f"{result.get('processing_time', 0):.3f} seconds",
"timestamp": result.get('timestamp', ''),
"text_length": len(user_text),
"word_count": len(user_text.split())
})
else:
st.warning("Please enter text to analyze.")
with col2:
if st.button("πŸ”„ Clear Text", use_container_width=True):
st.rerun()
# Tab 2: Batch Analysis
with tab2:
st.header("Batch Text Analysis")
# File upload for batch processing
batch_file = st.file_uploader(
"Upload CSV file for batch analysis:",
type=['csv'],
help="CSV file should contain a 'text' column with articles to analyze"
)
if batch_file:
try:
df = pd.read_csv(batch_file)
if 'text' not in df.columns:
st.error("CSV file must contain a 'text' column")
else:
st.success(f"File loaded: {len(df)} articles found")
# Preview data
st.subheader("Data Preview")
st.dataframe(df.head(10))
# Batch processing
if st.button("πŸš€ Process Batch", type="primary"):
if len(df) > app_manager.config['max_batch_size']:
st.warning(
f"Only processing first {app_manager.config['max_batch_size']} articles")
df = df.head(app_manager.config['max_batch_size'])
progress_bar = st.progress(0)
status_text = st.empty()
results = []
for i, row in df.iterrows():
status_text.text(
f"Processing article {i+1}/{len(df)}...")
progress_bar.progress((i + 1) / len(df))
result = make_prediction_request(row['text'])
if 'error' not in result:
results.append({
'text': row['text'][:100] + "...",
'prediction': result['prediction'],
'confidence': result['confidence'],
'processing_time': result.get('processing_time', 0)
})
else:
results.append({
'text': row['text'][:100] + "...",
'prediction': 'Error',
'confidence': 0,
'processing_time': 0
})
# Display results
results_df = pd.DataFrame(results)
# Summary statistics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Processed", len(results_df))
with col2:
fake_count = len(
results_df[results_df['prediction'] == 'Fake'])
st.metric("Fake News", fake_count)
with col3:
real_count = len(
results_df[results_df['prediction'] == 'Real'])
st.metric("Real News", real_count)
with col4:
avg_confidence = results_df['confidence'].mean()
st.metric("Avg Confidence",
f"{avg_confidence:.2%}")
# Results visualization
if len(results_df) > 0:
fig = px.histogram(
results_df,
x='prediction',
color='prediction',
title="Batch Analysis Results"
)
st.plotly_chart(fig, use_container_width=True)
# Download results
csv_buffer = io.StringIO()
results_df.to_csv(csv_buffer, index=False)
st.download_button(
label="πŸ“₯ Download Results",
data=csv_buffer.getvalue(),
file_name=f"batch_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
mime="text/csv"
)
except Exception as e:
st.error(f"Error processing file: {e}")
# Tab 3: Analytics
with tab3:
st.header("System Analytics")
# Prediction history
if st.session_state.prediction_history:
st.subheader("Recent Predictions")
# History chart
fig_history = create_prediction_history_chart()
if fig_history:
st.plotly_chart(fig_history, use_container_width=True)
# History table
history_df = pd.DataFrame(st.session_state.prediction_history)
st.dataframe(history_df.tail(20), use_container_width=True)
else:
st.info(
"No prediction history available. Make some predictions to see analytics.")
# System metrics
st.subheader("System Metrics")
# Load various log files for analytics
try:
# API health check
if app_manager.api_available:
response = app_manager.session.get(
f"{app_manager.config['api_url']}/metrics")
if response.status_code == 200:
metrics = response.json()
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total API Requests",
metrics.get('total_requests', 0))
with col2:
st.metric("Unique Clients", metrics.get(
'unique_clients', 0))
with col3:
st.metric("Model Version", metrics.get(
'model_version', 'Unknown'))
with col4:
status = metrics.get('model_health', 'unknown')
st.metric("Model Status", status)
except Exception as e:
st.warning(f"Could not load API metrics: {e}")
# Tab 4: Model Training
with tab4:
# File upload for training
training_file = st.file_uploader(
"Upload training dataset (CSV):",
type=['csv'],
help="CSV file should contain 'text' and 'label' columns (label: 0=Real, 1=Fake)"
)
if training_file:
try:
df_train = pd.read_csv(training_file)
required_columns = ['text', 'label']
missing_columns = [
col for col in required_columns if col not in df_train.columns]
if missing_columns:
st.error(f"Missing required columns: {missing_columns}")
else:
st.success(
f"Training file loaded: {len(df_train)} samples")
# Enhanced training section
render_enhanced_training_section(df_train)
except Exception as e:
st.error(f"Error loading training file: {e}")
# Tab 5: System Status
with tab5:
render_system_status()
def render_system_status():
"""Render system status tab"""
st.header("System Status & Monitoring")
# Auto-refresh toggle
col1, col2 = st.columns([1, 4])
with col1:
st.session_state.auto_refresh = st.checkbox(
"Auto Refresh", value=st.session_state.auto_refresh)
with col2:
if st.button("πŸ”„ Refresh Now"):
st.session_state.last_refresh = datetime.now()
st.rerun()
# System health overview
st.subheader("πŸ₯ System Health")
if app_manager.api_available:
try:
health_response = app_manager.session.get(
f"{app_manager.config['api_url']}/health")
if health_response.status_code == 200:
health_data = health_response.json()
# Overall status
overall_status = health_data.get('status', 'unknown')
if overall_status == 'healthy':
st.success("🟒 System Status: Healthy")
else:
st.error("πŸ”΄ System Status: Unhealthy")
# Detailed health metrics
col1, col2, col3 = st.columns(3)
with col1:
st.subheader("πŸ€– Model Health")
model_health = health_data.get('model_health', {})
for key, value in model_health.items():
if key != 'test_prediction':
st.write(
f"**{key.replace('_', ' ').title()}:** {value}")
with col2:
st.subheader("πŸ’» System Resources")
system_health = health_data.get('system_health', {})
for key, value in system_health.items():
if isinstance(value, (int, float)):
st.metric(key.replace('_', ' ').title(),
f"{value:.1f}%")
with col3:
st.subheader("πŸ”— API Health")
api_health = health_data.get('api_health', {})
for key, value in api_health.items():
st.write(
f"**{key.replace('_', ' ').title()}:** {value}")
except Exception as e:
st.error(f"Failed to get health status: {e}")
else:
st.error("πŸ”΄ API Service is not available")
# Model information
st.subheader("🎯 Model Information")
metadata = load_json_file(app_manager.paths['metadata'], {})
if metadata:
col1, col2 = st.columns(2)
with col1:
for key in ['model_version', 'test_accuracy', 'test_f1', 'model_type']:
if key in metadata:
display_key = key.replace('_', ' ').title()
value = metadata[key]
if isinstance(value, float):
st.metric(display_key, f"{value:.4f}")
else:
st.metric(display_key, str(value))
with col2:
for key in ['train_size', 'timestamp', 'data_version']:
if key in metadata:
display_key = key.replace('_', ' ').title()
value = metadata[key]
if key == 'timestamp':
try:
dt = datetime.fromisoformat(
value.replace('Z', '+00:00'))
value = dt.strftime('%Y-%m-%d %H:%M:%S')
except:
pass
st.write(f"**{display_key}:** {value}")
else:
st.warning("No model metadata available")
# Recent activity
st.subheader("πŸ“œ Recent Activity")
activity_log = load_json_file(app_manager.paths['activity_log'], [])
if activity_log:
recent_activities = activity_log[-10:] if len(
activity_log) > 10 else activity_log
for entry in reversed(recent_activities):
timestamp = entry.get('timestamp', 'Unknown')
event = entry.get('event', 'Unknown event')
level = entry.get('level', 'INFO')
if level == 'ERROR':
st.error(f"πŸ”΄ {timestamp} - {event}")
elif level == 'WARNING':
st.warning(f"🟑 {timestamp} - {event}")
else:
st.info(f"πŸ”΅ {timestamp} - {event}")
else:
st.info("No recent activity logs found")
# File system status
st.subheader("πŸ“ File System Status")
critical_files = [
("/tmp/pipeline.pkl", "Pipeline Model"),
("/tmp/model.pkl", "Model Component"),
("/tmp/vectorizer.pkl", "Vectorizer"),
("/tmp/metadata.json", "Model Metadata"),
("/tmp/data/combined_dataset.csv", "Training Dataset")
]
col1, col2 = st.columns(2)
with col1:
st.write("**Critical Files:**")
for file_path, description in critical_files:
if Path(file_path).exists():
st.success(f"βœ… {description}")
else:
st.error(f"❌ {description}")
with col2:
# Disk usage information
try:
import shutil
total, used, free = shutil.disk_usage("/tmp")
st.write("**Disk Usage (/tmp):**")
st.write(f"Total: {total // (1024**3)} GB")
st.write(f"Used: {used // (1024**3)} GB")
st.write(f"Free: {free // (1024**3)} GB")
usage_percent = (used / total) * 100
if usage_percent > 90:
st.error(f"⚠️ Disk usage: {usage_percent:.1f}%")
elif usage_percent > 75:
st.warning(f"⚠️ Disk usage: {usage_percent:.1f}%")
else:
st.success(f"βœ… Disk usage: {usage_percent:.1f}%")
except Exception as e:
st.error(f"Cannot check disk usage: {e}")
# System actions
st.subheader("πŸ”§ System Actions")
col1, col2, col3 = st.columns(3)
with col1:
# Initialize system button
if st.button("πŸ”§ Initialize System", help="Run system initialization if components are missing"):
with st.spinner("Running system initialization..."):
try:
result = subprocess.run(
[sys.executable, "/app/initialize_system.py"],
capture_output=True,
text=True,
timeout=300
)
if result.returncode == 0:
st.success(
"βœ… System initialization completed successfully!")
with st.expander("πŸ“‹ Initialization Output"):
st.code(result.stdout)
time.sleep(2)
st.rerun()
else:
st.error("❌ System initialization failed")
st.code(result.stderr)
except subprocess.TimeoutExpired:
st.error("⏰ Initialization timed out")
except Exception as e:
st.error(f"❌ Initialization error: {e}")
with col2:
# Reload API model
if st.button("πŸ”„ Reload API Model", help="Reload the model in the API service"):
if app_manager.api_available:
try:
with st.spinner("Reloading model in API..."):
reload_response = app_manager.session.post(
f"{app_manager.config['api_url']}/model/reload",
timeout=30
)
if reload_response.status_code == 200:
st.success("βœ… Model reloaded successfully!")
st.json(reload_response.json())
else:
st.error(f"❌ Model reload failed: {reload_response.status_code}")
except Exception as e:
st.error(f"❌ Model reload error: {e}")
else:
st.error("❌ API service not available")
with col3:
# Clear cache
if st.button("πŸ—‘οΈ Clear Cache", help="Clear prediction history and temporary data"):
try:
# Clear session state
st.session_state.prediction_history = []
st.session_state.upload_history = []
# Clear temporary files
temp_files = [
"/tmp/custom_upload.csv",
"/tmp/prediction_log.json"
]
cleared_count = 0
for temp_file in temp_files:
if Path(temp_file).exists():
Path(temp_file).unlink()
cleared_count += 1
st.success(f"βœ… Cache cleared! Removed {cleared_count} temporary files")
time.sleep(1)
st.rerun()
except Exception as e:
st.error(f"❌ Cache clear error: {e}")
# Auto-refresh logic
if st.session_state.auto_refresh:
time_since_refresh = datetime.now() - st.session_state.last_refresh
if time_since_refresh > timedelta(seconds=app_manager.config['refresh_interval']):
st.session_state.last_refresh = datetime.now()
st.rerun()
# Footer
st.markdown("---")
st.markdown("""
<div style='text-align: center; color: #666; padding: 20px;'>
<p>πŸ“° <strong>Fake News Detection System</strong> | Advanced MLOps Pipeline</p>
<p>Built with Streamlit, FastAPI, and Scikit-learn | Production-ready with comprehensive monitoring</p>
</div>
""", unsafe_allow_html=True)
# Run main application
if __name__ == "__main__":
main()