from SmolLm3 import LlamaModel
import torch
import yaml
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import numpy as np
from datasets import load_dataset
import logging
import math

from utils import upload_file_to_s3
# At the start of training loop
# print(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
# print(f"GPU Memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")


logger = logging.getLogger(__name__)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('training.log')
file_handler.setFormatter(formatter)  # Set formatter on the handler, not the logger
logger.addHandler(file_handler)
logger.setLevel(logging.INFO)

def encode_text(examples, tokenizer, seq_length):
    """Tokenize and prepare text examples for training."""
    tokens = tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=seq_length + 1,
        return_tensors="pt",
    )
    # Use clone().detach() as recommended
    input_ids = tokens["input_ids"].squeeze(0).clone().detach()
    input_ids = torch.clamp(input_ids, min=0, max=tokenizer.vocab_size - 1)
    labels = input_ids.clone().detach()
    labels = labels[1:].to(torch.int64)
    input_ids = input_ids[:-1].to(torch.int64)

    return {"input_ids": input_ids, "labels": labels}

def load_cosmopedia_dataset(batch_size=8, seq_length=1024, tokenizer=None):
    """
    Returns a torch dataloader for the cosmopedia dataset
    """
    # Set tokenizer parallelism explicitly
    import os
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    logger.info("tokenizer parallelism set to false")
    try:
        # Increase timeout and retries for dataset loading
        from datasets import config
        config.HF_DATASETS_TIMEOUT = 300  # 5 minutes timeout
        config.MAX_RETRIES = 10  # Increase retry attempts
        logger.info("dataset loading config set")
        train_dataset = load_dataset(
            "HuggingFaceTB/smollm-corpus",
            name="cosmopedia-v2",
            split="train",
            streaming=True,
        )
        logger.info("dataset loaded")

        # Use partial to bind tokenizer and seq_length to the encode function
        from functools import partial
        encode_fn = partial(encode_text, tokenizer=tokenizer, seq_length=seq_length)
        
        train_dataset = train_dataset.map(
            encode_fn, 
            remove_columns=["text"], 
            batched=False
        )
        train_dataset = train_dataset.with_format("torch")
        
        train_dataloader = DataLoader(
            train_dataset, 
            batch_size=batch_size,
            num_workers=2,
            pin_memory=True,
            prefetch_factor=4,
            persistent_workers=True
        )
        return train_dataloader
    except Exception as e:
        logger.error(f"Error loading dataset: {str(e)}")
        return None
    

def generate(model, idx, max_new_tokens, context_length, temperature=1.0, top_k=None, eos_token=None, device=None):
    logger.info(f"Generating on device {device}")
    model = model.to(device)
    idx = idx.to(device)
    model.eval()
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_length:]
        with torch.no_grad():
            logits, _ = model(idx_cond)  # Unpack both logits and loss (ignore loss)
            logits = logits.view(idx_cond.shape[0], -1, model.config['vocab_size'])  # Reshape to [batch, seq, vocab]
            
        # Get the logits for the last token only
        logits = logits[:, -1, :]  # Shape: [batch_size, vocab_size]
        
        if top_k is not None:
            # top k sampling
            top_logits, top_pos = torch.topk(logits, top_k)
            min_logit = top_logits[:, -1].unsqueeze(-1)
            logits = torch.where(logits < min_logit,
                               torch.tensor(float('-inf')).to(logits.device),
                               logits)
        
        # temperature scaling
        if temperature > 0.0:
            logits /= temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)
            
        if idx_next.item() == eos_token:
            break
            
        idx = torch.cat((idx, idx_next), dim=1)
    model.train()
    return idx

def sync_device(device):
    if device.startswith('cuda'):
        torch.cuda.synchronize()
    elif device == 'cpu':
        torch.cpu.synchronize() if hasattr(torch.cpu, 'synchronize') else None
    elif device.startswith('mps'):  # For Apple Silicon
        torch.mps.synchronize()

