LazarusNLP-indobert-onnx / example_usage.py
asmud's picture
Upload folder using huggingface_hub
e9e7e23 verified
#!/usr/bin/env python3
"""
Example usage of the quantized ONNX LazarusNLP IndoBERT model.
Demonstrates basic inference, batch processing, and similarity computation.
"""
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
import time
from sklearn.metrics.pairwise import cosine_similarity
def load_model(model_path="./"):
"""Load the quantized ONNX model and tokenizer."""
print("Loading quantized ONNX model...")
# Load ONNX session
session = ort.InferenceSession(f"{model_path}/model.onnx")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
print(f"βœ“ Model loaded successfully")
print(f"βœ“ Tokenizer max length: {tokenizer.model_max_length}")
return session, tokenizer
def get_embeddings(session, tokenizer, texts, pool_strategy="mean"):
"""
Get embeddings for texts using the ONNX model.
Args:
session: ONNX inference session
tokenizer: HuggingFace tokenizer
texts: List of texts or single text
pool_strategy: Pooling strategy ('mean', 'cls', 'max')
Returns:
numpy array of embeddings
"""
if isinstance(texts, str):
texts = [texts]
# Tokenize
inputs = tokenizer(texts, return_tensors="np", padding=True, truncation=True)
# Run inference
outputs = session.run(None, {
'input_ids': inputs['input_ids'],
'attention_mask': inputs['attention_mask']
})
# Extract embeddings
hidden_states = outputs[0] # Shape: [batch_size, seq_len, hidden_size]
attention_mask = inputs['attention_mask']
if pool_strategy == "mean":
# Mean pooling with attention mask
mask_expanded = np.expand_dims(attention_mask, axis=-1)
masked_embeddings = hidden_states * mask_expanded
sum_embeddings = np.sum(masked_embeddings, axis=1)
sum_mask = np.sum(mask_expanded, axis=1)
embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
elif pool_strategy == "cls":
# Use [CLS] token embedding
embeddings = hidden_states[:, 0, :]
elif pool_strategy == "max":
# Max pooling
embeddings = np.max(hidden_states, axis=1)
else:
raise ValueError(f"Unknown pooling strategy: {pool_strategy}")
return embeddings
def example_basic_usage():
"""Basic usage example."""
print("\n" + "="*50)
print("BASIC USAGE EXAMPLE")
print("="*50)
# Load model
session, tokenizer = load_model()
# Single text processing
text = "Teknologi kecerdasan buatan berkembang sangat pesat di Indonesia."
start_time = time.time()
embeddings = get_embeddings(session, tokenizer, text)
inference_time = time.time() - start_time
print(f"Input text: {text}")
print(f"Embedding shape: {embeddings.shape}")
print(f"Inference time: {inference_time:.4f}s")
print(f"Sample embedding values: {embeddings[0][:5]}")
def example_batch_processing():
"""Batch processing example."""
print("\n" + "="*50)
print("BATCH PROCESSING EXAMPLE")
print("="*50)
# Load model
session, tokenizer = load_model()
# Multiple texts
texts = [
"Saya suka makan nasi gudeg.",
"Artificial intelligence adalah teknologi masa depan.",
"Indonesia memiliki kebudayaan yang sangat beragam.",
"Machine learning membantu menganalisis data besar.",
"Pantai Bali sangat indah untuk berlibur."
]
print(f"Processing {len(texts)} texts...")
start_time = time.time()
embeddings = get_embeddings(session, tokenizer, texts)
batch_time = time.time() - start_time
print(f"Batch embedding shape: {embeddings.shape}")
print(f"Batch processing time: {batch_time:.4f}s")
print(f"Average time per text: {batch_time/len(texts):.4f}s")
return embeddings, texts
def example_similarity_search():
"""Similarity search example."""
print("\n" + "="*50)
print("SIMILARITY SEARCH EXAMPLE")
print("="*50)
# Load model
session, tokenizer = load_model()
# Documents for similarity search
documents = [
"AI dan machine learning mengubah cara kerja industri teknologi.",
"Kecerdasan buatan membantu otomatisasi proses bisnis modern.",
"Nasi rendang adalah makanan tradisional Indonesia yang lezat.",
"Kuliner Indonesia memiliki cita rasa yang unik dan beragam.",
"Deep learning adalah subset dari machine learning yang powerful.",
"Pantai Lombok menawarkan pemandangan yang menakjubkan.",
]
query = "Teknologi AI untuk bisnis"
print(f"Query: {query}")
print(f"Searching in {len(documents)} documents...")
# Get embeddings
query_embedding = get_embeddings(session, tokenizer, query)
doc_embeddings = get_embeddings(session, tokenizer, documents)
# Calculate similarities
similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
# Sort by similarity
ranked_docs = sorted(zip(documents, similarities), key=lambda x: x[1], reverse=True)
print("\nTop 3 most similar documents:")
for i, (doc, sim) in enumerate(ranked_docs[:3]):
print(f"{i+1}. Similarity: {sim:.4f}")
print(f" Document: {doc}")
def example_long_text_processing():
"""Long text processing example."""
print("\n" + "="*50)
print("LONG TEXT PROCESSING EXAMPLE")
print("="*50)
# Load model
session, tokenizer = load_model()
# Create long text
long_text = """
Perkembangan teknologi artificial intelligence di Indonesia menunjukkan tren yang sangat positif
dengan banyaknya startup dan perusahaan teknologi yang mulai mengadopsi solusi berbasis AI untuk
meningkatkan efisiensi operasional, customer experience, dan inovasi produk. Industri fintech,
e-commerce, dan healthcare menjadi sektor yang paling aktif dalam implementasi AI. Pemerintah
Indonesia juga mendukung ekosistem AI melalui berbagai program dan kebijakan yang mendorong
transformasi digital. Universitas dan institusi penelitian berkontribusi dalam pengembangan
talenta AI berkualitas. Tantangan yang dihadapi meliputi ketersediaan data berkualitas,
infrastruktur teknologi, dan regulasi yang mendukung inovasi namun tetap melindungi privasi
dan keamanan data. Kolaborasi antara pemerintah, industri, dan akademisi menjadi kunci sukses
pengembangan AI di Indonesia untuk mencapai visi Indonesia 2045 sebagai negara maju.
"""
print(f"Processing long text ({len(long_text)} characters)...")
# Process with different pooling strategies
strategies = ["mean", "cls", "max"]
for strategy in strategies:
start_time = time.time()
embeddings = get_embeddings(session, tokenizer, long_text.strip(), pool_strategy=strategy)
process_time = time.time() - start_time
print(f"Pooling: {strategy:4s} | Shape: {embeddings.shape} | Time: {process_time:.4f}s")
def example_performance_benchmark():
"""Performance benchmark example."""
print("\n" + "="*50)
print("PERFORMANCE BENCHMARK")
print("="*50)
# Load model
session, tokenizer = load_model()
# Test texts of different lengths
test_cases = [
("Short", "Halo dunia!"),
("Medium", "Teknologi AI berkembang sangat pesat dan mengubah berbagai industri di seluruh dunia."),
("Long", " ".join(["Kalimat panjang dengan banyak kata untuk menguji performa model."] * 20))
]
print("Benchmarking different text lengths...")
for name, text in test_cases:
times = []
# Warm up
get_embeddings(session, tokenizer, text)
# Benchmark
for _ in range(10):
start_time = time.time()
embeddings = get_embeddings(session, tokenizer, text)
times.append(time.time() - start_time)
avg_time = np.mean(times)
std_time = np.std(times)
token_count = len(tokenizer.encode(text))
print(f"{name:6s} ({token_count:3d} tokens): {avg_time:.4f}s Β± {std_time:.4f}s")
def validate_model():
"""Validate model functionality."""
print("\n" + "="*50)
print("MODEL VALIDATION")
print("="*50)
try:
# Load model
session, tokenizer = load_model()
# Test basic functionality
test_text = "Tes validasi model ONNX."
embeddings = get_embeddings(session, tokenizer, test_text)
# Validation checks
assert embeddings.shape[0] == 1, "Batch size should be 1"
assert embeddings.shape[1] == 768, "Hidden size should be 768"
assert not np.isnan(embeddings).any(), "No NaN values allowed"
assert not np.isinf(embeddings).any(), "No Inf values allowed"
print("βœ… Model validation passed!")
print(f"βœ… Output shape: {embeddings.shape}")
print(f"βœ… Output range: [{embeddings.min():.4f}, {embeddings.max():.4f}]")
except Exception as e:
print(f"❌ Model validation failed: {e}")
raise
def main():
"""Run all examples."""
print("πŸš€ LazarusNLP IndoBERT ONNX - Example Usage")
# Validate model first
validate_model()
# Run examples
example_basic_usage()
example_batch_processing()
example_similarity_search()
example_long_text_processing()
example_performance_benchmark()
print("\n" + "="*50)
print("πŸŽ‰ All examples completed successfully!")
print("="*50)
if __name__ == "__main__":
main()