import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, RobertaModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
import argparse

# 1. Dataset Class
class DepressionDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=256):
        self.texts = df['clean_text'].values
        self.labels = df['is_depression'].values
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# 2. Model Class
class DepressionClassifier(nn.Module):
    def __init__(self, dropout_rate=0.1):
        super(DepressionClassifier, self).__init__()
        self.roberta = RobertaModel.from_pretrained('roberta-base')
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(768, 2)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.roberta(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        sequence_output = outputs.last_hidden_state[:, 0, :]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        
        return logits

# 3. Prepare data loaders
def prepare_dataloaders(df, batch_size=16):
    # Split data
    train_df, temp_df = train_test_split(df, test_size=0.3, stratify=df['is_depression'], random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['is_depression'], random_state=42)
    
    # Initialize tokenizer
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    
    # Create datasets
    train_dataset = DepressionDataset(train_df, tokenizer)
    val_dataset = DepressionDataset(val_df, tokenizer)
    test_dataset = DepressionDataset(test_df, tokenizer)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, val_loader, test_loader

# 4. Training function
def train_model(model, train_loader, val_loader, device, epochs=3, learning_rate=2e-5):
    # Move model to device
    model = model.to(device)
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    # Initialize loss function
    loss_fn = nn.CrossEntropyLoss()
    
    # Training loop
    best_accuracy = 0
    
    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')
        
        # TRAINING
        model.train()
        train_loss = 0
        train_preds = []
        train_labels = []
        
        # Progress bar for training
        progress_bar = tqdm(train_loader, desc="Training")
        
        for batch in progress_bar:
            # Get batch data
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, labels)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            # Track metrics
            train_loss += loss.item()
            _, preds = torch.max(outputs, dim=1)
            train_preds.extend(preds.cpu().tolist())
            train_labels.extend(labels.cpu().tolist())
            
            # Update progress bar
            progress_bar.set_postfix({'loss': loss.item()})
        
        # Calculate training metrics
        avg_train_loss = train_loss / len(train_loader)
        train_accuracy = accuracy_score(train_labels, train_preds)
        
        # VALIDATION
        model.eval()
        val_loss = 0
        val_preds = []
        val_labels = []
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                # Get batch data
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                
                # Forward pass
                outputs = model(input_ids, attention_mask)
                loss = loss_fn(outputs, labels)
                
                # Track metrics
                val_loss += loss.item()
                _, preds = torch.max(outputs, dim=1)
                val_preds.extend(preds.cpu().tolist())
                val_labels.extend(labels.cpu().tolist())
        
        # Calculate validation metrics
        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = accuracy_score(val_labels, val_preds)
        
        # Print metrics
        print(f'Train Loss: {avg_train_loss:.4f} | Train Accuracy: {train_accuracy:.4f}')
        print(f'Val Loss: {avg_val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}')
        
        # Save best model
        if val_accuracy > best_accuracy:
            torch.save(model.state_dict(), 'best_model.pt')
            best_accuracy = val_accuracy
            print(f'New best model saved with accuracy: {val_accuracy:.4f}')
        
        print('-' * 50)
    
    # Load best model
    model.load_state_dict(torch.load('best_model.pt'))
    return model

# 5. Evaluation function
def evaluate_model(model, test_loader, device):
    model.eval()
    test_preds = []
    test_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask)
            _, preds = torch.max(outputs, dim=1)
            
            test_preds.extend(preds.cpu().tolist())
            test_labels.extend(labels.cpu().tolist())
    
    # Calculate metrics
    accuracy = accuracy_score(test_labels, test_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        test_labels, test_preds, average='binary'
    )
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

# 6. Main function
def main():
    parser = argparse.ArgumentParser(description='Train depression classifier')
    parser.add_argument('--data_path', type=str, default='depression_dataset_reddit_cleaned_final.csv', 
                        help='Path to the cleaned dataset')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training')
    parser.add_argument('--epochs', type=int, default=3, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
    args = parser.parse_args()
    
    # Check for GPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    
    # Load data
    df = pd.read_csv(args.data_path)
    print(f'Loaded dataset with {len(df)} examples')
    
    # Prepare data
    train_loader, val_loader, test_loader = prepare_dataloaders(
        df, batch_size=args.batch_size
    )
    print(f'Training samples: {len(train_loader.dataset)}')
    print(f'Validation samples: {len(val_loader.dataset)}')
    print(f'Testing samples: {len(test_loader.dataset)}')
    
    # Create model
    model = DepressionClassifier()
    print('Model created')
    
    # Train model
    print('Starting training...')
    trained_model = train_model(
        model, 
        train_loader, 
        val_loader, 
        device,
        epochs=args.epochs,
        learning_rate=args.learning_rate
    )
    
    # Evaluate model
    print('Evaluating model...')
    metrics = evaluate_model(trained_model, test_loader, device)
    
    # Print results
    print('\nTest Results:')
    for metric, value in metrics.items():
        print(f'{metric}: {value:.4f}')

if __name__ == '__main__':
    main()