Neural Memory Model for Russian Text Generation
This model implements a neural memory architecture for Russian text generation using PyTorch and the Titans library. The architecture is based on the implementation from lucidrains/titans-pytorch.
Model Description
The model uses a Transformer architecture enhanced with neural memory capabilities from the Titans library for improved context handling and long-range dependencies in text generation.
Architecture Source
The core architecture is derived from the Titans PyTorch implementation by Phil Wang (@lucidrains). The original implementation provides the following key components that we utilize:
- Memory-enhanced Transformer architecture
- Flexible attention mechanisms
- Neural memory layers
Key Features
- Neural memory architecture with customizable depth and size
- Sliding window attention mechanism
- Gradient accumulation for stable training
- CUDA-optimized implementation
Requirements
Environment
- Python: 3.9.21
- CUDA: 11.8
- GPU with at least 16GB VRAM recommended
Key Dependencies
Python version: 3.9.21
CUDA version: 11.8
Requirements:
adam-atan2-pytorch==0.1.18
datasets==3.2.0
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5
nvidia-nvtx-cu12==12.4.127
titans-pytorch==0.3.25
torchaudio==2.5.1
torchvision==0.20.1
transformers==4.48.3
triton==3.1.0
wandb==0.19.6
Example
The repository includes complete training and inference code. Key components:
- Data preprocessing (WikiDatasetPreprocessor)
- Custom dataset implementation (WikiTextDataset)
- Training loop with gradient accumulation
- Validation and checkpointing
Example Code
import os
import warnings
from pathlib import Path
from typing import List, Dict, Optional, Tuple
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
GPT2TokenizerFast,
PreTrainedModel,
PreTrainedTokenizer,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
PretrainedConfig,
GenerationMixin,
pipeline
)
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from huggingface_hub import HfApi, login
from datasets import load_dataset
from tqdm import tqdm
from adam_atan2_pytorch import AdoptAtan2
from titans_pytorch import (
MemoryAsContextTransformer,
MemoryMLP,
MemoryAttention
)
# Отключаем предупреждения
warnings.filterwarnings("ignore", category=UserWarning)
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.cache_size_limit = 100000
torch._dynamo.config.disable = True
# Настройки CUDA
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32'
# Константы
repo_id = 'Grpp/memory-transformer-ru'
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
PRIME_LENGTH = 100
GENERATE_LENGTH = 512
SHOULD_GENERATE = True
SEQ_LEN = 512
# Константы для нейронной памяти
NEURAL_MEMORY_DEPTH = 2
NUM_PERSIST_MEM = 4
NUM_LONGTERM_MEM = 4
NEURAL_MEM_LAYERS = (2, 4, 6)
NEURAL_MEM_GATE_ATTN_OUTPUT = False
NEURAL_MEM_MOMENTUM = True
NEURAL_MEM_MOMENTUM_ORDER = 1
NEURAL_MEM_QK_NORM = True
NEURAL_MEM_MAX_LR = 1e-1
USE_MEM_ATTENTION_MODEL = False
WINDOW_SIZE = 32
NEURAL_MEM_SEGMENT_LEN = 4
NEURAL_MEM_BATCH_SIZE = 128
SLIDING_WINDOWS = True
STORE_ATTN_POOL_CHUNKS = True
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
NEURAL_MEM_WEIGHT_RESIDUAL = True
class MemoryTransformerConfig(PretrainedConfig):
model_type = "memory_transformer"
def __init__(
self,
vocab_size=50257,
dim=384,
depth=8,
segment_len=32,
num_persist_mem=4,
num_longterm_mem=4,
neural_mem_layers=(2, 4, 6),
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
**kwargs
):
self.vocab_size = vocab_size
self.dim = dim
self.depth = depth
self.segment_len = segment_len
self.num_persist_mem = num_persist_mem
self.num_longterm_mem = num_longterm_mem
self.neural_mem_layers = neural_mem_layers
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
class MemoryTransformerForCausalLM(PreTrainedModel, GenerationMixin):
config_class = MemoryTransformerConfig
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
neural_memory_model = (
MemoryAttention(dim=64) if USE_MEM_ATTENTION_MODEL
else MemoryMLP(dim=64, depth=NEURAL_MEMORY_DEPTH)
)
self.transformer = MemoryAsContextTransformer(
num_tokens=config.vocab_size,
dim=config.dim,
depth=config.depth,
segment_len=config.segment_len,
num_persist_mem_tokens=config.num_persist_mem,
num_longterm_mem_tokens=config.num_longterm_mem,
neural_memory_layers=config.neural_mem_layers,
neural_memory_segment_len=NEURAL_MEM_SEGMENT_LEN,
neural_memory_batch_size=NEURAL_MEM_BATCH_SIZE,
neural_mem_gate_attn_output=NEURAL_MEM_GATE_ATTN_OUTPUT,
neural_mem_weight_residual=NEURAL_MEM_WEIGHT_RESIDUAL,
use_flex_attn=True,
sliding_window_attn=SLIDING_WINDOWS,
neural_memory_model=neural_memory_model,
neural_memory_kwargs=dict(
dim_head=64,
heads=4,
attn_pool_chunks=STORE_ATTN_POOL_CHUNKS,
qk_rmsnorm=NEURAL_MEM_QK_NORM,
momentum=NEURAL_MEM_MOMENTUM,
momentum_order=NEURAL_MEM_MOMENTUM_ORDER,
default_step_transform_max_lr=NEURAL_MEM_MAX_LR,
use_accelerated_scan=True,
per_parameter_lr_modulation=MEMORY_MODEL_PER_LAYER_LEARNED_LR
)
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(input_ids)
if labels is not None:
loss = self.transformer(input_ids, return_loss=True)
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=outputs,
past_key_values=None,
hidden_states=None,
attentions=None,
cross_attentions=None
)
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=outputs,
past_key_values=None,
hidden_states=None,
attentions=None,
cross_attentions=None
)
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
attention_mask=None,
**kwargs
):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"past_key_values": past,
"attention_mask": attention_mask,
}
@property
def device(self):
return next(self.parameters()).device
def setup_custom_model():
"""Регистрация кастомной модели"""
AutoConfig.register("memory_transformer", MemoryTransformerConfig)
AutoModelForCausalLM.register(MemoryTransformerConfig, MemoryTransformerForCausalLM)
def generate_example(model, tokenizer, text, max_length=100):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
attention_mask = torch.ones_like(input_ids, device=device)
print(f"Model device: {next(model.parameters()).device}")
print(f"Input device: {input_ids.device}")
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
if __name__ == "__main__":
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
setup_custom_model()
config = AutoConfig.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
test_text = "Московский кремль является"
generated_text = generate_example(model, tokenizer, test_text)
print(generated_text)
Finetine Code
import os
import torch
from pathlib import Path
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from tqdm import tqdm
from adam_atan2_pytorch import AdoptAtan2
# Импортируем классы из кода обучения
from run_train_pep8 import (
WikiDatasetPreprocessor,
WikiTextDataset,
create_dataloaders,
cycle
) # From Train Code
from test_load import setup_custom_model # From Example Code
# Настройки CUDA
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32'
# Константы для файнтьюнинга
BATCH_SIZE = 2
GRADIENT_ACCUMULATE_EVERY = 2
LEARNING_RATE = 1e-5
NUM_EPOCHS = 3
STEPS_PER_EPOCH = 1000 # Количество шагов на эпоху
SEQ_LEN = 256
PROCESSED_DATA_DIR = 'processed_data'
CACHE_DIR = 'cache'
REPO_ID = 'Grpp/memory-transformer-ru'
def finetune_model(
model,
train_loader,
val_loader,
num_epochs,
device,
save_path='finetuned_model'
):
"""Файнтьюнинг модели."""
model = model.to(device)
optimizer = AdoptAtan2(model.parameters(), lr=LEARNING_RATE)
best_val_loss = float('inf')
for epoch in range(num_epochs):
model.train()
total_train_loss = 0
train_steps = 0
# Прогресс-бар для фиксированного количества шагов
train_pbar = tqdm(range(STEPS_PER_EPOCH),
desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
for step in train_pbar:
total_loss = 0
# Градиентное накопление
for _ in range(GRADIENT_ACCUMULATE_EVERY):
batch = next(train_loader)
batch = batch.to(device)
# Получаем входные данные и метки
inputs = batch[:, :-1]
labels = batch[:, 1:]
# Прямой проход
outputs = model(input_ids=inputs, labels=labels)
loss = outputs.loss / GRADIENT_ACCUMULATE_EVERY
# Обратное распространение
loss.backward()
total_loss += loss.item()
# Обновление параметров
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
total_train_loss += total_loss
train_steps += 1
# Обновление прогресс-бара
train_pbar.set_postfix({
'loss': f'{total_loss:.4f}',
'avg_loss': f'{total_train_loss/train_steps:.4f}'
})
# Валидация каждые 100 шагов
if step % 100 == 0:
model.eval()
val_loss = 0
val_steps = 0
with torch.no_grad():
for _ in range(10): # Ограничиваем количество валидационных шагов
val_batch = next(val_loader)
val_batch = val_batch.to(device)
val_inputs = val_batch[:, :-1]
val_labels = val_batch[:, 1:]
val_outputs = model(input_ids=val_inputs, labels=val_labels)
val_loss += val_outputs.loss.item()
val_steps += 1
avg_val_loss = val_loss / val_steps
print(f"\nValidation loss: {avg_val_loss:.4f}")
# Сохраняем лучшую модель
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': best_val_loss,
}, f'{save_path}_best.pt')
model.train()
# Сохраняем чекпойнт после каждой эпохи
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': total_train_loss / train_steps,
}, f'{save_path}_epoch_{epoch}.pt')
print(f"\nEpoch {epoch+1} completed. Average loss: {total_train_loss/train_steps:.4f}")
return model
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Загружаем и подготавливаем данные
processed_data_path = Path(PROCESSED_DATA_DIR) / 'processed_wiki.pt'
if not processed_data_path.exists():
print("Processing dataset...")
preprocessor = WikiDatasetPreprocessor(CACHE_DIR, PROCESSED_DATA_DIR)
preprocessor.process_and_save(max_articles=10000)
print("Creating dataloaders...")
train_loader, val_loader = create_dataloaders(
processed_data_path,
batch_size=BATCH_SIZE,
seq_len=SEQ_LEN
)
train_loader = cycle(train_loader)
val_loader = cycle(val_loader)
# Загружаем предобученную модель
print("Loading pretrained model...")
setup_custom_model()
config = AutoConfig.from_pretrained(REPO_ID)
model = AutoModelForCausalLM.from_pretrained(REPO_ID)
print("Starting finetuning...")
# Файнтьюним модель
model = finetune_model(
model,
train_loader,
val_loader,
NUM_EPOCHS,
device
)
# Сохраняем финальную версию модели
print("Saving final model...")
model.save_pretrained('final_finetuned_model')
return model
if __name__ == "__main__":
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.benchmark = True
try:
model = main()
print("Finetuning completed successfully!")
except Exception as e:
print(f"An error occurred: {str(e)}")
Training
The model was trained on a cleaned subset of Russian Wikipedia articles using the following parameters:
Batch size: 4 Sequence length: 512 Learning rate: 2e-4 Gradient accumulation steps: 4 Neural memory depth: 2 Window size: 32
Train Code
import json
import os
import random
import re
from pathlib import Path
from typing import List, Dict
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2TokenizerFast
from tqdm import tqdm
from datasets import load_dataset
from adam_atan2_pytorch import AdoptAtan2
from titans_pytorch import (
MemoryAsContextTransformer,
MemoryMLP,
MemoryAttention
)
# CUDA memory settings
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32'
# Training constants
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
PRIME_LENGTH = 100
GENERATE_LENGTH = 512
SHOULD_GENERATE = True
SEQ_LEN = 512
# Neural memory constants
NEURAL_MEMORY_DEPTH = 2
NUM_PERSIST_MEM = 4
NUM_LONGTERM_MEM = 4
NEURAL_MEM_LAYERS = (2, 4, 6)
NEURAL_MEM_GATE_ATTN_OUTPUT = False
NEURAL_MEM_MOMENTUM = True
NEURAL_MEM_MOMENTUM_ORDER = 1
NEURAL_MEM_QK_NORM = True
NEURAL_MEM_MAX_LR = 1e-1
USE_MEM_ATTENTION_MODEL = False
WINDOW_SIZE = 32
NEURAL_MEM_SEGMENT_LEN = 4
NEURAL_MEM_BATCH_SIZE = 128
SLIDING_WINDOWS = True
STORE_ATTN_POOL_CHUNKS = True
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
NEURAL_MEM_WEIGHT_RESIDUAL = True
# Initialize tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
class WikiDatasetPreprocessor:
def __init__(self, cache_dir: str = 'cache', output_dir: str = 'processed_data'):
self.cache_dir = Path(cache_dir)
self.output_dir = Path(output_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.tokenizer = GPT2TokenizerFast.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2'
)
def load_wiki_dataset(self):
"""Загрузка датасета из Hugging Face."""
print("Loading Wikipedia dataset...")
dataset = load_dataset(
"misterkirill/ru-wikipedia",
cache_dir=str(self.cache_dir)
)
print(f"Dataset loaded. Size: {len(dataset['train'])} articles")
return dataset
def clean_text(self, text: str) -> str:
"""Базовая очистка текста."""
return ' '.join(text.split())
def process_wiki_article(self, text: str) -> List[str]:
"""Обработка одной статьи из википедии."""
processed_chunks = []
clean_text = self.clean_text(text)
tokens = self.tokenizer.encode(clean_text)
chunk_size = 256
stride = 192
for i in range(0, len(tokens), stride):
chunk = tokens[i:i + chunk_size]
if len(chunk) > 50:
processed_chunks.append(chunk)
return processed_chunks
def process_and_save(
self,
batch_size: int = 1000,
test_size: float = 0.1,
max_articles: int = 10000
):
"""Обработка статей из датасета и сохранение результатов."""
dataset = self.load_wiki_dataset()
total_articles = min(len(dataset['train']), max_articles)
print(f"Processing {total_articles} articles out of {len(dataset['train'])}")
all_chunks = []
for i in tqdm(range(0, total_articles, batch_size), desc="Processing articles"):
batch = dataset['train'][i:i + batch_size]
for text in batch['text']:
chunks = self.process_wiki_article(text)
all_chunks.extend(chunks)
if len(all_chunks) > 50000:
break
if len(all_chunks) > 50000:
break
print(f"Total chunks created: {len(all_chunks)}")
random.seed(42)
random.shuffle(all_chunks)
test_size = int(len(all_chunks) * test_size)
train_chunks = all_chunks[:-test_size]
test_chunks = all_chunks[-test_size:]
print(f"Saving {len(train_chunks)} training chunks and {len(test_chunks)} test chunks...")
torch.save(
{
'train': train_chunks,
'test': test_chunks
},
self.output_dir / 'processed_wiki.pt'
)
class WikiTextDataset(Dataset):
def __init__(self, chunks: List[List[int]], seq_len: int = 512):
self.chunks = chunks
self.seq_len = seq_len
def __len__(self):
return len(self.chunks)
def __getitem__(self, idx):
chunk = self.chunks[idx]
if len(chunk) < self.seq_len + 1:
chunk = chunk + [50256] * (self.seq_len + 1 - len(chunk))
else:
chunk = chunk[:self.seq_len + 1]
return torch.tensor(chunk, device='cuda').long()
def create_dataloaders(
processed_data_path: str,
batch_size: int = 4,
seq_len: int = 512,
train_test_split: float = 0.9
) -> tuple:
"""Создание загрузчиков данных для обучения и валидации."""
print(f"Loading processed data from {processed_data_path}")
data = torch.load(processed_data_path)
train_chunks = data['train']
test_chunks = data['test']
train_dataset = WikiTextDataset(train_chunks, seq_len)
test_dataset = WikiTextDataset(test_chunks, seq_len)
print(f"Created datasets with {len(train_dataset)} training and "
f"{len(test_dataset)} test samples")
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
pin_memory=False
)
val_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=False
)
return train_loader, val_loader
def cycle(loader):
"""Бесконечный итератор по загрузчику данных."""
while True:
for data in loader:
yield data
def create_model():
"""Создание модели нейронной сети."""
try:
if USE_MEM_ATTENTION_MODEL:
neural_memory_model = MemoryAttention(dim=64)
else:
neural_memory_model = MemoryMLP(dim=64, depth=NEURAL_MEMORY_DEPTH)
model = MemoryAsContextTransformer(
num_tokens=len(tokenizer),
dim=384,
depth=8,
segment_len=WINDOW_SIZE,
num_persist_mem_tokens=NUM_PERSIST_MEM,
num_longterm_mem_tokens=NUM_LONGTERM_MEM,
neural_memory_layers=NEURAL_MEM_LAYERS,
neural_memory_segment_len=NEURAL_MEM_SEGMENT_LEN,
neural_memory_batch_size=NEURAL_MEM_BATCH_SIZE,
neural_mem_gate_attn_output=NEURAL_MEM_GATE_ATTN_OUTPUT,
neural_mem_weight_residual=NEURAL_MEM_WEIGHT_RESIDUAL,
use_flex_attn=True,
sliding_window_attn=SLIDING_WINDOWS,
neural_memory_model=neural_memory_model,
neural_memory_kwargs=dict(
dim_head=64,
heads=4,
attn_pool_chunks=STORE_ATTN_POOL_CHUNKS,
qk_rmsnorm=NEURAL_MEM_QK_NORM,
momentum=NEURAL_MEM_MOMENTUM,
momentum_order=NEURAL_MEM_MOMENTUM_ORDER,
default_step_transform_max_lr=NEURAL_MEM_MAX_LR,
use_accelerated_scan=True,
per_parameter_lr_modulation=MEMORY_MODEL_PER_LAYER_LEARNED_LR
)
).cuda()
assert next(model.parameters()).is_cuda, "Model is not on CUDA"
return model
except Exception as e:
print(f"Error creating model: {e}")
raise e
def train_model(model, train_loader, val_loader, num_batches=int(1e4)):
"""Обучение модели."""
optim = AdoptAtan2(model.parameters(), lr=2e-4)
torch.cuda.empty_cache()
pbar = tqdm(range(num_batches), desc='Training')
running_loss = 0.0
try:
for i in pbar:
model.train()
total_loss = 0
for __ in range(4):
batch = next(train_loader)
loss = model(batch, return_loss=True)
loss = loss / 4
loss.backward()
total_loss += loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % 100 == 0:
torch.cuda.empty_cache()
avg_loss = total_loss
running_loss = 0.9 * running_loss + 0.1 * avg_loss if i > 0 else avg_loss
pbar.set_postfix({
'loss': f'{running_loss:.4f}',
'batch_loss': f'{avg_loss:.4f}'
})
if i % 100 == 0:
model.eval()
with torch.no_grad():
val_batch = next(val_loader)
val_loss = model(val_batch, return_loss=True)
pbar.set_postfix({
'train_loss': f'{running_loss:.4f}',
'val_loss': f'{val_loss.item():.4f}'
})
if i % 1000 == 0 and i > 0:
torch.save({
'epoch': i,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optim.state_dict(),
'loss': running_loss,
}, f'checkpoint_{i}.pt')
except KeyboardInterrupt:
print("\nTraining interrupted by user")
except Exception as e:
print(f"\nTraining stopped due to error: {e}")
raise e
return model
def main():
"""Основная функция программы."""
try:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. This code requires GPU.")
print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
BATCH_SIZE = 4
SEQ_LEN = 512
CACHE_DIR = 'cache'
PROCESSED_DATA_DIR = 'processed_data'
NUM_BATCHES = 10000
preprocessor = WikiDatasetPreprocessor(CACHE_DIR, PROCESSED_DATA_DIR)
processed_data_path = Path(PROCESSED_DATA_DIR) / 'processed_wiki.pt'
if not processed_data_path.exists():
print("Processing Wikipedia dataset...")
preprocessor.process_and_save(max_articles=10000)
train_loader, val_loader = create_dataloaders(
processed_data_path,
batch_size=BATCH_SIZE,
seq_len=SEQ_LEN
)
train_loader = cycle(train_loader)
val_loader = cycle(val_loader)
model = create_model()
model = train_model(model, train_loader, val_loader, num_batches=NUM_BATCHES)
torch.save(model.state_dict(), 'final_model.pt')
return model, train_loader, val_loader
except Exception as e:
print(f"Error in main: {e}")
raise e
if __name__ == "__main__":
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.benchmark = True
model, train_loader, val_loader = main()
License
This project is licensed under the MIT License. See LICENSE file for details.
Citation
If you use this model in your research, please cite:
@software{neural_memory_model,
title = {Neural Memory Model for Russian Text Generation},
year = {2025},
url = {https://huggingface.co/Grpp/memory-transformer-ru}
}
- Downloads last month
- 9