#!/usr/bin/env python3 """ Demo script showing how MLM probability affects encoder model analysis """ import torch from transformers import AutoTokenizer, AutoModelForMaskedLM import warnings warnings.filterwarnings("ignore") def demo_mlm_probability_effect(): """Demonstrate how MLM probability affects the analysis""" print("🎭 MLM Probability Effect Demo") print("=" * 60) # Load a BERT model model_name = "distilbert-base-uncased" print(f"Loading {model_name}...") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForMaskedLM.from_pretrained(model_name) model.eval() # Test text text = "The capital of France is Paris and it is beautiful." print(f"📝 Text: {text}") # Tokenize inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) input_ids = inputs.input_ids tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) print(f"🔤 Tokens: {tokens}") print() # Test different MLM probabilities mlm_probs = [0.1, 0.15, 0.3, 0.5, 0.8] for mlm_prob in mlm_probs: print(f"🎯 MLM Probability: {mlm_prob}") # Simulate the analysis process seq_length = input_ids.size(1) special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id} # Count how many tokens would be analyzed analyzed_count = 0 analyzed_tokens = [] torch.manual_seed(42) # For reproducible results for i in range(seq_length): token = tokens[i] if input_ids[0, i].item() not in special_token_ids: if torch.rand(1).item() < mlm_prob: analyzed_count += 1 analyzed_tokens.append(f"'{token}'") total_content_tokens = sum(1 for i in range(seq_length) if input_ids[0, i].item() not in special_token_ids) print(f" 📊 Analyzed: {analyzed_count}/{total_content_tokens} content tokens ({analyzed_count/total_content_tokens*100:.1f}%)") print(f" 🎯 Analyzed tokens: {', '.join(analyzed_tokens[:5])}" + (f" + {len(analyzed_tokens)-5} more" if len(analyzed_tokens) > 5 else "")) print() def simulate_perplexity_calculation(): """Simulate how different MLM probabilities affect perplexity calculation""" print("🧮 Perplexity Calculation Simulation") print("=" * 60) # Load model model_name = "distilbert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForMaskedLM.from_pretrained(model_name) model.eval() text = "Machine learning is transforming artificial intelligence rapidly." inputs = tokenizer(text, return_tensors="pt") input_ids = inputs.input_ids print(f"📝 Text: {text}") print(f"🔤 Tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}") print() mlm_probs = [0.15, 0.3, 0.5] for mlm_prob in mlm_probs: print(f"🎭 MLM Probability: {mlm_prob}") # Simulate multiple iterations iteration_results = [] for iteration in range(3): # Simulate masking masked_input_ids = input_ids.clone() original_tokens = input_ids.clone() seq_length = input_ids.size(1) mask_indices = [] special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id} torch.manual_seed(42 + iteration) # Different seed per iteration for i in range(seq_length): if input_ids[0, i].item() not in special_token_ids: if torch.rand(1).item() < mlm_prob: mask_indices.append(i) masked_input_ids[0, i] = tokenizer.mask_token_id if not mask_indices: # Ensure at least one token is masked non_special_indices = [i for i in range(seq_length) if input_ids[0, i].item() not in special_token_ids] if non_special_indices: mask_idx = torch.randint(0, len(non_special_indices), (1,)).item() mask_indices = [non_special_indices[mask_idx]] masked_input_ids[0, mask_indices[0]] = tokenizer.mask_token_id # Calculate pseudo-perplexity for masked tokens with torch.no_grad(): outputs = model(masked_input_ids) predictions = outputs.logits masked_token_losses = [] masked_tokens = [] for idx in mask_indices: target_id = original_tokens[0, idx] pred_scores = predictions[0, idx] prob = torch.softmax(pred_scores, dim=-1)[target_id] loss = -torch.log(prob + 1e-10) masked_token_losses.append(loss.item()) token = tokenizer.convert_ids_to_tokens([target_id])[0] masked_tokens.append(token) if masked_token_losses: avg_loss = sum(masked_token_losses) / len(masked_token_losses) perplexity = torch.exp(torch.tensor(avg_loss)).item() iteration_results.append(perplexity) print(f" Iteration {iteration + 1}: {len(mask_indices)} tokens masked") print(f" Masked: {', '.join(masked_tokens[:3])}" + (f" + {len(masked_tokens)-3} more" if len(masked_tokens) > 3 else "")) print(f" Pseudo-perplexity: {perplexity:.2f}") if iteration_results: avg_perplexity = sum(iteration_results) / len(iteration_results) print(f" 📊 Average pseudo-perplexity: {avg_perplexity:.2f}") print() def explain_mlm_probability(): """Explain what MLM probability actually does""" print("💡 Understanding MLM Probability") print("=" * 60) print(""" 🎭 **What is MLM Probability?** MLM (Masked Language Modeling) probability controls what fraction of tokens get randomly selected for detailed perplexity analysis. 📊 **How it works:** • Low MLM prob (0.15): Analyzes ~15% of tokens randomly • High MLM prob (0.5): Analyzes ~50% of tokens randomly • This affects both the average perplexity AND the visualization 🎯 **Why it matters:** • Higher MLM prob = More tokens analyzed = More complete picture • Lower MLM prob = Fewer tokens analyzed = Faster but less comprehensive • The randomness simulates real MLM training conditions 🌈 **Visual Effect:** • Analyzed tokens: Colored by their actual perplexity • Non-analyzed tokens: Shown in gray (baseline) • Try 0.15 vs 0.5 to see the difference! ⚖️ **Trade-offs:** • MLM 0.15: Fast, matches BERT training, but sparse analysis • MLM 0.5: Slower, more comprehensive, but artificial • MLM 0.8: Very slow, nearly complete, but unrealistic """) def main(): """Run MLM probability demonstration""" try: explain_mlm_probability() demo_mlm_probability_effect() simulate_perplexity_calculation() print("🎉 MLM Probability Demo Complete!") print("💡 Now try the app with different MLM probabilities:") print(" • Use 0.15 for standard analysis") print(" • Use 0.5 for more comprehensive analysis") print(" • Watch how the visualization changes!") except Exception as e: print(f"❌ Demo failed: {e}") print("💡 Make sure you have transformers installed: pip install transformers") if __name__ == "__main__": main()