Upload folder using huggingface_hub
Browse files- .gitignore +7 -0
- LICENSE +21 -0
- README.md +170 -3
- config.py +9 -0
- models/__init__.py +0 -0
- models/model.py +113 -0
- scripts/generate.py +183 -0
- scripts/memory.py +42 -0
- scripts/prepare_data.py +26 -0
- scripts/tokenizer_setup.py +33 -0
- scripts/train.py +133 -0
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/*
|
| 2 |
+
*.pt
|
| 3 |
+
*.json
|
| 4 |
+
.idea
|
| 5 |
+
__pycache__
|
| 6 |
+
venv
|
| 7 |
+
memory.db
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Brett Moore
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,170 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Microformer
|
| 2 |
+
|
| 3 |
+
**Microformer** is a minimal, educational-scale transformer language model built from scratch in PyTorch.
|
| 4 |
+
Inspired by [nanoGPT](https://github.com/karpathy/nanoGPT) and OpenAI’s GPT-1, Microformer is designed for learning, experimentation, and prototyping on lightweight datasets like [text8](https://mattmahoney.net/dc/textdata.html) or Tiny Shakespeare.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## Features
|
| 9 |
+
|
| 10 |
+
- Decoder-only transformer (GPT-style) architecture
|
| 11 |
+
- **Stacked adapters per layer for dual-memory:**
|
| 12 |
+
- **Long-term adapters** (for corpus/knowledge facts)
|
| 13 |
+
- **Session adapters** (for rapid, online, user/session-specific learning)
|
| 14 |
+
- Choice of character-level **or** subword/BPE tokenization (configurable)
|
| 15 |
+
- Learnable positional encoding
|
| 16 |
+
- Multi-head self-attention
|
| 17 |
+
- Configurable depth, embedding size, sequence length, and attention heads
|
| 18 |
+
- Simple end-to-end pipeline: preprocessing, training, and text generation
|
| 19 |
+
- Modular, readable code ideal for educational use and tinkering
|
| 20 |
+
- Temperature and multinomial sampling in text generation
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## What’s Unique: Stacked Adapters for Dual-Memory Learning
|
| 25 |
+
|
| 26 |
+
Microformer implements **two adapters in every transformer block**:
|
| 27 |
+
|
| 28 |
+
- **Long-term adapter:**
|
| 29 |
+
Trained with your full corpus during batch/corpus training.
|
| 30 |
+
Stores stable, general “knowledge” (e.g., literary style, factual info).
|
| 31 |
+
|
| 32 |
+
- **Session adapter:**
|
| 33 |
+
Starts blank and is trained *on the fly* during chat or interactive teaching.
|
| 34 |
+
Lets you rapidly “teach” new facts, styles, or user preferences without overwriting core knowledge.
|
| 35 |
+
|
| 36 |
+
At inference, the outputs of both adapters (plus the core transformer) are combined—giving the model both stable and flexible, session-specific memory, just like a human brain’s “temporal lobe” and “core memory”.
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## Project Structure
|
| 41 |
+
|
| 42 |
+
```
|
| 43 |
+
microformer/
|
| 44 |
+
├── config.py # Hyperparameters and model settings
|
| 45 |
+
├── data/
|
| 46 |
+
│ ├── corpus.txt # Raw training text
|
| 47 |
+
│ ├── train.pt # Preprocessed training tensor (token IDs)
|
| 48 |
+
│ ├── val.pt # Validation tensor (token IDs)
|
| 49 |
+
│ ├── vocab.json # Vocabulary (char or subword, stoi/itos mapping)
|
| 50 |
+
│ └── tokenizer.json # (optional) BPE tokenizer file if using subwords
|
| 51 |
+
├── models/
|
| 52 |
+
│ └── model.py # Transformer model definition (Microformer)
|
| 53 |
+
├── scripts/
|
| 54 |
+
│ ├── prepare_data.py # Data preprocessing/tokenization
|
| 55 |
+
│ ├── train.py # Training script (trains long-term adapters)
|
| 56 |
+
│ ├── generate_text.py # Inference/generation + online learning (session adapters)
|
| 57 |
+
│ └── tokenizer_setup.py # BPE Tokenizer
|
| 58 |
+
└── README.md
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## Quickstart
|
| 64 |
+
|
| 65 |
+
1. **Prepare your corpus and run the tokenizer**
|
| 66 |
+
|
| 67 |
+
Place your text data in `data/corpus.txt`.
|
| 68 |
+
|
| 69 |
+
2. **Choose your tokenizer:**
|
| 70 |
+
|
| 71 |
+
- **Character-level (default):**
|
| 72 |
+
No extra steps needed.
|
| 73 |
+
|
| 74 |
+
- **BPE/Subword (recommended for rich/modern text):**
|
| 75 |
+
```bash
|
| 76 |
+
python scripts/tokenizer_setup.py --input data/corpus.txt --vocab_size 1000
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
3. **Prepare the dataset**
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
python scripts/prepare_data.py
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
4. **Train the model (long-term knowledge)**
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
python scripts/train.py
|
| 89 |
+
```
|
| 90 |
+
- This trains only the **long-term adapters** and core weights.
|
| 91 |
+
- Session adapters remain untrained (blank) until chat time.
|
| 92 |
+
|
| 93 |
+
5. **Generate text and teach interactively (session memory)**
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
python scripts/generate_text.py
|
| 97 |
+
```
|
| 98 |
+
- Loads your trained model.
|
| 99 |
+
- Prompts for a seed string and temperature.
|
| 100 |
+
- **Allows you to “teach” new facts on the fly!**
|
| 101 |
+
- New knowledge is stored in session adapters—does *not* overwrite long-term knowledge.
|
| 102 |
+
|
| 103 |
+
---
|
| 104 |
+
|
| 105 |
+
## Example Config (`config.py`)
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
EMBED_DIM = 128
|
| 109 |
+
NUM_HEADS = 4
|
| 110 |
+
NUM_LAYERS = 2
|
| 111 |
+
FF_DIM = 256
|
| 112 |
+
MAX_SEQ_LEN = 128
|
| 113 |
+
BATCH_SIZE = 32
|
| 114 |
+
ADAPTER_DIM = 32 # Used for both long-term and session adapters
|
| 115 |
+
VOCAB_SIZE = 100 # Set automatically from tokenizer/vocab
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
---
|
| 119 |
+
|
| 120 |
+
## Using the Dual-Memory System
|
| 121 |
+
|
| 122 |
+
- **Long-term adapters:**
|
| 123 |
+
Learned during `train.py`—persist between runs.
|
| 124 |
+
|
| 125 |
+
- **Session adapters:**
|
| 126 |
+
Learned during interactive chat in `generate_text.py`—resettable (optional) between users/sessions.
|
| 127 |
+
|
| 128 |
+
- **Teach new facts by entering a prompt and providing your ideal answer.**
|
| 129 |
+
The model will “remember” this during the session, even if it wasn’t present in the training corpus.
|
| 130 |
+
|
| 131 |
+
---
|
| 132 |
+
|
| 133 |
+
## Customization & Ideas
|
| 134 |
+
|
| 135 |
+
- Use BPE/subword tokenization for more expressive modeling (recommended for non-trivial datasets)
|
| 136 |
+
- Add more adapters or experiment with gating (e.g., blend adapters by context)
|
| 137 |
+
- Combine with a key-value retrieval or buffer for truly persistent “user memory”
|
| 138 |
+
- Visualize training with TensorBoard or wandb
|
| 139 |
+
- Tinker with alternative attention or memory mechanisms
|
| 140 |
+
|
| 141 |
+
---
|
| 142 |
+
|
| 143 |
+
## Requirements
|
| 144 |
+
|
| 145 |
+
- Python 3.8+
|
| 146 |
+
- [PyTorch](https://pytorch.org/)
|
| 147 |
+
- [tokenizers](https://github.com/huggingface/tokenizers) (for BPE/subword)
|
| 148 |
+
|
| 149 |
+
Install dependencies with:
|
| 150 |
+
```bash
|
| 151 |
+
pip install torch tokenizers
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
---
|
| 155 |
+
|
| 156 |
+
## Credits
|
| 157 |
+
|
| 158 |
+
- Inspired by [nanoGPT](https://github.com/karpathy/nanoGPT) and [minGPT](https://github.com/karpathy/minGPT) by Andrej Karpathy
|
| 159 |
+
- Adapter and continual-learning inspiration from recent NLP research ([Houlsby et al. 2019](https://arxiv.org/abs/1902.00751))
|
| 160 |
+
- Built using concepts from the original [GPT-1 paper](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
| 161 |
+
|
| 162 |
+
---
|
| 163 |
+
|
| 164 |
+
## License
|
| 165 |
+
|
| 166 |
+
MIT License – Use freely for learning and experimentation.
|
| 167 |
+
|
| 168 |
+
---
|
| 169 |
+
|
| 170 |
+
**Happy tinkering with dual-memory transformers!**
|
config.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hyperparameters and config settings
|
| 2 |
+
|
| 3 |
+
EMBED_DIM = 256 # Size of token embeddings
|
| 4 |
+
NUM_HEADS = 8 # Number of attention heads
|
| 5 |
+
NUM_LAYERS = 4 # Number of transformer blocks
|
| 6 |
+
FF_DIM = 512 # Feedforward layer dimension
|
| 7 |
+
MAX_SEQ_LEN = 256 # Maximum sequence length
|
| 8 |
+
VOCAB_SIZE = 100 # Placeholder (will be overridden based on dataset)
|
| 9 |
+
ADAPTER_DIM = 32 # Add in adapter for continual learning
|
models/__init__.py
ADDED
|
File without changes
|
models/model.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
class PositionalEncoding(nn.Module):
|
| 10 |
+
def __init__(self, d_model, max_len=5000):
|
| 11 |
+
super().__init__()
|
| 12 |
+
pe = torch.zeros(max_len, d_model)
|
| 13 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
| 14 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 15 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 16 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 17 |
+
pe = pe.unsqueeze(0)
|
| 18 |
+
self.register_buffer('pe', pe)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return x + self.pe[:, :x.size(1)]
|
| 22 |
+
|
| 23 |
+
class MultiHeadSelfAttention(nn.Module):
|
| 24 |
+
def __init__(self, embed_dim, num_heads):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
attn_output, _ = self.attn(x, x, x)
|
| 30 |
+
return attn_output
|
| 31 |
+
|
| 32 |
+
class FeedForward(nn.Module):
|
| 33 |
+
def __init__(self, embed_dim, ff_dim):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.ff = nn.Sequential(
|
| 36 |
+
nn.Linear(embed_dim, ff_dim),
|
| 37 |
+
nn.ReLU(),
|
| 38 |
+
nn.Linear(ff_dim, embed_dim)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
return self.ff(x)
|
| 43 |
+
|
| 44 |
+
# --- NEW: Adapter Block ---
|
| 45 |
+
class Adapter(nn.Module):
|
| 46 |
+
def __init__(self, dim, bottleneck=32):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.down = nn.Linear(dim, bottleneck)
|
| 49 |
+
self.relu = nn.ReLU()
|
| 50 |
+
self.up = nn.Linear(bottleneck, dim)
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
return x + self.up(self.relu(self.down(x))) # Residual
|
| 53 |
+
|
| 54 |
+
class TransformerBlock(nn.Module):
|
| 55 |
+
def __init__(self, embed_dim, num_heads, ff_dim,
|
| 56 |
+
long_term_adapter_dim=None, session_adapter_dim=None):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
|
| 59 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
| 60 |
+
self.ff = FeedForward(embed_dim, ff_dim)
|
| 61 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
| 62 |
+
# Two adapters: one for long-term (rarely updated), one for session (online)
|
| 63 |
+
self.long_term_adapter = Adapter(embed_dim, long_term_adapter_dim) if long_term_adapter_dim else None
|
| 64 |
+
self.session_adapter = Adapter(embed_dim, session_adapter_dim) if session_adapter_dim else None
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
x = self.norm1(x + self.attn(x))
|
| 68 |
+
x = self.norm2(x + self.ff(x))
|
| 69 |
+
# Add both adapters’ outputs, if present
|
| 70 |
+
if self.long_term_adapter is not None:
|
| 71 |
+
x = self.long_term_adapter(x)
|
| 72 |
+
if self.session_adapter is not None:
|
| 73 |
+
x = self.session_adapter(x)
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Microformer(nn.Module):
|
| 78 |
+
def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers, max_seq_len,
|
| 79 |
+
long_term_adapter_dim=None, session_adapter_dim=None):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| 82 |
+
self.positional_encoding = PositionalEncoding(embed_dim, max_seq_len)
|
| 83 |
+
self.layers = nn.ModuleList([
|
| 84 |
+
TransformerBlock(
|
| 85 |
+
embed_dim, num_heads, ff_dim,
|
| 86 |
+
long_term_adapter_dim=long_term_adapter_dim,
|
| 87 |
+
session_adapter_dim=session_adapter_dim
|
| 88 |
+
)
|
| 89 |
+
for _ in range(num_layers)
|
| 90 |
+
])
|
| 91 |
+
self.output = nn.Linear(embed_dim, vocab_size)
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
x = self.embedding(x)
|
| 95 |
+
x = self.positional_encoding(x)
|
| 96 |
+
for layer in self.layers:
|
| 97 |
+
x = layer(x)
|
| 98 |
+
return self.output(x)
|
| 99 |
+
|
| 100 |
+
def freeze_except_adapters(self, session_only=True, include_output=True):
|
| 101 |
+
for param in self.parameters():
|
| 102 |
+
param.requires_grad = False
|
| 103 |
+
for layer in self.layers:
|
| 104 |
+
if getattr(layer, 'session_adapter', None) is not None:
|
| 105 |
+
for param in layer.session_adapter.parameters():
|
| 106 |
+
param.requires_grad = True
|
| 107 |
+
if not session_only and getattr(layer, 'long_term_adapter', None) is not None:
|
| 108 |
+
for param in layer.long_term_adapter.parameters():
|
| 109 |
+
param.requires_grad = True
|
| 110 |
+
if include_output:
|
| 111 |
+
for param in self.output.parameters():
|
| 112 |
+
param.requires_grad = True
|
| 113 |
+
|
scripts/generate.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
from models.model import Microformer
|
| 10 |
+
from tokenizers import Tokenizer
|
| 11 |
+
from config import VOCAB_SIZE, EMBED_DIM, NUM_HEADS, FF_DIM, NUM_LAYERS, MAX_SEQ_LEN, ADAPTER_DIM
|
| 12 |
+
import sqlite3
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
# --- Load tokenizer and model ---
|
| 16 |
+
tokenizer = Tokenizer.from_file("data/tokenizer.json")
|
| 17 |
+
VOCAB_SIZE = tokenizer.get_vocab_size()
|
| 18 |
+
|
| 19 |
+
model = Microformer(
|
| 20 |
+
vocab_size=VOCAB_SIZE,
|
| 21 |
+
embed_dim=EMBED_DIM,
|
| 22 |
+
num_heads=NUM_HEADS,
|
| 23 |
+
ff_dim=FF_DIM,
|
| 24 |
+
num_layers=NUM_LAYERS,
|
| 25 |
+
max_seq_len=MAX_SEQ_LEN,
|
| 26 |
+
long_term_adapter_dim=ADAPTER_DIM,
|
| 27 |
+
session_adapter_dim=ADAPTER_DIM
|
| 28 |
+
)
|
| 29 |
+
model.load_state_dict(torch.load("microformer.pt"))
|
| 30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
model.to(device)
|
| 32 |
+
|
| 33 |
+
# --- Freeze all but session adapters and output for online learning ---
|
| 34 |
+
model.freeze_except_adapters(session_only=True, include_output=True)
|
| 35 |
+
|
| 36 |
+
criterion = nn.CrossEntropyLoss()
|
| 37 |
+
optimizer = optim.Adam(
|
| 38 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 39 |
+
lr=1e-2 # High LR for visible learning during teaching
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# --- Memory DB setup ---
|
| 43 |
+
conn = sqlite3.connect("memory.db")
|
| 44 |
+
c = conn.cursor()
|
| 45 |
+
c.execute("""
|
| 46 |
+
CREATE TABLE IF NOT EXISTS memory (
|
| 47 |
+
timestamp TEXT,
|
| 48 |
+
prompt TEXT,
|
| 49 |
+
response TEXT
|
| 50 |
+
)
|
| 51 |
+
""")
|
| 52 |
+
conn.commit()
|
| 53 |
+
|
| 54 |
+
def top_k_top_p_filtering(logits, top_k=50, top_p=0.9):
|
| 55 |
+
logits = logits.squeeze(0) # [1, vocab] → [vocab]
|
| 56 |
+
probs = torch.softmax(logits, dim=-1)
|
| 57 |
+
|
| 58 |
+
# Sort probabilities
|
| 59 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 60 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 61 |
+
|
| 62 |
+
# Top-p mask
|
| 63 |
+
sorted_mask = cumulative_probs > top_p
|
| 64 |
+
sorted_mask[1:] = sorted_mask[:-1].clone()
|
| 65 |
+
sorted_mask[0] = False
|
| 66 |
+
|
| 67 |
+
# Top-k mask
|
| 68 |
+
if top_k < sorted_probs.size(0):
|
| 69 |
+
sorted_mask[top_k:] = True
|
| 70 |
+
|
| 71 |
+
# Zero out masked values
|
| 72 |
+
sorted_probs[sorted_mask] = 0.0
|
| 73 |
+
|
| 74 |
+
# Normalize and sample
|
| 75 |
+
sorted_probs /= sorted_probs.sum()
|
| 76 |
+
sampled_relative_index = torch.multinomial(sorted_probs, 1).item()
|
| 77 |
+
sampled_token_id = sorted_indices[sampled_relative_index].item()
|
| 78 |
+
|
| 79 |
+
return sampled_token_id
|
| 80 |
+
|
| 81 |
+
def generate(prompt, length=100, temperature=1.0, top_p=0.9, top_k=50):
|
| 82 |
+
input_ids = tokenizer.encode(prompt).ids
|
| 83 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
|
| 84 |
+
|
| 85 |
+
eos_token_id = tokenizer.token_to_id("<EOS>")
|
| 86 |
+
|
| 87 |
+
for _ in range(length):
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
logits = model(input_tensor)
|
| 90 |
+
logits = logits[:, -1, :] / temperature
|
| 91 |
+
|
| 92 |
+
# Repetition penalty
|
| 93 |
+
for token_id in input_tensor[0].tolist():
|
| 94 |
+
logits[0, token_id] *= 0.8
|
| 95 |
+
|
| 96 |
+
next_token_id = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
| 97 |
+
|
| 98 |
+
input_tensor = torch.cat([input_tensor, torch.tensor([[next_token_id]], device=device)], dim=1)
|
| 99 |
+
|
| 100 |
+
if next_token_id == eos_token_id:
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
output_ids = input_tensor[0].tolist()
|
| 104 |
+
decoded = tokenizer.decode(output_ids)
|
| 105 |
+
|
| 106 |
+
if "<EOS>" in decoded:
|
| 107 |
+
decoded = decoded.split("<EOS>")[0].strip()
|
| 108 |
+
|
| 109 |
+
return decoded
|
| 110 |
+
|
| 111 |
+
def online_unsupervised_update(model, tokenizer, text, optimizer, loss_fn, device, max_len=64):
|
| 112 |
+
# Always called after freeze_except_adapters(session_only=True)
|
| 113 |
+
ids = tokenizer.encode(text).ids + [tokenizer.token_to_id("<EOS>")]
|
| 114 |
+
if len(ids) < 2:
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
ids = ids[:max_len + 1]
|
| 118 |
+
input_ids = ids[:-1]
|
| 119 |
+
target_ids = ids[1:]
|
| 120 |
+
input_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(input_ids))
|
| 121 |
+
target_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(target_ids))
|
| 122 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
|
| 123 |
+
target_tensor = torch.tensor([target_ids], dtype=torch.long, device=device)
|
| 124 |
+
|
| 125 |
+
model.train()
|
| 126 |
+
logits = model(input_tensor)
|
| 127 |
+
logits = logits.view(-1, logits.size(-1))
|
| 128 |
+
targets = target_tensor.view(-1)
|
| 129 |
+
loss = loss_fn(logits, targets)
|
| 130 |
+
optimizer.zero_grad()
|
| 131 |
+
loss.backward()
|
| 132 |
+
optimizer.step()
|
| 133 |
+
model.eval()
|
| 134 |
+
return loss.item()
|
| 135 |
+
|
| 136 |
+
# Optional: Reset session adapter weights between sessions
|
| 137 |
+
def reset_session_adapters(model):
|
| 138 |
+
for layer in model.layers:
|
| 139 |
+
if getattr(layer, 'session_adapter', None) is not None:
|
| 140 |
+
for param in layer.session_adapter.parameters():
|
| 141 |
+
if param.data is not None:
|
| 142 |
+
nn.init.zeros_(param.data)
|
| 143 |
+
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
while True:
|
| 146 |
+
prompt = input("\nEnter a prompt (or 'exit' to quit): ")
|
| 147 |
+
if prompt.lower() in {"exit", "quit"}:
|
| 148 |
+
break
|
| 149 |
+
temp = float(input("Temperature (e.g. 0.7, 1.0): "))
|
| 150 |
+
|
| 151 |
+
output = generate(prompt, length=100, temperature=temp, top_p=0.9, top_k=50)
|
| 152 |
+
print("\nGenerated text:\n")
|
| 153 |
+
print(output)
|
| 154 |
+
|
| 155 |
+
# Online learning: always update session adapters only
|
| 156 |
+
teach = input("\nDo you want to teach the model a better answer? (y/N): ").strip().lower()
|
| 157 |
+
if teach == "y":
|
| 158 |
+
your_answer = input("Type your ideal response for this prompt: ")
|
| 159 |
+
model.freeze_except_adapters(session_only=True, include_output=True)
|
| 160 |
+
online_text = prompt + " " + your_answer
|
| 161 |
+
loss = online_unsupervised_update(
|
| 162 |
+
model, tokenizer, online_text, optimizer, criterion, device, max_len=MAX_SEQ_LEN
|
| 163 |
+
)
|
| 164 |
+
print(f"[Online update loss: {loss:.4f}]")
|
| 165 |
+
else:
|
| 166 |
+
model.freeze_except_adapters(session_only=True, include_output=True)
|
| 167 |
+
online_text = prompt + " " + output
|
| 168 |
+
loss = online_unsupervised_update(
|
| 169 |
+
model, tokenizer, online_text, optimizer, criterion, device, max_len=MAX_SEQ_LEN
|
| 170 |
+
)
|
| 171 |
+
print(f"[Online (self-improve) update loss: {loss:.4f}]")
|
| 172 |
+
|
| 173 |
+
# Store the interaction in memory DB as before
|
| 174 |
+
c.execute("INSERT INTO memory (timestamp, prompt, response) VALUES (?, ?, ?)",
|
| 175 |
+
(datetime.now().isoformat(timespec='seconds'), prompt, output))
|
| 176 |
+
conn.commit()
|
| 177 |
+
|
| 178 |
+
print("\nRecent memory:")
|
| 179 |
+
for row in c.execute("SELECT * FROM memory ORDER BY timestamp DESC LIMIT 5"):
|
| 180 |
+
print(f"[{row[0]}] {row[1]} → {row[2]}")
|
| 181 |
+
|
| 182 |
+
# Optional: Uncomment to reset fast-memory (session adapters) between users/sessions
|
| 183 |
+
# reset_session_adapters(model)
|
scripts/memory.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 4 |
+
|
| 5 |
+
import sqlite3
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
# Connect to SQLite database (will create if it doesn't exist)
|
| 9 |
+
conn = sqlite3.connect("memory.db")
|
| 10 |
+
cursor = conn.cursor()
|
| 11 |
+
|
| 12 |
+
# Create memory table if it doesn't exist
|
| 13 |
+
cursor.execute("""
|
| 14 |
+
CREATE TABLE IF NOT EXISTS memory (
|
| 15 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 16 |
+
prompt TEXT NOT NULL,
|
| 17 |
+
response TEXT NOT NULL,
|
| 18 |
+
timestamp TEXT DEFAULT CURRENT_TIMESTAMP
|
| 19 |
+
)
|
| 20 |
+
""")
|
| 21 |
+
conn.commit()
|
| 22 |
+
|
| 23 |
+
def save_memory(prompt: str, response: str):
|
| 24 |
+
"""Save a prompt-response pair to the memory database."""
|
| 25 |
+
cursor.execute(
|
| 26 |
+
"INSERT INTO memory (prompt, response) VALUES (?, ?)",
|
| 27 |
+
(prompt, response)
|
| 28 |
+
)
|
| 29 |
+
conn.commit()
|
| 30 |
+
|
| 31 |
+
def recall_memories(limit: int = 5):
|
| 32 |
+
"""Retrieve the most recent prompt-response pairs."""
|
| 33 |
+
cursor.execute(
|
| 34 |
+
"SELECT prompt, response, timestamp FROM memory ORDER BY timestamp DESC LIMIT ?",
|
| 35 |
+
(limit,)
|
| 36 |
+
)
|
| 37 |
+
return cursor.fetchall()
|
| 38 |
+
|
| 39 |
+
def clear_memory():
|
| 40 |
+
"""Delete all memory records."""
|
| 41 |
+
cursor.execute("DELETE FROM memory")
|
| 42 |
+
conn.commit()
|
scripts/prepare_data.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
import numpy
|
| 4 |
+
from tokenizers import Tokenizer
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
# Load tokenizer
|
| 8 |
+
tokenizer = Tokenizer.from_file("data/tokenizer.json")
|
| 9 |
+
VOCAB_SIZE = tokenizer.get_vocab_size()
|
| 10 |
+
|
| 11 |
+
# Load corpus
|
| 12 |
+
with open("data/corpus.txt", "r", encoding="utf-8") as f:
|
| 13 |
+
text = f.read()
|
| 14 |
+
|
| 15 |
+
# Encode with BPE tokenizer
|
| 16 |
+
encoded = tokenizer.encode(text).ids
|
| 17 |
+
|
| 18 |
+
# Convert to tensor and split into train/val
|
| 19 |
+
data = torch.tensor(encoded, dtype=torch.long)
|
| 20 |
+
split = int(0.9 * len(data))
|
| 21 |
+
train_data = data[:split]
|
| 22 |
+
val_data = data[split:]
|
| 23 |
+
|
| 24 |
+
# Save outputs
|
| 25 |
+
torch.save(train_data, "data/train.pt")
|
| 26 |
+
torch.save(val_data, "data/val.pt")
|
scripts/tokenizer_setup.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
# Paths
|
| 6 |
+
corpus_path = Path("data/corpus.txt")
|
| 7 |
+
tokenizer_path = Path("data/tokenizer.json")
|
| 8 |
+
|
| 9 |
+
# Read corpus
|
| 10 |
+
with corpus_path.open("r", encoding="utf-8") as f:
|
| 11 |
+
lines = [line.strip() for line in f if line.strip()]
|
| 12 |
+
|
| 13 |
+
# Initialize tokenizer with BPE model
|
| 14 |
+
tokenizer = Tokenizer(models.BPE())
|
| 15 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
| 16 |
+
pre_tokenizers.Whitespace(),
|
| 17 |
+
pre_tokenizers.Punctuation()
|
| 18 |
+
])
|
| 19 |
+
|
| 20 |
+
# Train tokenizer
|
| 21 |
+
trainer = trainers.BpeTrainer(vocab_size=5000, special_tokens=["<PAD>", "<UNK>", "<EOS>"])
|
| 22 |
+
tokenizer.train_from_iterator(lines, trainer)
|
| 23 |
+
|
| 24 |
+
# Save tokenizer
|
| 25 |
+
tokenizer.save(str(tokenizer_path))
|
| 26 |
+
|
| 27 |
+
# Create vocab.json for compatibility
|
| 28 |
+
vocab = tokenizer.get_vocab()
|
| 29 |
+
stoi = vocab
|
| 30 |
+
itos = {v: k for k, v in vocab.items()}
|
| 31 |
+
|
| 32 |
+
with open("data/vocab.json", "w") as f:
|
| 33 |
+
json.dump({"stoi": stoi, "itos": itos}, f)
|
scripts/train.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
import json
|
| 9 |
+
from models.model import Microformer
|
| 10 |
+
from config import *
|
| 11 |
+
|
| 12 |
+
# ------------------------
|
| 13 |
+
# LOAD DATA AND VOCAB
|
| 14 |
+
# ------------------------
|
| 15 |
+
with open("data/vocab.json", "r") as f:
|
| 16 |
+
vocab = json.load(f)
|
| 17 |
+
stoi = vocab["stoi"]
|
| 18 |
+
itos = {int(k): v for k, v in vocab["itos"].items()}
|
| 19 |
+
VOCAB_SIZE = len(stoi)
|
| 20 |
+
|
| 21 |
+
data = torch.load("data/train.pt")
|
| 22 |
+
SEQ_LEN = MAX_SEQ_LEN
|
| 23 |
+
BATCH_SIZE = 32
|
| 24 |
+
|
| 25 |
+
# Drop remainder for clean batch shape
|
| 26 |
+
num_batches = len(data) // (SEQ_LEN * BATCH_SIZE)
|
| 27 |
+
trimmed_len = num_batches * SEQ_LEN * BATCH_SIZE
|
| 28 |
+
data = data[:trimmed_len]
|
| 29 |
+
data = data.view(BATCH_SIZE, -1) # shape: (BATCH_SIZE, n_chunks)
|
| 30 |
+
|
| 31 |
+
def get_batch(start_idx):
|
| 32 |
+
x = data[:, start_idx:start_idx+SEQ_LEN]
|
| 33 |
+
y = data[:, start_idx+1:start_idx+SEQ_LEN+1]
|
| 34 |
+
return x, y
|
| 35 |
+
|
| 36 |
+
# ------------------------
|
| 37 |
+
# DEVICE SETUP
|
| 38 |
+
# ------------------------
|
| 39 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 40 |
+
|
| 41 |
+
# ------------------------
|
| 42 |
+
# MODEL INSTANTIATION (with stacked adapters)
|
| 43 |
+
# ------------------------
|
| 44 |
+
model = Microformer(
|
| 45 |
+
VOCAB_SIZE,
|
| 46 |
+
EMBED_DIM,
|
| 47 |
+
NUM_HEADS,
|
| 48 |
+
FF_DIM,
|
| 49 |
+
NUM_LAYERS,
|
| 50 |
+
MAX_SEQ_LEN,
|
| 51 |
+
long_term_adapter_dim=ADAPTER_DIM, # <-- set in config
|
| 52 |
+
session_adapter_dim=ADAPTER_DIM # <-- set in config
|
| 53 |
+
)
|
| 54 |
+
model.to(device)
|
| 55 |
+
|
| 56 |
+
# ------------------------
|
| 57 |
+
# TRAIN LONG-TERM ADAPTERS ONLY
|
| 58 |
+
# ------------------------
|
| 59 |
+
model.freeze_except_adapters(session_only=False, include_output=True)
|
| 60 |
+
# (Optionally, explicitly freeze session adapters:)
|
| 61 |
+
for layer in model.layers:
|
| 62 |
+
if getattr(layer, 'session_adapter', None) is not None:
|
| 63 |
+
for param in layer.session_adapter.parameters():
|
| 64 |
+
param.requires_grad = False
|
| 65 |
+
|
| 66 |
+
criterion = nn.CrossEntropyLoss()
|
| 67 |
+
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
|
| 68 |
+
|
| 69 |
+
# ------------------------
|
| 70 |
+
# MAIN BATCH TRAINING LOOP (CORPUS)
|
| 71 |
+
# ------------------------
|
| 72 |
+
for epoch in range(6):
|
| 73 |
+
for i in range(0, data.shape[1] - SEQ_LEN, SEQ_LEN):
|
| 74 |
+
inputs, targets = get_batch(i)
|
| 75 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
| 76 |
+
optimizer.zero_grad()
|
| 77 |
+
out = model(inputs)
|
| 78 |
+
loss = criterion(out.reshape(-1, VOCAB_SIZE), targets.reshape(-1))
|
| 79 |
+
loss.backward()
|
| 80 |
+
optimizer.step()
|
| 81 |
+
|
| 82 |
+
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
|
| 83 |
+
|
| 84 |
+
torch.save(model.state_dict(), "microformer.pt")
|
| 85 |
+
|
| 86 |
+
# ------------------------
|
| 87 |
+
# ONLINE (SESSION) LEARNING UTILITY
|
| 88 |
+
# ------------------------
|
| 89 |
+
def online_unsupervised_update(model, tokenizer, text, optimizer, loss_fn, device, max_len=64):
|
| 90 |
+
# Only update session adapters/output layer; call freeze_except_adapters before this as needed.
|
| 91 |
+
ids = tokenizer.encode(text).ids + [tokenizer.token_to_id("<EOS>")]
|
| 92 |
+
if len(ids) < 2:
|
| 93 |
+
return None # not enough tokens
|
| 94 |
+
|
| 95 |
+
ids = ids[:max_len + 1]
|
| 96 |
+
input_ids = ids[:-1]
|
| 97 |
+
target_ids = ids[1:]
|
| 98 |
+
input_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(input_ids))
|
| 99 |
+
target_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(target_ids))
|
| 100 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
|
| 101 |
+
target_tensor = torch.tensor([target_ids], dtype=torch.long, device=device)
|
| 102 |
+
|
| 103 |
+
model.train()
|
| 104 |
+
logits = model(input_tensor)
|
| 105 |
+
logits = logits.view(-1, logits.size(-1))
|
| 106 |
+
targets = target_tensor.view(-1)
|
| 107 |
+
loss = loss_fn(logits, targets)
|
| 108 |
+
optimizer.zero_grad()
|
| 109 |
+
loss.backward()
|
| 110 |
+
optimizer.step()
|
| 111 |
+
model.eval()
|
| 112 |
+
return loss.item()
|
| 113 |
+
|
| 114 |
+
# ------------------------
|
| 115 |
+
# SESSION ADAPTER RESET FUNCTION (OPTIONAL)
|
| 116 |
+
# ------------------------
|
| 117 |
+
def reset_session_adapters(model):
|
| 118 |
+
for layer in model.layers:
|
| 119 |
+
if getattr(layer, 'session_adapter', None) is not None:
|
| 120 |
+
for param in layer.session_adapter.parameters():
|
| 121 |
+
if param.data is not None:
|
| 122 |
+
nn.init.zeros_(param.data)
|
| 123 |
+
|
| 124 |
+
# ------------------------
|
| 125 |
+
# USAGE FOR ONLINE LEARNING (after chat, NOT in main batch loop):
|
| 126 |
+
# ------------------------
|
| 127 |
+
# from tokenizers import Tokenizer
|
| 128 |
+
# tokenizer = Tokenizer.from_file("data/tokenizer.json")
|
| 129 |
+
# model.freeze_except_adapters(session_only=True, include_output=True)
|
| 130 |
+
# message = "Who is Buck?"
|
| 131 |
+
# loss = online_unsupervised_update(model, tokenizer, message, optimizer, criterion, device, max_len=SEQ_LEN)
|
| 132 |
+
# print(f"Online update loss: {loss}")
|
| 133 |
+
|