File size: 9,892 Bytes
440be58
 
6865a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440be58
 
b7a8b53
 
 
6865a82
440be58
f6fd023
 
 
440be58
6865a82
 
f6fd023
 
 
 
 
440be58
6865a82
440be58
6865a82
 
 
 
440be58
6865a82
440be58
6a8e86c
b9517cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6865a82
440be58
6865a82
 
440be58
6865a82
 
440be58
6865a82
 
440be58
6865a82
 
440be58
 
6865a82
440be58
ce00daa
6865a82
ce00daa
 
 
 
440be58
b7a8b53
 
440be58
6865a82
440be58
ce00daa
6865a82
ce00daa
 
 
 
 
440be58
6865a82
440be58
42346cf
 
3a5baef
3a3eaaf
42346cf
3a5baef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42346cf
 
440be58
3a3eaaf
 
 
 
 
 
 
 
 
 
 
0bb3936
3a3eaaf
99786ed
 
 
3a3eaaf
 
 
 
 
 
 
 
 
 
 
0bba440
 
 
3a3eaaf
 
99786ed
3a3eaaf
 
 
 
 
 
 
 
 
 
 
 
 
 
99786ed
 
3a3eaaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440be58
ce00daa
2be2951
ce00daa
 
 
 
 
 
 
2be2951
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
---
library_name: transformers
tags:
- pruning
- distillation
- sparsity‑2:4
license: apache-2.0
language:
- en
- de
- fr
- es
- it
- pt
base_model:
- doubledsbv/KafkaLM-15B-Base
pipeline_tag: text-generation
---

<img src="https://cdn-uploads.huggingface.co/production/uploads/645ded34a45b4182d7f5c385/EgsjPDWd37LjAtamiICxk.png" width="480" height="480" alt="image/png">


# Model Description

**KafkaLM‑15B‑Base** is a 15‑billion‑parameter, sparsity‑aware language model distilled from *Mistral‑Small‑24B‑Base‑2501* and further post trained (SFT + DPO + GRPO /w verifiable rewards).  

This experimental model was created in five stages:

| Stage | What we did | Why it matters |
|-------|-------------|----------------|
| **1. SimplePrune** | Applied a hierarchical, hardware‑aware pruning pipeline that combines block‑, channel‑ and 2:4 structured sparsity (≈ 37.5 % parameter reduction) | Slashes memory footprint while minimizing perplexity degradation |
| **2. Teacher calibration** | Briefly fine‑tuned the unpruned 24 B teacher on a 10 B‑token multilingual European corpus on a AMD M300A cluster | Produces stable logits and hidden states for distillation  |
| **3. Knowledge distillation** | Distilled the calibrated teacher into the pruned 15 B student using a **fused loss**:<br/>`L Pooled SquareHead + LKL + 0.25 * LCE` | Transfers teacher capabiities effectively with <15B tokens **(< 2 epochs)** on 64 MI300A nodes |
| **4. SFT+DPO** | Supervised finetuning + Direct Preference Optimization) on curated open-source multilingual and multitask datasets | Enhances model alignment with human preferences while preserving multilingual capabilities |
| **5. RL** | Trained GRPO as separate LoRA adapter to make it easy for serving and optional for using | Enables flexible deployment with optional reinforcement learning benefits without modifying the base model |

**Key capabilities**

* Balanced for both **multitask** and multilingual conversation and long context handling  
* Structured **2:4 sparsity** → runs up to **40 % faster** on sparsity‑aware kernels  
* Distilled on a combination of multilingual pretraining and synthetic data  
* Training pipeline optimized for unified‑memory GPUs (AMD MI300A) but runs on any CUDA / ROCm device

---

### LoRA based reasoning capabilities