def print_gpu_memory(step_name=""):
    """
    Print GPU memory statistics with a specified step name
    """
    if torch.cuda.is_available():
        logger.info(f"\nGPU Memory Stats {step_name}:")
        logger.info(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        logger.info(f"GPU Memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
        logger.info(f"Max GPU Memory allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")

# Learning rate scheduler
def get_lr_lambda(current_step, warmup_steps, max_steps, max_lr):
    """
    Modified learning rate scheduler with:
    1. Linear warmup for first 3000 steps
    2. Cosine decay from 3000 to 60000 steps
    3. Minimum learning rate of 1.5e-5 (5% of max_lr)
    """
    min_lr = max_lr * 0.05  # Minimum learning rate (5% of max_lr)

    if current_step < warmup_steps:
        # Linear warmup from 0 to max_lr
        return float(current_step) / float(max(1, warmup_steps))
    else:
        # Cosine decay from max_lr to min_lr
        progress = float(current_step - warmup_steps) / float(max(1, max_steps - warmup_steps))
        return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress))


def train_model(config, model, train_loader, test_loader, optimizer, device, num_epochs, eval_freq, eval_iter, start_context="Jack Gisburn rather a cheap genius- ", tokenizer=None):
    total_loss = 0
    tokens_seen, global_step = 0, -1
    
    # Adjusted gradient accumulation setup
    actual_batch_size = config['tokens']['micro_batch_size']  # Now 16
    effective_batch_size_multiplier = 2  # Reduced from 4 to maintain reasonable memory usage
    target_batch_size = effective_batch_size_multiplier * config['tokens']['micro_batch_size']
    gradient_accumulation_steps = target_batch_size // actual_batch_size
    
    # Adjusted learning rate parameters for new batch size
    max_lr = 3e-4  # Keep the same max learning rate
    warmup_steps = 3000  # Increase warmup steps for longer training
    max_steps = 60000  # Set to match 10 hours of training
    min_lr = max_lr * 0.05  # Reduce minimum LR to 5% of max (was 10%)
    
    # Create LambdaLR scheduler with the improved lambda function
    lr_lambda = lambda step: get_lr_lambda(step, warmup_steps, max_steps, max_lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    logger.info(f"Training with learning rate schedule:")
    logger.info(f"Max LR: {max_lr}")
    logger.info(f"Warmup Steps: {warmup_steps}")
    logger.info(f"Max Steps: {max_steps}")
    logger.info(f"Min LR: {max_lr * 0.05}")
    logger.info(f"Gradient Accumulation Steps: {gradient_accumulation_steps}")
    logger.info(f"Effective Batch Size: {actual_batch_size * gradient_accumulation_steps}")
    
    print_gpu_memory("at start of training")
    
    # Add these near the start of training loop
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    
    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()  # Zero gradients at start of epoch
        
        for batch_idx, batch in enumerate(train_loader):
            input_batch = batch['input_ids'].to(device)
            target_batch = batch['labels'].to(device)
            
            # Forward pass
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                logits, original_loss = model(input_batch, target_batch)
            
                # Scale loss for gradient accumulation
            scaled_loss = original_loss / gradient_accumulation_steps
            scaled_loss.backward()
            
            # Add the original loss to total_loss for logging
            total_loss += original_loss.item()  # Don't multiply back up
            tokens_seen += input_batch.numel()
            
            # Calculate running average loss
            total_batches = batch_idx + 1
            avg_loss = total_loss / total_batches
            if batch_idx % 25 == 0:
                logger.info(f"Batch {batch_idx + 1}, Running Avg Loss: {avg_loss:.5f}")
            # Only update weights after accumulating gradients
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                scheduler.step()  # Update learning rate
                optimizer.zero_grad()
                global_step += 1
            
            # Evaluation block
            if global_step % eval_freq == 0 and global_step > 0:
                # Use total batches processed instead of global_step
                current_lr = scheduler.get_last_lr()[0]
                optimizer_lr = optimizer.param_groups[0]['lr']
                
                print_gpu_memory(f"at step {global_step}")
                logger.info(f"learning rate: {current_lr:.8f}")
                logger.info(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Avg loss {avg_loss:.3f} | {tokens_seen} tokens seen")
                logger.info(f"optimizer lr: {optimizer_lr:.8f}")
                logger.info(f"scheduler lr: {current_lr:.8f}")
                
                # Generate sample text
                encoded_text = tokenizer.encode(start_context, return_tensors="pt")
                random_topk = np.random.randint(1, 10)
                logger.info(f"random_topk: {random_topk}")
                random_temperature = np.random.uniform(0.7, 0.9)
                logger.info(f"random_temperature: {random_temperature}")
                logger.info(f"global step {global_step} , batch_idx {batch_idx} => generating text")
                generated_text = generate(model, 
                                       idx=encoded_text,
                                       max_new_tokens=256,
                                       context_length=256, 
                                       temperature=random_temperature, 
                                       top_k=random_topk, 
                                       eos_token=tokenizer.eos_token_id, 
                                       device=device)
                logger.info(f"+++"*30)
                logger.info(tokenizer.decode(generated_text.squeeze(0)))
                logger.info(f"+++"*30)
                
                # Save checkpoint
                model_file_name = f"model_{global_step}_steps_avg_loss_{avg_loss:.5f}_optimizer_lr_{optimizer_lr:.8f}.pth"
                torch.save({
                    'step': global_step,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': avg_loss,
                }, model_file_name)
                
                s3_path = upload_file_to_s3(model_file_name, config['model']['model_config']['s3_bucket'], 
                                          config['model']['model_config']['s3_checkpoint_folder'])
                logger.info(f"Model saved to S3: {s3_path}")

                log_path = upload_file_to_s3(config['model']['model_config']['s3_log_file_name'], config['model']['model_config']['s3_bucket'], 
                                              config['model']['model_config']['s3_log_folder'])
                logger.info(f"Log saved to S3: {log_path}")
            
            if batch_idx % 100 == 0:
                logger.info(f"Batch {batch_idx} finished")
                logger.info(f"+++"*30)

    logger.info("Training complete")

if __name__ == "__main__":
    config = yaml.load(open("config_smollm2_135M.yaml", "r"), Loader=yaml.FullLoader)
    logger.info(config)
    
    # Set memory efficient settings
    torch.set_float32_matmul_precision('high')
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    
    # Empty cache before model creation
    torch.cuda.empty_cache()
    
    model = LlamaModel(config['model'])
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Enable gradient checkpointing for memory efficiency
    # model.gradient_checkpointing_enable()
    
    model.to(device)
    model = torch.compile(model)
    logger.info(model)
    logger.info("++"*30)
    
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=3e-4, 
        weight_decay=0.15,
        betas=(0.9, 0.95)
    )
    
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
    tokenizer.pad_token = tokenizer.eos_token
    vocab_size = tokenizer.vocab_size
    
    # Adjusted batch size and sequence length
    train_loader = load_cosmopedia_dataset(
        batch_size=16,  # Set to 16
        seq_length=1024,  # Kept at 1024
        tokenizer=tokenizer
    )
    
    import time
    t1 = time.time()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Set environment variable for memory allocation
    import os
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
    
    train_model(
        config, 
        model, 
        train_loader, 
        train_loader, 
        optimizer=optimizer, 
        device=device, 
        num_epochs=1, 
        eval_freq=1000,  # Increase eval frequency to every 500 steps
        eval_iter=1000,
        start_context="Once Upon a Time far far away in a galaxy", 
        tokenizer=tokenizer
    )
    t2 = time.time()
    logger.info(f"Time taken for training: {t2 - t1:.2f} seconds")