import os
import xml.etree.ElementTree as ET
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Any, Optional
from collections import defaultdict
from accelerate import Accelerator
from torch.utils.data import DataLoader, TensorDataset
from torch.cuda.amp import GradScaler, autocast

class DynamicModel(nn.Module):
    def __init__(self, sections: Dict[str, List[Dict[str, Any]]]):
        super(DynamicModel, self).__init__()
        self.sections = nn.ModuleDict()
        if not sections:
            sections = {
                'default': [{
                    'input_size': 128,
                    'output_size': 256,
                    'activation': 'relu',
                    'batch_norm': True,
                    'dropout': 0.1
                }]
            }
        for section_name, layers in sections.items():
            self.sections[section_name] = nn.ModuleList()
            for layer_params in layers:
                print(f"Creating layer in section '{section_name}' with params: {layer_params}")
                self.sections[section_name].append(self.create_layer(layer_params))

    def create_layer(self, layer_params: Dict[str, Any]) -> nn.Module:
        layers = []
        layers.append(nn.Linear(layer_params['input_size'], layer_params['output_size']))
        if layer_params.get('batch_norm', False):
            layers.append(nn.BatchNorm1d(layer_params['output_size']))
        activation = layer_params.get('activation', 'relu')
        if activation == 'relu':
            layers.append(nn.ReLU(inplace=True))
        elif activation == 'tanh':
            layers.append(nn.Tanh())
        elif activation == 'sigmoid':
            layers.append(nn.Sigmoid())
        elif activation == 'leaky_relu':
            layers.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
        elif activation == 'elu':
            layers.append(nn.ELU(alpha=1.0, inplace=True))
        elif activation is not None:
            raise ValueError(f"Unsupported activation function: {activation}")
        if dropout_rate := layer_params.get('dropout', 0.0):
            layers.append(nn.Dropout(p=dropout_rate))
        if hidden_layers := layer_params.get('hidden_layers', []):
            for hidden_layer_params in hidden_layers:
                layers.append(self.create_layer(hidden_layer_params))
        if layer_params.get('memory_augmentation', True):
            layers.append(MemoryAugmentationLayer(layer_params['output_size']))
        if layer_params.get('hybrid_attention', True):
            layers.append(HybridAttentionLayer(layer_params['output_size']))
        if layer_params.get('dynamic_flash_attention', True):
            layers.append(DynamicFlashAttentionLayer(layer_params['output_size']))
        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor, section_name: Optional[str] = None) -> torch.Tensor:
        if section_name is not None:
            if section_name not in self.sections:
                raise KeyError(f"Section '{section_name}' not found in model")
            for layer in self.sections[section_name]:
                x = layer(x)
        else:
            for section_name, layers in self.sections.items():
                for layer in layers:
                    x = layer(x)
        return x

class MemoryAugmentationLayer(nn.Module):
    def __init__(self, size: int):
        super(MemoryAugmentationLayer, self).__init__()
        self.memory = nn.Parameter(torch.randn(size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.memory

class HybridAttentionLayer(nn.Module):
    def __init__(self, size: int):
        super(HybridAttentionLayer, self).__init__()
        self.attention = nn.MultiheadAttention(size, num_heads=8)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.unsqueeze(1)  # Add sequence dimension
        attn_output, _ = self.attention(x, x, x)
        return attn_output.squeeze(1)

class DynamicFlashAttentionLayer(nn.Module):
    def __init__(self, size: int):
        super(DynamicFlashAttentionLayer, self).__init__()
        self.attention = nn.MultiheadAttention(size, num_heads=8)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.unsqueeze(1)  # Add sequence dimension
        attn_output, _ = self.attention(x, x, x)
        return attn_output.squeeze(1)

def parse_xml_file(file_path: str) -> List[Dict[str, Any]]:
    tree = ET.parse(file_path)
    root = tree.getroot()
    layers = []
    for layer in root.findall('.//layer'):
        layer_params = {}
        layer_params['input_size'] = int(layer.get('input_size', 128))
        layer_params['output_size'] = int(layer.get('output_size', 256))
        layer_params['activation'] = layer.get('activation', 'relu').lower()
        if layer_params['activation'] not in ['relu', 'tanh', 'sigmoid', 'none']:
            raise ValueError(f"Unsupported activation function: {layer_params['activation']}")
        if layer_params['input_size'] <= 0 or layer_params['output_size'] <= 0:
            raise ValueError("Layer dimensions must be positive integers")
        layers.append(layer_params)
    if not layers:
        layers.append({
            'input_size': 128,
            'output_size': 256,
            'activation': 'relu'
        })
    return layers

def create_model_from_folder(folder_path: str) -> DynamicModel:
    sections = defaultdict(list)
    if not os.path.exists(folder_path):
        print(f"Warning: Folder {folder_path} does not exist. Creating model with default configuration.")
        return DynamicModel({})
    xml_files_found = False
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if file.endswith('.xml'):
                xml_files_found = True
                file_path = os.path.join(root, file)
                try:
                    layers = parse_xml_file(file_path)
                    section_name = os.path.basename(root).replace('.', '_')
                    sections[section_name].extend(layers)
                except Exception as e:
                    print(f"Error processing {file_path}: {str(e)}")
    if not xml_files_found:
        print("Warning: No XML files found. Creating model with default configuration.")
        return DynamicModel({})
    return DynamicModel(dict(sections))

def main():
    folder_path = 'data'
    model = create_model_from_folder(folder_path)
    print(f"Created dynamic PyTorch model with sections: {list(model.sections.keys())}")
    # Print the model architecture
    print(model)
    first_section = next(iter(model.sections.keys()))
    first_layer = model.sections[first_section][0]
    input_features = first_layer[0].in_features
    # Ensure the input tensor size matches the expected input size
    sample_input = torch.randn(1, input_features)
    output = model(sample_input)
    print(f"Sample output shape: {output.shape}")

    accelerator = Accelerator()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    num_epochs = 10
    dataset = TensorDataset(
        torch.randn(100, input_features),
        torch.randint(0, 2, (100,))
    )
    train_dataloader = DataLoader(
        dataset,
        batch_size=8,  # Reduced batch size
        shuffle=True
    )

    model, optimizer, train_dataloader = accelerator.prepare(
        model,
        optimizer,
        train_dataloader
    )

    scaler = GradScaler()  # Mixed precision training

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch_idx, (inputs, labels) in enumerate(train_dataloader):
            optimizer.zero_grad()
            with autocast():  # Mixed precision training
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

if __name__ == "__main__":
    main()