# process_hf_dataset.py
from datasets import load_dataset
import re
from parser import parse_python_code, create_vector
from database import init_chromadb, store_program, DB_NAME, HF_DATASET_NAME, create_collection
import chromadb
import os
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModel
import torch
from tqdm import tqdm  # For progress bar
import time
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()

# Cache CodeBERT model globally to avoid repeated loading and reducing freezing
model_name = "microsoft/codebert-base"
tokenizer = None
model = None
device = None

def load_codebert_model(use_gpu=False):
    """Load and cache the CodeBERT model, handling GPU/CPU options."""
    global tokenizer, model, device
    if tokenizer is None or model is None:
        try:
            device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModel.from_pretrained(model_name).to(device)
            logger.info(f"CodeBERT model loaded on {device}")
        except Exception as e:
            logger.error(f"Error loading CodeBERT model: {e}")
            raise
    return tokenizer, model, device

def rename_variables(code, variable_prefixes=None):
    """Rename variables in Python code to align with vector categories (input_variable, assigned_variable, returned_variable)."""
    if variable_prefixes is None:
        variable_prefixes = {
            'input': 'input_variable',
            'assigned': 'assigned_variable',
            'returned': 'returned_variable'
        }
    
    # Simple variable name detection and renaming
    pattern = r'\b[a-zA-Z_]\w*\b'  # Match variable names (simple heuristic)
    variables = set()
    code_lines = code.split('\n')
    
    # Find all variable names (simplified approach, could improve with AST)
    for line in code_lines:
        matches = re.findall(pattern, line)
        for match in matches:
            if match not in ['def', 'if', 'else', 'for', 'while', 'return', 'import', 'print', 'eval', 'str', 'int']:  # Exclude keywords
                variables.add(match)
    
    # Sort variables by first appearance (simplified, could improve with AST)
    sorted_vars = sorted(list(variables))
    var_map = {}
    var_count = {'input_variable': 1, 'assigned_variable': 1, 'returned_variable': 1}
    
    # Assign variables based on context (simplified heuristic)
    for var in sorted_vars:
        # Determine variable role based on context
        is_input = any(var in line and 'def' in line for line in code_lines)  # Check if in function definition (input parameter)
        is_returned = any('return' in line and var in line for line in code_lines)  # Check if used in return statement
        is_assigned = any('=' in line and var in line.split('=')[0].strip() for line in code_lines)  # Check if assigned
        
        if is_input:
            role = 'input_variable'
        elif is_returned:
            role = 'returned_variable'
        elif is_assigned:
            role = 'assigned_variable'
        else:
            role = 'assigned_variable'  # Default to assigned if unclear
        
        new_name = f"{role}{var_count[role]}"
        var_map[var] = new_name
        var_count[role] += 1
    
    # Replace variables in code
    new_code = code
    for old_var, new_var in var_map.items():
        new_code = re.sub(r'\b' + old_var + r'\b', new_var, new_code)
    
    return new_code, var_map

def generate_description_tokens(sequence, vectors, var_map=None):
    """Generate semantic description tokens for a program, including variable roles."""
    tokens = []
    category_descriptions = {
        'import': 'imports module',
        'function': 'defines function',
        'assigned_variable': 'assigns variable',
        'input_variable': 'input parameter',
        'returned_variable': 'returns value',
        'if': 'conditional statement',
        'return': 'returns result',
        'try': 'try block',
        'except': 'exception handler',
        'expression': 'expression statement',
        'spacer': 'empty line or comment'
    }
    
    for cat, vec in zip(sequence, vectors):
        if cat in category_descriptions:
            tokens.append(f"{category_descriptions[cat]}:{cat}")
            # Add vector-derived features (e.g., level, span) as tokens
            tokens.append(f"level:{vec[1]}")
            tokens.append(f"span:{vec[3]:.2f}")
    
    # Add variable role tokens if var_map exists
    if var_map:
        for old_var, new_var in var_map.items():
            role = new_var.split('variable')[0] + 'variable'  # Extract role (e.g., 'input_variable')
            tokens.append(f"variable:{old_var}={new_var}:{role}")
    
    return tokens

