#!/usr/bin/env python3 """ Example usage of the quantized ONNX LazarusNLP IndoBERT model. Demonstrates basic inference, batch processing, and similarity computation. """ import onnxruntime as ort from transformers import AutoTokenizer import numpy as np import time from sklearn.metrics.pairwise import cosine_similarity def load_model(model_path="./"): """Load the quantized ONNX model and tokenizer.""" print("Loading quantized ONNX model...") # Load ONNX session session = ort.InferenceSession(f"{model_path}/model.onnx") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path) print(f"✓ Model loaded successfully") print(f"✓ Tokenizer max length: {tokenizer.model_max_length}") return session, tokenizer def get_embeddings(session, tokenizer, texts, pool_strategy="mean"): """ Get embeddings for texts using the ONNX model. Args: session: ONNX inference session tokenizer: HuggingFace tokenizer texts: List of texts or single text pool_strategy: Pooling strategy ('mean', 'cls', 'max') Returns: numpy array of embeddings """ if isinstance(texts, str): texts = [texts] # Tokenize inputs = tokenizer(texts, return_tensors="np", padding=True, truncation=True) # Run inference outputs = session.run(None, { 'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'] }) # Extract embeddings hidden_states = outputs[0] # Shape: [batch_size, seq_len, hidden_size] attention_mask = inputs['attention_mask'] if pool_strategy == "mean": # Mean pooling with attention mask mask_expanded = np.expand_dims(attention_mask, axis=-1) masked_embeddings = hidden_states * mask_expanded sum_embeddings = np.sum(masked_embeddings, axis=1) sum_mask = np.sum(mask_expanded, axis=1) embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9) elif pool_strategy == "cls": # Use [CLS] token embedding embeddings = hidden_states[:, 0, :] elif pool_strategy == "max": # Max pooling embeddings = np.max(hidden_states, axis=1) else: raise ValueError(f"Unknown pooling strategy: {pool_strategy}") return embeddings def example_basic_usage(): """Basic usage example.""" print("\n" + "="*50) print("BASIC USAGE EXAMPLE") print("="*50) # Load model session, tokenizer = load_model() # Single text processing text = "Teknologi kecerdasan buatan berkembang sangat pesat di Indonesia." start_time = time.time() embeddings = get_embeddings(session, tokenizer, text) inference_time = time.time() - start_time print(f"Input text: {text}") print(f"Embedding shape: {embeddings.shape}") print(f"Inference time: {inference_time:.4f}s") print(f"Sample embedding values: {embeddings[0][:5]}") def example_batch_processing(): """Batch processing example.""" print("\n" + "="*50) print("BATCH PROCESSING EXAMPLE") print("="*50) # Load model session, tokenizer = load_model() # Multiple texts texts = [ "Saya suka makan nasi gudeg.", "Artificial intelligence adalah teknologi masa depan.", "Indonesia memiliki kebudayaan yang sangat beragam.", "Machine learning membantu menganalisis data besar.", "Pantai Bali sangat indah untuk berlibur." ] print(f"Processing {len(texts)} texts...") start_time = time.time() embeddings = get_embeddings(session, tokenizer, texts) batch_time = time.time() - start_time print(f"Batch embedding shape: {embeddings.shape}") print(f"Batch processing time: {batch_time:.4f}s") print(f"Average time per text: {batch_time/len(texts):.4f}s") return embeddings, texts def example_similarity_search(): """Similarity search example.""" print("\n" + "="*50) print("SIMILARITY SEARCH EXAMPLE") print("="*50) # Load model session, tokenizer = load_model() # Documents for similarity search documents = [ "AI dan machine learning mengubah cara kerja industri teknologi.", "Kecerdasan buatan membantu otomatisasi proses bisnis modern.", "Nasi rendang adalah makanan tradisional Indonesia yang lezat.", "Kuliner Indonesia memiliki cita rasa yang unik dan beragam.", "Deep learning adalah subset dari machine learning yang powerful.", "Pantai Lombok menawarkan pemandangan yang menakjubkan.", ] query = "Teknologi AI untuk bisnis" print(f"Query: {query}") print(f"Searching in {len(documents)} documents...") # Get embeddings query_embedding = get_embeddings(session, tokenizer, query) doc_embeddings = get_embeddings(session, tokenizer, documents) # Calculate similarities similarities = cosine_similarity(query_embedding, doc_embeddings)[0] # Sort by similarity ranked_docs = sorted(zip(documents, similarities), key=lambda x: x[1], reverse=True) print("\nTop 3 most similar documents:") for i, (doc, sim) in enumerate(ranked_docs[:3]): print(f"{i+1}. Similarity: {sim:.4f}") print(f" Document: {doc}") def example_long_text_processing(): """Long text processing example.""" print("\n" + "="*50) print("LONG TEXT PROCESSING EXAMPLE") print("="*50) # Load model session, tokenizer = load_model() # Create long text long_text = """ Perkembangan teknologi artificial intelligence di Indonesia menunjukkan tren yang sangat positif dengan banyaknya startup dan perusahaan teknologi yang mulai mengadopsi solusi berbasis AI untuk meningkatkan efisiensi operasional, customer experience, dan inovasi produk. Industri fintech, e-commerce, dan healthcare menjadi sektor yang paling aktif dalam implementasi AI. Pemerintah Indonesia juga mendukung ekosistem AI melalui berbagai program dan kebijakan yang mendorong transformasi digital. Universitas dan institusi penelitian berkontribusi dalam pengembangan talenta AI berkualitas. Tantangan yang dihadapi meliputi ketersediaan data berkualitas, infrastruktur teknologi, dan regulasi yang mendukung inovasi namun tetap melindungi privasi dan keamanan data. Kolaborasi antara pemerintah, industri, dan akademisi menjadi kunci sukses pengembangan AI di Indonesia untuk mencapai visi Indonesia 2045 sebagai negara maju. """ print(f"Processing long text ({len(long_text)} characters)...") # Process with different pooling strategies strategies = ["mean", "cls", "max"] for strategy in strategies: start_time = time.time() embeddings = get_embeddings(session, tokenizer, long_text.strip(), pool_strategy=strategy) process_time = time.time() - start_time print(f"Pooling: {strategy:4s} | Shape: {embeddings.shape} | Time: {process_time:.4f}s") def example_performance_benchmark(): """Performance benchmark example.""" print("\n" + "="*50) print("PERFORMANCE BENCHMARK") print("="*50) # Load model session, tokenizer = load_model() # Test texts of different lengths test_cases = [ ("Short", "Halo dunia!"), ("Medium", "Teknologi AI berkembang sangat pesat dan mengubah berbagai industri di seluruh dunia."), ("Long", " ".join(["Kalimat panjang dengan banyak kata untuk menguji performa model."] * 20)) ] print("Benchmarking different text lengths...") for name, text in test_cases: times = [] # Warm up get_embeddings(session, tokenizer, text) # Benchmark for _ in range(10): start_time = time.time() embeddings = get_embeddings(session, tokenizer, text) times.append(time.time() - start_time) avg_time = np.mean(times) std_time = np.std(times) token_count = len(tokenizer.encode(text)) print(f"{name:6s} ({token_count:3d} tokens): {avg_time:.4f}s ± {std_time:.4f}s") def validate_model(): """Validate model functionality.""" print("\n" + "="*50) print("MODEL VALIDATION") print("="*50) try: # Load model session, tokenizer = load_model() # Test basic functionality test_text = "Tes validasi model ONNX." embeddings = get_embeddings(session, tokenizer, test_text) # Validation checks assert embeddings.shape[0] == 1, "Batch size should be 1" assert embeddings.shape[1] == 768, "Hidden size should be 768" assert not np.isnan(embeddings).any(), "No NaN values allowed" assert not np.isinf(embeddings).any(), "No Inf values allowed" print("✅ Model validation passed!") print(f"✅ Output shape: {embeddings.shape}") print(f"✅ Output range: [{embeddings.min():.4f}, {embeddings.max():.4f}]") except Exception as e: print(f"❌ Model validation failed: {e}") raise def main(): """Run all examples.""" print("🚀 LazarusNLP IndoBERT ONNX - Example Usage") # Validate model first validate_model() # Run examples example_basic_usage() example_batch_processing() example_similarity_search() example_long_text_processing() example_performance_benchmark() print("\n" + "="*50) print("🎉 All examples completed successfully!") print("="*50) if __name__ == "__main__": main()