|
|
|
""" |
|
Example usage of the quantized ONNX LazarusNLP IndoBERT model. |
|
Demonstrates basic inference, batch processing, and similarity computation. |
|
""" |
|
|
|
import onnxruntime as ort |
|
from transformers import AutoTokenizer |
|
import numpy as np |
|
import time |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
def load_model(model_path="./"): |
|
"""Load the quantized ONNX model and tokenizer.""" |
|
print("Loading quantized ONNX model...") |
|
|
|
|
|
session = ort.InferenceSession(f"{model_path}/model.onnx") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
print(f"β Model loaded successfully") |
|
print(f"β Tokenizer max length: {tokenizer.model_max_length}") |
|
|
|
return session, tokenizer |
|
|
|
def get_embeddings(session, tokenizer, texts, pool_strategy="mean"): |
|
""" |
|
Get embeddings for texts using the ONNX model. |
|
|
|
Args: |
|
session: ONNX inference session |
|
tokenizer: HuggingFace tokenizer |
|
texts: List of texts or single text |
|
pool_strategy: Pooling strategy ('mean', 'cls', 'max') |
|
|
|
Returns: |
|
numpy array of embeddings |
|
""" |
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
|
|
inputs = tokenizer(texts, return_tensors="np", padding=True, truncation=True) |
|
|
|
|
|
outputs = session.run(None, { |
|
'input_ids': inputs['input_ids'], |
|
'attention_mask': inputs['attention_mask'] |
|
}) |
|
|
|
|
|
hidden_states = outputs[0] |
|
attention_mask = inputs['attention_mask'] |
|
|
|
if pool_strategy == "mean": |
|
|
|
mask_expanded = np.expand_dims(attention_mask, axis=-1) |
|
masked_embeddings = hidden_states * mask_expanded |
|
sum_embeddings = np.sum(masked_embeddings, axis=1) |
|
sum_mask = np.sum(mask_expanded, axis=1) |
|
embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9) |
|
elif pool_strategy == "cls": |
|
|
|
embeddings = hidden_states[:, 0, :] |
|
elif pool_strategy == "max": |
|
|
|
embeddings = np.max(hidden_states, axis=1) |
|
else: |
|
raise ValueError(f"Unknown pooling strategy: {pool_strategy}") |
|
|
|
return embeddings |
|
|
|
def example_basic_usage(): |
|
"""Basic usage example.""" |
|
print("\n" + "="*50) |
|
print("BASIC USAGE EXAMPLE") |
|
print("="*50) |
|
|
|
|
|
session, tokenizer = load_model() |
|
|
|
|
|
text = "Teknologi kecerdasan buatan berkembang sangat pesat di Indonesia." |
|
|
|
start_time = time.time() |
|
embeddings = get_embeddings(session, tokenizer, text) |
|
inference_time = time.time() - start_time |
|
|
|
print(f"Input text: {text}") |
|
print(f"Embedding shape: {embeddings.shape}") |
|
print(f"Inference time: {inference_time:.4f}s") |
|
print(f"Sample embedding values: {embeddings[0][:5]}") |
|
|
|
def example_batch_processing(): |
|
"""Batch processing example.""" |
|
print("\n" + "="*50) |
|
print("BATCH PROCESSING EXAMPLE") |
|
print("="*50) |
|
|
|
|
|
session, tokenizer = load_model() |
|
|
|
|
|
texts = [ |
|
"Saya suka makan nasi gudeg.", |
|
"Artificial intelligence adalah teknologi masa depan.", |
|
"Indonesia memiliki kebudayaan yang sangat beragam.", |
|
"Machine learning membantu menganalisis data besar.", |
|
"Pantai Bali sangat indah untuk berlibur." |
|
] |
|
|
|
print(f"Processing {len(texts)} texts...") |
|
|
|
start_time = time.time() |
|
embeddings = get_embeddings(session, tokenizer, texts) |
|
batch_time = time.time() - start_time |
|
|
|
print(f"Batch embedding shape: {embeddings.shape}") |
|
print(f"Batch processing time: {batch_time:.4f}s") |
|
print(f"Average time per text: {batch_time/len(texts):.4f}s") |
|
|
|
return embeddings, texts |
|
|
|
def example_similarity_search(): |
|
"""Similarity search example.""" |
|
print("\n" + "="*50) |
|
print("SIMILARITY SEARCH EXAMPLE") |
|
print("="*50) |
|
|
|
|
|
session, tokenizer = load_model() |
|
|
|
|
|
documents = [ |
|
"AI dan machine learning mengubah cara kerja industri teknologi.", |
|
"Kecerdasan buatan membantu otomatisasi proses bisnis modern.", |
|
"Nasi rendang adalah makanan tradisional Indonesia yang lezat.", |
|
"Kuliner Indonesia memiliki cita rasa yang unik dan beragam.", |
|
"Deep learning adalah subset dari machine learning yang powerful.", |
|
"Pantai Lombok menawarkan pemandangan yang menakjubkan.", |
|
] |
|
|
|
query = "Teknologi AI untuk bisnis" |
|
|
|
print(f"Query: {query}") |
|
print(f"Searching in {len(documents)} documents...") |
|
|
|
|
|
query_embedding = get_embeddings(session, tokenizer, query) |
|
doc_embeddings = get_embeddings(session, tokenizer, documents) |
|
|
|
|
|
similarities = cosine_similarity(query_embedding, doc_embeddings)[0] |
|
|
|
|
|
ranked_docs = sorted(zip(documents, similarities), key=lambda x: x[1], reverse=True) |
|
|
|
print("\nTop 3 most similar documents:") |
|
for i, (doc, sim) in enumerate(ranked_docs[:3]): |
|
print(f"{i+1}. Similarity: {sim:.4f}") |
|
print(f" Document: {doc}") |
|
|
|
def example_long_text_processing(): |
|
"""Long text processing example.""" |
|
print("\n" + "="*50) |
|
print("LONG TEXT PROCESSING EXAMPLE") |
|
print("="*50) |
|
|
|
|
|
session, tokenizer = load_model() |
|
|
|
|
|
long_text = """ |
|
Perkembangan teknologi artificial intelligence di Indonesia menunjukkan tren yang sangat positif |
|
dengan banyaknya startup dan perusahaan teknologi yang mulai mengadopsi solusi berbasis AI untuk |
|
meningkatkan efisiensi operasional, customer experience, dan inovasi produk. Industri fintech, |
|
e-commerce, dan healthcare menjadi sektor yang paling aktif dalam implementasi AI. Pemerintah |
|
Indonesia juga mendukung ekosistem AI melalui berbagai program dan kebijakan yang mendorong |
|
transformasi digital. Universitas dan institusi penelitian berkontribusi dalam pengembangan |
|
talenta AI berkualitas. Tantangan yang dihadapi meliputi ketersediaan data berkualitas, |
|
infrastruktur teknologi, dan regulasi yang mendukung inovasi namun tetap melindungi privasi |
|
dan keamanan data. Kolaborasi antara pemerintah, industri, dan akademisi menjadi kunci sukses |
|
pengembangan AI di Indonesia untuk mencapai visi Indonesia 2045 sebagai negara maju. |
|
""" |
|
|
|
print(f"Processing long text ({len(long_text)} characters)...") |
|
|
|
|
|
strategies = ["mean", "cls", "max"] |
|
|
|
for strategy in strategies: |
|
start_time = time.time() |
|
embeddings = get_embeddings(session, tokenizer, long_text.strip(), pool_strategy=strategy) |
|
process_time = time.time() - start_time |
|
|
|
print(f"Pooling: {strategy:4s} | Shape: {embeddings.shape} | Time: {process_time:.4f}s") |
|
|
|
def example_performance_benchmark(): |
|
"""Performance benchmark example.""" |
|
print("\n" + "="*50) |
|
print("PERFORMANCE BENCHMARK") |
|
print("="*50) |
|
|
|
|
|
session, tokenizer = load_model() |
|
|
|
|
|
test_cases = [ |
|
("Short", "Halo dunia!"), |
|
("Medium", "Teknologi AI berkembang sangat pesat dan mengubah berbagai industri di seluruh dunia."), |
|
("Long", " ".join(["Kalimat panjang dengan banyak kata untuk menguji performa model."] * 20)) |
|
] |
|
|
|
print("Benchmarking different text lengths...") |
|
|
|
for name, text in test_cases: |
|
times = [] |
|
|
|
|
|
get_embeddings(session, tokenizer, text) |
|
|
|
|
|
for _ in range(10): |
|
start_time = time.time() |
|
embeddings = get_embeddings(session, tokenizer, text) |
|
times.append(time.time() - start_time) |
|
|
|
avg_time = np.mean(times) |
|
std_time = np.std(times) |
|
token_count = len(tokenizer.encode(text)) |
|
|
|
print(f"{name:6s} ({token_count:3d} tokens): {avg_time:.4f}s Β± {std_time:.4f}s") |
|
|
|
def validate_model(): |
|
"""Validate model functionality.""" |
|
print("\n" + "="*50) |
|
print("MODEL VALIDATION") |
|
print("="*50) |
|
|
|
try: |
|
|
|
session, tokenizer = load_model() |
|
|
|
|
|
test_text = "Tes validasi model ONNX." |
|
embeddings = get_embeddings(session, tokenizer, test_text) |
|
|
|
|
|
assert embeddings.shape[0] == 1, "Batch size should be 1" |
|
assert embeddings.shape[1] == 768, "Hidden size should be 768" |
|
assert not np.isnan(embeddings).any(), "No NaN values allowed" |
|
assert not np.isinf(embeddings).any(), "No Inf values allowed" |
|
|
|
print("β
Model validation passed!") |
|
print(f"β
Output shape: {embeddings.shape}") |
|
print(f"β
Output range: [{embeddings.min():.4f}, {embeddings.max():.4f}]") |
|
|
|
except Exception as e: |
|
print(f"β Model validation failed: {e}") |
|
raise |
|
|
|
def main(): |
|
"""Run all examples.""" |
|
print("π LazarusNLP IndoBERT ONNX - Example Usage") |
|
|
|
|
|
validate_model() |
|
|
|
|
|
example_basic_usage() |
|
example_batch_processing() |
|
example_similarity_search() |
|
example_long_text_processing() |
|
example_performance_benchmark() |
|
|
|
print("\n" + "="*50) |
|
print("π All examples completed successfully!") |
|
print("="*50) |
|
|
|
if __name__ == "__main__": |
|
main() |