Spaces:
Sleeping
Sleeping
import os | |
import joblib | |
import onnxruntime as ort | |
import numpy as np | |
from pathlib import Path | |
from typing import Dict, Any, Optional, List | |
import logging | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
import re | |
import warnings | |
# Suppress sklearn warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
warnings.filterwarnings("ignore", message=".*sklearn.*") | |
logger = logging.getLogger(__name__) | |
class MLManager: | |
"""Centralized ML model manager for SafeSpace threat detection""" | |
def __init__(self, models_dir: str = "models"): | |
self.models_dir = Path(models_dir) | |
self.models_loaded = False | |
# Model instances | |
self.threat_model = None | |
self.sentiment_model = None | |
self.onnx_session = None | |
self.threat_vectorizer = None | |
self.sentiment_vectorizer = None | |
# Model paths | |
self.model_paths = { | |
"threat": self.models_dir / "Threat.pkl", | |
"sentiment": self.models_dir / "sentiment.pkl", | |
"context": self.models_dir / "contextClassifier.onnx" | |
} | |
# Initialize models | |
self._load_models() | |
def _load_models(self) -> bool: | |
"""Load all ML models""" | |
try: | |
logger.info("Loading ML models...") | |
# Load threat detection model | |
if self.model_paths["threat"].exists(): | |
try: | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
self.threat_model = joblib.load(self.model_paths["threat"]) | |
logger.info("✅ Threat model loaded successfully") | |
except Exception as e: | |
logger.warning(f"⚠️ Failed to load threat model: {e}") | |
self.threat_model = None | |
else: | |
logger.error(f"❌ Threat model not found: {self.model_paths['threat']}") | |
# Load sentiment analysis model | |
if self.model_paths["sentiment"].exists(): | |
try: | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
self.sentiment_model = joblib.load(self.model_paths["sentiment"]) | |
logger.info("✅ Sentiment model loaded successfully") | |
except Exception as e: | |
logger.warning(f"⚠️ Failed to load sentiment model: {e}") | |
self.sentiment_model = None | |
else: | |
logger.error(f"❌ Sentiment model not found: {self.model_paths['sentiment']}") | |
# Load ONNX context classifier | |
if self.model_paths["context"].exists(): | |
try: | |
self.onnx_session = ort.InferenceSession( | |
str(self.model_paths["context"]), | |
providers=['CPUExecutionProvider'] # Specify CPU provider | |
) | |
logger.info("✅ ONNX context classifier loaded successfully") | |
except Exception as e: | |
logger.warning(f"⚠️ Failed to load ONNX model: {e}") | |
self.onnx_session = None | |
else: | |
logger.error(f"❌ ONNX model not found: {self.model_paths['context']}") | |
# Check if models are loaded | |
models_available = [ | |
self.threat_model is not None, | |
self.sentiment_model is not None, | |
self.onnx_session is not None | |
] | |
self.models_loaded = any(models_available) | |
if self.models_loaded: | |
logger.info(f"✅ ML Manager initialized with {sum(models_available)}/3 models") | |
else: | |
logger.warning("⚠️ No models loaded, falling back to rule-based detection") | |
return self.models_loaded | |
except Exception as e: | |
logger.error(f"❌ Error loading models: {e}") | |
self.models_loaded = False | |
return False | |
def _preprocess_text(self, text: str) -> str: | |
"""Preprocess text for model input""" | |
if not text: | |
return "" | |
# Convert to lowercase | |
text = text.lower() | |
# Remove extra whitespace | |
text = re.sub(r'\s+', ' ', text).strip() | |
# Remove special characters but keep basic punctuation | |
text = re.sub(r'[^\w\s\.,!?-]', '', text) | |
return text | |
def predict_threat(self, text: str) -> Dict[str, Any]: | |
"""Main threat prediction using ensemble of models""" | |
try: | |
processed_text = self._preprocess_text(text) | |
if not processed_text: | |
return self._create_empty_prediction() | |
predictions = {} | |
confidence_scores = [] | |
models_used = [] | |
# 1. Threat Detection Model | |
threat_confidence = 0.0 | |
threat_prediction = 0 | |
if self.threat_model is not None: | |
try: | |
# Ensure we have clean text input for threat detection | |
threat_input = processed_text if isinstance(processed_text, str) else str(processed_text) | |
# Handle different model prediction formats | |
raw_prediction = self.threat_model.predict([threat_input]) | |
# Extract prediction value - handle both single values and arrays | |
if isinstance(raw_prediction, (list, np.ndarray)): | |
if len(raw_prediction) > 0: | |
pred_val = raw_prediction[0] | |
if isinstance(pred_val, (list, np.ndarray)) and len(pred_val) > 0: | |
threat_prediction = int(pred_val[0]) | |
elif isinstance(pred_val, (int, float, np.integer, np.floating)): | |
threat_prediction = int(pred_val) | |
else: | |
logger.warning(f"Unexpected threat prediction format: {type(pred_val)} - {pred_val}") | |
threat_prediction = 0 | |
else: | |
threat_prediction = 0 | |
elif isinstance(raw_prediction, (int, float, np.integer, np.floating)): | |
threat_prediction = int(raw_prediction) | |
else: | |
logger.warning(f"Unexpected threat prediction type: {type(raw_prediction)} - {raw_prediction}") | |
threat_prediction = 0 | |
# Get confidence if available | |
if hasattr(self.threat_model, 'predict_proba'): | |
threat_proba = self.threat_model.predict_proba([threat_input])[0] | |
threat_confidence = float(max(threat_proba)) | |
else: | |
threat_confidence = 0.8 if threat_prediction == 1 else 0.2 | |
predictions["threat"] = { | |
"prediction": threat_prediction, | |
"confidence": threat_confidence | |
} | |
confidence_scores.append(threat_confidence * 0.5) # 50% weight | |
models_used.append("threat_classifier") | |
except Exception as e: | |
logger.error(f"Threat model prediction failed: {e}") | |
# Provide fallback threat detection | |
threat_keywords = ['attack', 'violence', 'emergency', 'fire', 'accident', 'threat', 'danger', 'killed', 'death'] | |
fallback_threat = 1 if any(word in processed_text for word in threat_keywords) else 0 | |
fallback_confidence = 0.8 if fallback_threat == 1 else 0.2 | |
predictions["threat"] = { | |
"prediction": fallback_threat, | |
"confidence": fallback_confidence | |
} | |
confidence_scores.append(fallback_confidence * 0.5) | |
models_used.append("fallback_threat") | |
# 2. Sentiment Analysis Model | |
sentiment_confidence = 0.0 | |
sentiment_prediction = 0 | |
if self.sentiment_model is not None: | |
try: | |
# Ensure we have clean text input for sentiment analysis | |
sentiment_input = processed_text if isinstance(processed_text, str) else str(processed_text) | |
# Handle different model prediction formats | |
raw_prediction = self.sentiment_model.predict([sentiment_input]) | |
# Extract prediction value - handle both single values and arrays | |
if isinstance(raw_prediction, (list, np.ndarray)): | |
if len(raw_prediction) > 0: | |
pred_val = raw_prediction[0] | |
if isinstance(pred_val, (list, np.ndarray)) and len(pred_val) > 0: | |
# Handle numeric prediction values safely | |
try: | |
sentiment_prediction = int(pred_val[0]) | |
except (ValueError, TypeError): | |
# Handle non-numeric predictions gracefully | |
logger.debug(f"Non-numeric prediction value: {pred_val[0]}, using default") | |
sentiment_prediction = 0 | |
elif isinstance(pred_val, (int, float, np.integer, np.floating)): | |
# Handle numeric prediction values safely | |
try: | |
sentiment_prediction = int(pred_val) | |
except (ValueError, TypeError): | |
# Handle non-numeric predictions gracefully | |
logger.debug(f"Non-numeric prediction value: {pred_val}, using default") | |
sentiment_prediction = 0 | |
elif isinstance(pred_val, dict): | |
# Handle dictionary prediction format (common with transformers models) | |
label = pred_val.get("label", "").lower() | |
score = pred_val.get("score", 0.0) | |
# Map emotions to binary sentiment (0=negative, 1=positive) | |
negative_emotions = ["fear", "anger", "sadness", "disgust"] | |
positive_emotions = ["joy", "surprise", "love", "happiness"] | |
if label in negative_emotions: | |
sentiment_prediction = 0 # Negative | |
elif label in positive_emotions: | |
sentiment_prediction = 1 # Positive | |
else: | |
# Default handling for unknown labels | |
sentiment_prediction = 0 if score < 0.5 else 1 | |
# Use the score from the prediction | |
sentiment_confidence = float(score) | |
logger.debug(f"Processed emotion '{label}' -> sentiment: {sentiment_prediction} (confidence: {sentiment_confidence})") | |
else: | |
logger.warning(f"Unexpected sentiment prediction format: {type(pred_val)} - {pred_val}") | |
sentiment_prediction = 0 | |
else: | |
sentiment_prediction = 0 | |
elif isinstance(raw_prediction, (int, float, np.integer, np.floating)): | |
# Handle single numeric prediction values safely | |
try: | |
sentiment_prediction = int(raw_prediction) | |
except (ValueError, TypeError): | |
# Handle non-numeric predictions gracefully | |
logger.debug(f"Non-numeric raw prediction: {raw_prediction}, using default") | |
sentiment_prediction = 0 | |
else: | |
logger.warning(f"Unexpected sentiment prediction type: {type(raw_prediction)} - {raw_prediction}") | |
sentiment_prediction = 0 | |
# Get confidence if available | |
if hasattr(self.sentiment_model, 'predict_proba'): | |
sentiment_proba = self.sentiment_model.predict_proba([sentiment_input])[0] | |
sentiment_confidence = float(max(sentiment_proba)) | |
else: | |
sentiment_confidence = 0.7 if sentiment_prediction == 0 else 0.3 # Negative sentiment = higher threat | |
# Determine sentiment label | |
sentiment_label = "negative" if sentiment_prediction == 0 else "positive" | |
# If we got a label from the dictionary prediction, use that instead | |
if 'label' in locals(): | |
sentiment_label = label | |
predictions["sentiment"] = { | |
"prediction": sentiment_prediction, | |
"confidence": sentiment_confidence, | |
"label": sentiment_label | |
} | |
# Negative sentiment contributes to threat score | |
sentiment_threat_score = (1 - sentiment_prediction) * sentiment_confidence * 0.2 # 20% weight | |
confidence_scores.append(sentiment_threat_score) | |
models_used.append("sentiment_classifier") | |
except Exception as e: | |
logger.error(f"Sentiment model prediction failed: {e}") | |
# Provide fallback sentiment analysis | |
negative_words = ['attack', 'violence', 'death', 'killed', 'emergency', 'fire', 'accident', 'threat'] | |
fallback_sentiment = 0 if any(word in processed_text for word in negative_words) else 1 | |
predictions["sentiment"] = { | |
"prediction": fallback_sentiment, | |
"confidence": 0.6, | |
"label": "negative" if fallback_sentiment == 0 else "positive" | |
} | |
sentiment_threat_score = (1 - fallback_sentiment) * 0.6 * 0.2 | |
confidence_scores.append(sentiment_threat_score) | |
models_used.append("fallback_sentiment") | |
# 3. ONNX Context Classifier | |
onnx_confidence = 0.0 | |
onnx_prediction = 0 | |
if self.onnx_session is not None: | |
try: | |
# Check what inputs the ONNX model expects | |
input_names = [inp.name for inp in self.onnx_session.get_inputs()] | |
if 'input_ids' in input_names and 'attention_mask' in input_names: | |
# This is likely a transformer model (BERT-like) | |
# Create simple tokenized input (basic approach) | |
tokens = processed_text.split()[:50] # Limit to 50 tokens | |
# Simple word-to-ID mapping (this is a fallback approach) | |
input_ids = [hash(word) % 1000 + 1 for word in tokens] # Simple hash-based IDs | |
# Pad or truncate to fixed length | |
max_length = 128 | |
if len(input_ids) < max_length: | |
input_ids.extend([0] * (max_length - len(input_ids))) | |
else: | |
input_ids = input_ids[:max_length] | |
attention_mask = [1 if i != 0 else 0 for i in input_ids] | |
# Convert to numpy arrays with correct shape | |
input_ids_array = np.array([input_ids], dtype=np.int64) | |
attention_mask_array = np.array([attention_mask], dtype=np.int64) | |
inputs = { | |
'input_ids': input_ids_array, | |
'attention_mask': attention_mask_array | |
} | |
onnx_output = self.onnx_session.run(None, inputs) | |
# Extract prediction from output | |
if len(onnx_output) > 0 and len(onnx_output[0]) > 0: | |
# Handle different output formats | |
output = onnx_output[0][0] | |
if isinstance(output, (list, np.ndarray)) and len(output) > 1: | |
# Probability output | |
probs = output | |
onnx_prediction = int(np.argmax(probs)) | |
onnx_confidence = float(max(probs)) | |
else: | |
# Single value output | |
onnx_prediction = int(output > 0.5) | |
onnx_confidence = float(abs(output)) | |
else: | |
# Use the original simple feature approach | |
input_name = input_names[0] if input_names else 'input' | |
text_features = self._text_to_features(processed_text) | |
onnx_output = self.onnx_session.run(None, {input_name: text_features}) | |
onnx_prediction = int(onnx_output[0][0]) if len(onnx_output[0]) > 0 else 0 | |
onnx_confidence = float(onnx_output[1][0][1]) if len(onnx_output) > 1 else 0.5 | |
predictions["onnx"] = { | |
"prediction": onnx_prediction, | |
"confidence": onnx_confidence | |
} | |
confidence_scores.append(onnx_confidence * 0.3) # 30% weight | |
models_used.append("context_classifier") | |
except Exception as e: | |
logger.error(f"ONNX model prediction failed: {e}") | |
# Provide fallback based on keyword analysis | |
threat_keywords = ['emergency', 'attack', 'violence', 'fire', 'accident', 'threat', 'danger'] | |
fallback_confidence = len([w for w in threat_keywords if w in processed_text]) / len(threat_keywords) | |
fallback_prediction = 1 if fallback_confidence > 0.3 else 0 | |
predictions["onnx"] = { | |
"prediction": fallback_prediction, | |
"confidence": fallback_confidence | |
} | |
confidence_scores.append(fallback_confidence * 0.3) | |
models_used.append("fallback_context") | |
# Calculate final confidence score | |
final_confidence = sum(confidence_scores) if confidence_scores else 0.0 | |
# Apply aviation content boost (as mentioned in your demo) | |
aviation_keywords = ['flight', 'aircraft', 'aviation', 'airline', 'pilot', 'crash', 'airport'] | |
if any(keyword in processed_text for keyword in aviation_keywords): | |
final_confidence = min(final_confidence + 0.1, 1.0) # +10% boost | |
# Determine if it's a threat | |
is_threat = final_confidence >= 0.6 or threat_prediction == 1 | |
return { | |
"is_threat": is_threat, | |
"final_confidence": final_confidence, | |
"threat_prediction": threat_prediction, | |
"sentiment_analysis": predictions.get("sentiment"), | |
"onnx_prediction": predictions.get("onnx"), | |
"models_used": models_used, | |
"raw_predictions": predictions | |
} | |
except Exception as e: | |
logger.error(f"Error in threat prediction: {e}") | |
return self._create_empty_prediction() | |
def _text_to_features(self, text: str) -> np.ndarray: | |
"""Convert text to numerical features for ONNX model""" | |
try: | |
# Simple feature extraction - you may need to adjust based on your ONNX model requirements | |
# This is a basic approach, you might need to match your training preprocessing | |
# Basic text statistics | |
features = [ | |
len(text), # text length | |
len(text.split()), # word count | |
text.count('!'), # exclamation marks | |
text.count('?'), # question marks | |
text.count('.'), # periods | |
] | |
# Add more features as needed for your specific ONNX model | |
# You might need to use the same vectorizer that was used during training | |
return np.array([features], dtype=np.float32) | |
except Exception as e: | |
logger.error(f"Error creating features: {e}") | |
return np.array([[0.0, 0.0, 0.0, 0.0, 0.0]], dtype=np.float32) | |
def _create_empty_prediction(self) -> Dict[str, Any]: | |
"""Create empty prediction result""" | |
return { | |
"is_threat": False, | |
"final_confidence": 0.0, | |
"threat_prediction": 0, | |
"sentiment_analysis": None, | |
"onnx_prediction": None, | |
"models_used": [], | |
"raw_predictions": {} | |
} | |
def get_status(self) -> Dict[str, Any]: | |
"""Get status of all models""" | |
return { | |
"models_loaded": self.models_loaded, | |
"threat_model": self.threat_model is not None, | |
"sentiment_model": self.sentiment_model is not None, | |
"onnx_model": self.onnx_session is not None, | |
"models_dir": str(self.models_dir), | |
"model_files": { | |
name: path.exists() for name, path in self.model_paths.items() | |
} | |
} | |
def analyze_batch(self, texts: List[str]) -> List[Dict[str, Any]]: | |
"""Analyze multiple texts in batch""" | |
return [self.predict_threat(text) for text in texts] | |