[Download adapter from hf](https://huggingface.co/seedboxai/KafkaLM-15B-GRPO_LoRA_Exp)

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Load base model
base_model = AutoModelForCausalLM.from_pretrained("seedboxai/KafkaLM-15B")
tokenizer = AutoTokenizer.from_pretrained("seedboxai/KafkaLM-15B")

# Apply LoRA adapter
model = PeftModel.from_pretrained(
    base_model,
    "seedboxai/KafkaLM-15B-GRPO_LoRA_Exp",
    adapter_name="grpo_lora"
)
```

## Pruning Process

**Pruning & Distillation Strategy — SimplePrune**
Hardware‑aware, hierarchical pipeline. SimplePrune starts with coarse block‑level pruning and drills down to channel‑ and neuron‑level removals, finishing with 2 : 4 structured sparsity. This staged approach converts compression ratios into real memory‑bandwidth and latency gains. ​

**Sensitivity‑guided selection**
Each stage is driven by activation‑magnitude profiles and Hessian‑based importance scores captured asynchronously during training, allowing the framework to run inside the MI300A’s 512 GB unified memory without OOM interruptions. ​

**Two‑phase optimisation** 
A fast greedy pass prunes low‑impact blocks in MLP expansion layers, after which a **Tabu‑Search** meta‑heuristic explores cross‑layer combinations for a better global trade‑off between sparsity and perplexity/KL divergence. ​

**Post‑pruning knowledge distillation**
The pruned 15 B student is distilled from a calibrated 24 B teacher using a fused LSquareHead + KL + 0.25 · CE loss across 20 B multilingual tokens, restoring > 96 % of the original quality in ≤ 2 epochs on up to 64 MI300A nodes. ​

### Results
Up to 40 % parameter reduction (24 B → 15 B) delivers 2× lower TTFT and ≈ 40 % higher tokens/s versus the uncompressed teacher while matching perplexity and divergence metrics—validating SimplePrune as an effective route to deploy KafkaLM in memory‑constrained, sparsity‑accelerated environments. 

| Metric | Mistral‑24B | **KafkaLM‑15B** | Δ |
|--------|-------------|-----------------|---|
| Time‑to‑First‑Token | 4.91 s | **2.46 s** | −50% |
| Prompts / s | 4.70 | **6.55** | +38% |
| Tokens / s | 579 | **812** | +40% |


<img src="https://cdn-uploads.huggingface.co/production/uploads/645ded34a45b4182d7f5c385/4rDhaeC-1GMj6KWbB27f9.png" width="480" height="480" alt="image/png">


### Training scalability (distillation run, MI300A cluster)

| Nodes | Tokens / s | Speed‑up |
|-------|------------|----------|
| 4 | 1 461 | – |
| 8 | 3 327 | 2.3 × |
| 16 | 7 423 | 5.1 × |
| 32 | 15 286 | 10.5 × |
| 64 | 25 455 | 17.4 × |

Near‑linear scaling thanks to sharded ZeRO‑3 + RCCL optimisations.

# Inference


### Transformers

```python
model_name = "seedboxai/KafkaLM-15B"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)

# prepare the model input
prompt = "Why did Kafka hit different?"
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=1024
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() 

response = tokenizer.decode(output_ids, skip_special_tokens=True)

print(response)
```

## vLLM

```python

"""
This example shows how to use KafkaLM-15B in vLLM with and without LoRA reasoning functionality
for offline inference.
"""

from typing import Optional
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
import transformers

SYS_MESSAGE = 'A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.'
tokenizer = transformers.AutoTokenizer.from_pretrained("")

def create_test_prompts(lora_path: str) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
    """Create a list of test prompts with their sampling parameters.
    1 requests for base model, 1 request for the LoRA. 
    """
    return [
        ("Why did Kafka hit different?",
         SamplingParams(temperature=0.7,
                        top_p=0.95,
                        max_tokens=1024), None),
        
        ("Create a Markdown table comparing SHA‑256, BLAKE3, and SHA‑3 with columns: internal structure, block size, and throughput.",
         SamplingParams(temperature=0.6,
                        top_p=0.95,
                        top_k=20,
                        logprobs=1,
                        prompt_logprobs=1,
                        stop=["</s>", "<eos>"],
                        max_tokens=4096),
                        LoRARequest("reasoning-lora", 1, lora_path)),
       
    ]


def process_requests(engine: LLMEngine,
                     test_prompts: list[tuple[str, SamplingParams,
                                              Optional[LoRARequest]]]):
    """Continuously process a list of prompts and handle the outputs."""
    request_id = 0

    while test_prompts or engine.has_unfinished_requests():
        if test_prompts:
            input, sampling_params, lora_request = test_prompts.pop(0)
            prompt = tokenizer.apply_chat_template([{'role':'system', 'content': SYS_MESSAGE},{'role': 'user', 'content': input}], tokenize = False, add_generation_prompt = True)
            engine.add_request(str(request_id),
                               prompt,
                               sampling_params,
                               lora_request=lora_request)
            request_id += 1

        request_outputs: list[RequestOutput] = engine.step()

        for request_output in request_outputs:
            if request_output.finished:
                print(request_output)


def initialize_engine() -> LLMEngine:
    """Initialize the LLMEngine."""

    engine_args = EngineArgs(model="seedboxai/KafkaLM-15B",
                             enable_lora=True,
                             max_loras=1,
                             max_lora_rank=128,
                             max_cpu_loras=2,
                             max_num_seqs=256)
    
    return LLMEngine.from_engine_args(engine_args)


def main():
    """Main function that sets up and runs the prompt processing."""
    engine = initialize_engine()
    lora_path = snapshot_download(repo_id="seedboxai/KafkaLM-15B-GRPO_LoRA_Exp")
    test_prompts = create_test_prompts(lora_path)
    process_requests(engine, test_prompts)


if __name__ == '__main__':
    main()

```


## Citation
```bibtex
@misc{kafkalm2025,
  title={Evaluating AMD's MI300A APU: Performance Insights on LLM Training via Knowledge Distillation},
  author={Dennis Dickmann, Philipp Offenhäuser, Rishabh Saxena, George S. Markomanolis, Alessandro Rigazzi, Patrick Keller, Dennis Hoppe},
  howpublished={Cray User Group Conference, 2025},
  note={to be published},
  year={2025}
}
```