File size: 9,892 Bytes
440be58 6865a82 440be58 b7a8b53 6865a82 440be58 f6fd023 440be58 6865a82 f6fd023 440be58 6865a82 440be58 6865a82 440be58 6865a82 440be58 6a8e86c b9517cb 6865a82 440be58 6865a82 440be58 6865a82 440be58 6865a82 440be58 6865a82 440be58 6865a82 440be58 ce00daa 6865a82 ce00daa 440be58 b7a8b53 440be58 6865a82 440be58 ce00daa 6865a82 ce00daa 440be58 6865a82 440be58 42346cf 3a5baef 3a3eaaf 42346cf 3a5baef 42346cf 440be58 3a3eaaf 0bb3936 3a3eaaf 99786ed 3a3eaaf 0bba440 3a3eaaf 99786ed 3a3eaaf 99786ed 3a3eaaf 440be58 ce00daa 2be2951 ce00daa 2be2951 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
---
library_name: transformers
tags:
- pruning
- distillation
- sparsity‑2:4
license: apache-2.0
language:
- en
- de
- fr
- es
- it
- pt
base_model:
- doubledsbv/KafkaLM-15B-Base
pipeline_tag: text-generation
---
<img src="https://cdn-uploads.huggingface.co/production/uploads/645ded34a45b4182d7f5c385/EgsjPDWd37LjAtamiICxk.png" width="480" height="480" alt="image/png">
# Model Description
**KafkaLM‑15B‑Base** is a 15‑billion‑parameter, sparsity‑aware language model distilled from *Mistral‑Small‑24B‑Base‑2501* and further post trained (SFT + DPO + GRPO /w verifiable rewards).
This experimental model was created in five stages:
| Stage | What we did | Why it matters |
|-------|-------------|----------------|
| **1. SimplePrune** | Applied a hierarchical, hardware‑aware pruning pipeline that combines block‑, channel‑ and 2:4 structured sparsity (≈ 37.5 % parameter reduction) | Slashes memory footprint while minimizing perplexity degradation |
| **2. Teacher calibration** | Briefly fine‑tuned the unpruned 24 B teacher on a 10 B‑token multilingual European corpus on a AMD M300A cluster | Produces stable logits and hidden states for distillation |
| **3. Knowledge distillation** | Distilled the calibrated teacher into the pruned 15 B student using a **fused loss**:<br/>`L Pooled SquareHead + LKL + 0.25 * LCE` | Transfers teacher capabiities effectively with <15B tokens **(< 2 epochs)** on 64 MI300A nodes |
| **4. SFT+DPO** | Supervised finetuning + Direct Preference Optimization) on curated open-source multilingual and multitask datasets | Enhances model alignment with human preferences while preserving multilingual capabilities |
| **5. RL** | Trained GRPO as separate LoRA adapter to make it easy for serving and optional for using | Enables flexible deployment with optional reinforcement learning benefits without modifying the base model |
**Key capabilities**
* Balanced for both **multitask** and multilingual conversation and long context handling
* Structured **2:4 sparsity** → runs up to **40 % faster** on sparsity‑aware kernels
* Distilled on a combination of multilingual pretraining and synthetic data
* Training pipeline optimized for unified‑memory GPUs (AMD MI300A) but runs on any CUDA / ROCm device
---
### LoRA based reasoning capabilities
[Download adapter from hf](https://huggingface.co/seedboxai/KafkaLM-15B-GRPO_LoRA_Exp)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# Load base model
base_model = AutoModelForCausalLM.from_pretrained("seedboxai/KafkaLM-15B")
tokenizer = AutoTokenizer.from_pretrained("seedboxai/KafkaLM-15B")
# Apply LoRA adapter
model = PeftModel.from_pretrained(
base_model,
"seedboxai/KafkaLM-15B-GRPO_LoRA_Exp",
adapter_name="grpo_lora"
)
```
## Pruning Process
**Pruning & Distillation Strategy — SimplePrune**
Hardware‑aware, hierarchical pipeline. SimplePrune starts with coarse block‑level pruning and drills down to channel‑ and neuron‑level removals, finishing with 2 : 4 structured sparsity. This staged approach converts compression ratios into real memory‑bandwidth and latency gains.
**Sensitivity‑guided selection**
Each stage is driven by activation‑magnitude profiles and Hessian‑based importance scores captured asynchronously during training, allowing the framework to run inside the MI300A’s 512 GB unified memory without OOM interruptions.
**Two‑phase optimisation**
A fast greedy pass prunes low‑impact blocks in MLP expansion layers, after which a **Tabu‑Search** meta‑heuristic explores cross‑layer combinations for a better global trade‑off between sparsity and perplexity/KL divergence.
**Post‑pruning knowledge distillation**
The pruned 15 B student is distilled from a calibrated 24 B teacher using a fused LSquareHead + KL + 0.25 · CE loss across 20 B multilingual tokens, restoring > 96 % of the original quality in ≤ 2 epochs on up to 64 MI300A nodes.
### Results
Up to 40 % parameter reduction (24 B → 15 B) delivers 2× lower TTFT and ≈ 40 % higher tokens/s versus the uncompressed teacher while matching perplexity and divergence metrics—validating SimplePrune as an effective route to deploy KafkaLM in memory‑constrained, sparsity‑accelerated environments.
| Metric | Mistral‑24B | **KafkaLM‑15B** | Δ |
|--------|-------------|-----------------|---|
| Time‑to‑First‑Token | 4.91 s | **2.46 s** | −50% |
| Prompts / s | 4.70 | **6.55** | +38% |
| Tokens / s | 579 | **812** | +40% |
<img src="https://cdn-uploads.huggingface.co/production/uploads/645ded34a45b4182d7f5c385/4rDhaeC-1GMj6KWbB27f9.png" width="480" height="480" alt="image/png">
### Training scalability (distillation run, MI300A cluster)
| Nodes | Tokens / s | Speed‑up |
|-------|------------|----------|
| 4 | 1 461 | – |
| 8 | 3 327 | 2.3 × |
| 16 | 7 423 | 5.1 × |
| 32 | 15 286 | 10.5 × |
| 64 | 25 455 | 17.4 × |
Near‑linear scaling thanks to sharded ZeRO‑3 + RCCL optimisations.
# Inference
### Transformers
```python
model_name = "seedboxai/KafkaLM-15B"
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
# prepare the model input
prompt = "Why did Kafka hit different?"
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# conduct text completion
generated_ids = model.generate(
**model_inputs,
max_new_tokens=1024
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
response = tokenizer.decode(output_ids, skip_special_tokens=True)
print(response)
```
## vLLM
```python
"""
This example shows how to use KafkaLM-15B in vLLM with and without LoRA reasoning functionality
for offline inference.
"""
from typing import Optional
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
import transformers
SYS_MESSAGE = 'A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.'
tokenizer = transformers.AutoTokenizer.from_pretrained("")
def create_test_prompts(lora_path: str) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
"""Create a list of test prompts with their sampling parameters.
1 requests for base model, 1 request for the LoRA.
"""
return [
("Why did Kafka hit different?",
SamplingParams(temperature=0.7,
top_p=0.95,
max_tokens=1024), None),
("Create a Markdown table comparing SHA‑256, BLAKE3, and SHA‑3 with columns: internal structure, block size, and throughput.",
SamplingParams(temperature=0.6,
top_p=0.95,
top_k=20,
logprobs=1,
prompt_logprobs=1,
stop=["</s>", "<eos>"],
max_tokens=4096),
LoRARequest("reasoning-lora", 1, lora_path)),
]
def process_requests(engine: LLMEngine,
test_prompts: list[tuple[str, SamplingParams,
Optional[LoRARequest]]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
input, sampling_params, lora_request = test_prompts.pop(0)
prompt = tokenizer.apply_chat_template([{'role':'system', 'content': SYS_MESSAGE},{'role': 'user', 'content': input}], tokenize = False, add_generation_prompt = True)
engine.add_request(str(request_id),
prompt,
sampling_params,
lora_request=lora_request)
request_id += 1
request_outputs: list[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
print(request_output)
def initialize_engine() -> LLMEngine:
"""Initialize the LLMEngine."""
engine_args = EngineArgs(model="seedboxai/KafkaLM-15B",
enable_lora=True,
max_loras=1,
max_lora_rank=128,
max_cpu_loras=2,
max_num_seqs=256)
return LLMEngine.from_engine_args(engine_args)
def main():
"""Main function that sets up and runs the prompt processing."""
engine = initialize_engine()
lora_path = snapshot_download(repo_id="seedboxai/KafkaLM-15B-GRPO_LoRA_Exp")
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)
if __name__ == '__main__':
main()
```
## Citation
```bibtex
@misc{kafkalm2025,
title={Evaluating AMD's MI300A APU: Performance Insights on LLM Training via Knowledge Distillation},
author={Dennis Dickmann, Philipp Offenhäuser, Rishabh Saxena, George S. Markomanolis, Alessandro Rigazzi, Patrick Keller, Dennis Hoppe},
howpublished={Cray User Group Conference, 2025},
note={to be published},
year={2025}
}
``` |