microgpt-deva / README.md
ss-76's picture
Update README.md
05616ba verified
---
license: mit
tags:
- generative
- language-model
- sanskrit
- devanagari
- flashattention
- micro-llm
language:
- sa
datasets:
- custom
library_name: transformers
pipeline_tag: text-generation
---
# 🧠 MicroGPT-Deva: Lightweight Sanskrit Generative LLM
**MicroGPT-Deva** is a compact decoder-only language model trained on Sanskrit text in **Devanagari script**, optimized for text generation tasks. It uses a custom transformer architecture with **FlashAttention** for efficient GPU utilization and fast decoding.
This model is ideal for:
- Generating Sanskrit sentences or paragraphs
- Educational chatbots or creative writing tools
- Deployment on resource-constrained environments (single-GPU)
---
## 🛠️ Model Details
| Property | Value |
|--------------------|------------------------------|
| Architecture | Decoder-only Transformer |
| Vocabulary Size | 12,000 (SentencePiece BPE) |
| Hidden Size | 512 |
| Layers | 8 |
| Attention Heads | 8 |
| Sequence Length | 512 tokens |
| Parameters | ~33M |
| FlashAttention | ✅ Yes |
---
## 📖 Training
- **Data**: Custom Sanskrit dataset of over 100,000+ Devanagari `.txt` files.
- **Tokenizer**: [SentencePiece](https://github.com/google/sentencepiece) BPE model trained with `character_coverage=1.0`.
- **Training Platform**: AWS SageMaker Tesla V100 GPU
- **Framework**: PyTorch with custom FlashAttention blocks
- **Training Time**: ~3 epochs with dynamic batching on sharded data
---
## 💬 Usage
### 🧪 In Python
```python
import torch
import sentencepiece as spm
from microgpt_deva import MicroGPT, Config
# Load tokenizer
sp = spm.SentencePieceProcessor()
sp.load("devanagari.model")
# Load config and model
with open("config.json") as f:
config = Config(json.load(f))
model = MicroGPT(config)
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()
# Generate text
prompt = "कस्मिंश्चिन् नगराभ्याशे "
input_ids = torch.tensor([sp.encode(prompt, out_type=int)], dtype=torch.long)
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=30)
print(sp.decode(output[0].tolist()))