def generate_semantic_vector(description, total_lines=100, use_gpu=False):
    """Generate a 6D semantic vector for a textual description using CodeBERT, projecting to 6D."""
    global tokenizer, model, device
    if tokenizer is None or model is None:
        tokenizer, model, device = load_codebert_model(use_gpu)
    
    # Tokenize and encode the description
    inputs = tokenizer(description, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Generate embeddings
    with torch.no_grad():
        outputs = model(**inputs)
        # Use mean pooling of the last hidden states
        vector = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().tolist()
    
    # Truncate or project to 6D (simplified projection: take first 6 dimensions)
    if len(vector) < 6:
        vector.extend([0] * (6 - len(vector)))
    elif len(vector) > 6:
        vector = vector[:6]  # Truncate to 6D
    
    # Ensure vector isn’t all zeros or defaults
    if all(v == 0 for v in vector):
        logger.warning(f"Default vector detected for description: {description}")
        # Fallback: Use heuristic if CodeBERT fails to generate meaningful embeddings
        category_map = {
            'import': 1, 'function': 2, 'assign': 17, 'input': 18, 'return': 19, 'if': 5, 'try': 8, 'except': 14
        }
        tokens = description.lower().split()
        vector = [0] * 6
        for token in tokens:
            for cat, cat_id in category_map.items():
                if cat in token:
                    vector[0] = cat_id  # category_id
                    vector[1] = 1  # level
                    vector[2] = 0.5  # center_pos
                    vector[3] = 0.1  # span
                    vector[4] = 1  # parent_depth
                    vector[5] = cat_id / len(category_map)  # parent_weight
                    break
    
    logger.debug(f"Generated semantic vector for '{description}': {vector}")
    return vector

def process_hf_dataset(batch_size=100, use_gpu=False):
    """Process the Hugging Face dataset in batches and store programs in ChromaDB, aligning with vector categories."""
    # Load the dataset
    try:
        dataset = load_dataset("iamtarun/python_code_instructions_18k_alpaca", split="train")
        dataset_list = list(dataset)
        logger.info(f"Loaded dataset with {len(dataset_list)} entries")
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        raise
    
    # Initialize ChromaDB client
    client = init_chromadb()
    
    # Do not clear or populate with defaults here—let UI buttons handle this
    try:
        collection = client.get_or_create_collection(DB_NAME)
        logger.info(f"Using existing or new ChromaDB collection: {DB_NAME}, contains {collection.count()} entries")
        # Verify collection is valid
        if collection is None or not hasattr(collection, 'add'):
            raise ValueError("ChromaDB collection access failed")
        logger.info("Verified ChromaDB collection is valid")
    except Exception as e:
        logger.error(f"Error accessing ChromaDB collection: {e}")
        raise
    
    # Process in batches with progress bar
    total_entries = len(dataset_list)
    for i in tqdm(range(0, total_entries, batch_size), desc="Processing Hugging Face Dataset"):
        batch = dataset_list[i:i + batch_size]
        batch_programs = []
        batch_ids = []
        batch_documents = []
        batch_metadatas = []
        batch_embeddings = []
        
        for entry in batch:
            try:
                instruction = entry['instruction']
                output = entry['output']
                
                # Rename variables to align with vector categories
                processed_code, var_map = rename_variables(output)
                
                # Parse the code to get parts and sequence, generating our 6D vectors
                parts, sequence = parse_python_code(processed_code)
                program_vectors = [part['vector'] for part in parts]  # Use parser's 6D vectors for program structure
                
                # Generate description tokens including variable roles
                description_tokens = f"task:{instruction.replace(' ', '_')}"
                description_tokens_list = generate_description_tokens(sequence, program_vectors, var_map)
                description_tokens += " " + " ".join(description_tokens_list)
                
                # Generate a 6D semantic vector for the instruction
                semantic_vector = generate_semantic_vector(instruction, use_gpu=use_gpu)
                
                # Store program data
                program_id = str(hash(processed_code))
                batch_ids.append(program_id)
                batch_documents.append(processed_code)
                batch_metadatas.append({"sequence": ",".join(sequence), "description_tokens": description_tokens, "program_vectors": str(program_vectors)})
                batch_embeddings.append(semantic_vector)
                
                logger.debug(f"Processed entry: {program_id}, Vector: {semantic_vector}")
            except Exception as e:
                logger.error(f"Error processing entry {i}: {e}")
                continue  # Skip failed entries but continue processing
        
        # Batch add to ChromaDB
        try:
            collection.add(
                documents=batch_documents,
                metadatas=batch_metadatas,
                ids=batch_ids,
                embeddings=batch_embeddings
            )
            logger.info(f"Added batch {i//batch_size + 1} to ChromaDB with {len(batch_ids)} entries")
            # Verify addition
            count = collection.count()
            logger.info(f"ChromaDB now contains {count} entries after adding batch")
        except Exception as e:
            logger.error(f"Error adding batch to ChromaDB: {e}")
            raise
    
    # Save to Hugging Face Dataset
    save_chromadb_to_hf()

def save_chromadb_to_hf(dataset_name=HF_DATASET_NAME, token=os.getenv("HF_KEY")):
    """Save ChromaDB data to Hugging Face Dataset, with error handling and logging."""
    try:
        client = init_chromadb()
        collection = client.get_collection(DB_NAME)
        
        # Fetch all data from ChromaDB
        results = collection.get(include=["documents", "metadatas", "embeddings"])
        data = {
            "code": results["documents"],
            "sequence": [meta["sequence"] for meta in results["metadatas"]],
            "vectors": results["embeddings"],  # Semantic 6D vectors
            "description_tokens": [meta.get('description_tokens', '') for meta in results["metadatas"]],
            "program_vectors": [eval(meta.get('program_vectors', '[]')) for meta in results["metadatas"]]  # Store structural vectors
        }
        
        # Create a Hugging Face Dataset
        dataset = Dataset.from_dict(data)
        logger.info(f"Created Hugging Face Dataset with {len(data['code'])} entries")
        
        # Push to Hugging Face Hub
        dataset.push_to_hub(dataset_name, token=token, exist_ok=True)  # Allow overwriting existing dataset
        logger.info(f"Dataset pushed to Hugging Face Hub as {dataset_name}, overwriting existing dataset")
        # Verify push (optional, could check dataset on Hub)
        logger.info(f"Verified Hugging Face dataset push with {len(dataset)} entries")
    except Exception as e:
        logger.error(f"Error pushing dataset to Hugging Face Hub: {e}")
        raise

if __name__ == "__main__":
    process_hf_dataset(batch_size=100, use_gpu=False)