from flask import Flask, render_template, request, jsonify
from flask_socketio import SocketIO
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import shutil
import numpy as np
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

class Predictor:
    def __init__(self, model_cfg, checkpoint, device):
        self.device = device
        self.model = build_sam2(model_cfg, checkpoint, device=device)
        self.predictor = SAM2ImagePredictor(self.model)
        self.image_set = False

    def set_image(self, image):
        """Set the image for SAM prediction."""
        self.image = image
        self.predictor.set_image(image)
        self.image_set = True

    def predict(self, point_coords, point_labels, multimask_output=False):
        """Run SAM prediction."""
        if not self.image_set:
            raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
        return self.predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            multimask_output=multimask_output
        )
from utils.helpers import (
    blend_mask_with_image,
    save_mask_as_png,
    convert_mask_to_yolo,
)
import torch
from ultralytics import YOLO
import threading
from threading import Lock
import subprocess
import time
import logging
import multiprocessing
import json


# Initialize Flask app and SocketIO
app = Flask(__name__)
socketio = SocketIO(app)

# Define Base Directory
BASE_DIR = os.path.abspath(os.path.dirname(__file__))

# Folder structure with absolute paths
UPLOAD_FOLDERS = {
    'input': os.path.join(BASE_DIR, 'static/uploads/input'),
    'segmented_voids': os.path.join(BASE_DIR, 'static/uploads/segmented/voids'),
    'segmented_chips': os.path.join(BASE_DIR, 'static/uploads/segmented/chips'),
    'mask_voids': os.path.join(BASE_DIR, 'static/uploads/mask/voids'),
    'mask_chips': os.path.join(BASE_DIR, 'static/uploads/mask/chips'),
    'automatic_segmented': os.path.join(BASE_DIR, 'static/uploads/segmented/automatic'),
}

HISTORY_FOLDERS = {
    'images': os.path.join(BASE_DIR, 'static/history/images'),
    'masks_chip': os.path.join(BASE_DIR, 'static/history/masks/chip'),
    'masks_void': os.path.join(BASE_DIR, 'static/history/masks/void'),
}

DATASET_FOLDERS = {
    'train_images': os.path.join(BASE_DIR, 'dataset/train/images'),
    'train_labels': os.path.join(BASE_DIR, 'dataset/train/labels'),
    'val_images': os.path.join(BASE_DIR, 'dataset/val/images'),
    'val_labels': os.path.join(BASE_DIR, 'dataset/val/labels'),
    'temp_backup': os.path.join(BASE_DIR, 'temp_backup'),
    'models': os.path.join(BASE_DIR, 'models'),
    'models_old': os.path.join(BASE_DIR, 'models/old'),
}

# Ensure all folders exist
for folder_name, folder_path in {**UPLOAD_FOLDERS, **HISTORY_FOLDERS, **DATASET_FOLDERS}.items():
    os.makedirs(folder_path, exist_ok=True)
    logging.info(f"Ensured folder exists: {folder_name} -> {folder_path}")

training_process = None


def initialize_training_status():
    """Initialize global training status."""
    global training_status
    training_status = {'running': False, 'cancelled': False}

def persist_training_status():
    """Save training status to a file."""
    with open(os.path.join(BASE_DIR, 'training_status.json'), 'w') as status_file:
        json.dump(training_status, status_file)

def load_training_status():
    """Load training status from a file."""
    global training_status
    status_path = os.path.join(BASE_DIR, 'training_status.json')
    if os.path.exists(status_path):
        with open(status_path, 'r') as status_file:
            training_status = json.load(status_file)
    else:
        training_status = {'running': False, 'cancelled': False}

load_training_status()

os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0"

# Initialize SAM Predictor
MODEL_CFG = os.path.join(BASE_DIR, "sam2", "sam2_hiera_l.yaml")
CHECKPOINT = os.path.join(BASE_DIR, "sam2", "sam2.1_hiera_large.pt")
print(f"Chargement de {MODEL_CFG} et {CHECKPOINT} sur l'appareil .")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
predictor = Predictor(MODEL_CFG, CHECKPOINT, DEVICE)

# Initialize YOLO-seg
YOLO_CFG = os.path.join(DATASET_FOLDERS['models'], "best.pt")
yolo_model = YOLO(YOLO_CFG)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler(os.path.join(BASE_DIR, "app.log"))  # Log to a file
    ]
)


@app.route('/')
def index():
    """Serve the main UI."""
    return render_template('index.html')

@app.route('/upload', methods=['POST'])
def upload_image():
    """Handle image uploads."""
    if 'file' not in request.files:
        return jsonify({'error': 'No file uploaded'}), 400
    file = request.files['file']
    if file.filename == '':
        return jsonify({'error': 'No file selected'}), 400

    # Save the uploaded file to the input folder
    input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename)
    file.save(input_path)

    # Set the uploaded image in the predictor
    image = np.array(Image.open(input_path).convert("RGB"))
    predictor.set_image(image)

    # Return a web-accessible URL instead of the file system path
    web_accessible_url = f"/static/uploads/input/{file.filename}"
    print(f"Image uploaded and set for prediction: {input_path}")
    return jsonify({'image_url': web_accessible_url})

@app.route('/segment', methods=['POST'])
def segment():
    """
    Perform segmentation and return the blended image URL.
    """
    try:
        # Extract data from request
        data = request.json
        points = np.array(data.get('points', []))
        labels = np.array(data.get('labels', []))
        current_class = data.get('class', 'voids')  # Default to 'voids' if class not provided

        # Ensure predictor has an image set
        if not predictor.image_set:
            raise ValueError("No image set for prediction.")

        # Perform SAM prediction
        masks, _, _ = predictor.predict(
            point_coords=points,
            point_labels=labels,
            multimask_output=False
        )

        # Check if masks exist and have non-zero elements
        if masks is None or masks.size == 0:
            raise RuntimeError("No masks were generated by the predictor.")

        # Define output paths based on class
        mask_folder = UPLOAD_FOLDERS.get(f'mask_{current_class}')
        segmented_folder = UPLOAD_FOLDERS.get(f'segmented_{current_class}')

        if not mask_folder or not segmented_folder:
            raise ValueError(f"Invalid class '{current_class}' provided.")

        os.makedirs(mask_folder, exist_ok=True)
        os.makedirs(segmented_folder, exist_ok=True)

        # Save the raw mask
        mask_path = os.path.join(mask_folder, 'raw_mask.png')
        save_mask_as_png(masks[0], mask_path)

        # Generate blended image
        blend_color = [34, 139, 34] if current_class == 'voids' else [30, 144, 255]  # Green for voids, blue for chips
        blended_image = blend_mask_with_image(predictor.image, masks[0], blend_color)

        # Save blended image
        blended_filename = f"blended_{current_class}.png"
        blended_path = os.path.join(segmented_folder, blended_filename)
        Image.fromarray(blended_image).save(blended_path)

        # Return URL for frontend access
        segmented_url = f"/static/uploads/segmented/{current_class}/{blended_filename}"
        logging.info(f"Segmentation completed for {current_class}. Points: {points}, Labels: {labels}")
        return jsonify({'segmented_url': segmented_url})

    except ValueError as ve:
        logging.error(f"Value error during segmentation: {ve}")
        return jsonify({'error': str(ve)}), 400

    except Exception as e:
        logging.error(f"Unexpected error during segmentation: {e}")
        return jsonify({'error': 'Segmentation failed', 'details': str(e)}), 500

@app.route('/automatic_segment', methods=['POST'])
def automatic_segment():
    """Perform automatic segmentation using YOLO."""
    if 'file' not in request.files:
        return jsonify({'error': 'No file uploaded'}), 400
    file = request.files['file']
    if file.filename == '':
        return jsonify({'error': 'No file selected'}), 400

    input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename)
    file.save(input_path)

    try:
        # Perform YOLO segmentation
        results = yolo_model.predict(input_path, save=False, save_txt=False)
        output_folder = UPLOAD_FOLDERS['automatic_segmented']
        os.makedirs(output_folder, exist_ok=True)

        chips_data = []
        chips = []
        voids = []

        # Process results and save segmented images
        for result in results:
            annotated_image = result.plot()
            result_filename = f"{file.filename.rsplit('.', 1)[0]}_pred.jpg"
            result_path = os.path.join(output_folder, result_filename)
            Image.fromarray(annotated_image).save(result_path)

            # Separate chips and voids
            for i, label in enumerate(result.boxes.cls):  # YOLO labels
                label_name = result.names[int(label)]  # Get label name (e.g., 'chip' or 'void')
                box = result.boxes.xyxy[i].cpu().numpy()  # Bounding box (x1, y1, x2, y2)
                area = float((box[2] - box[0]) * (box[3] - box[1]))  # Calculate area

                if label_name == 'chip':
                    chips.append({'box': box, 'area': area, 'voids': []})
                elif label_name == 'void':
                    voids.append({'box': box, 'area': area})

            # Assign voids to chips based on proximity
            for void in voids:
                void_centroid = [
                    (void['box'][0] + void['box'][2]) / 2,  # x centroid
                    (void['box'][1] + void['box'][3]) / 2   # y centroid
                ]
                for chip in chips:
                    # Check if void centroid is within chip bounding box
                    if (chip['box'][0] <= void_centroid[0] <= chip['box'][2] and
                            chip['box'][1] <= void_centroid[1] <= chip['box'][3]):
                        chip['voids'].append(void)
                        break

            # Calculate metrics for each chip
            for idx, chip in enumerate(chips):
                chip_area = chip['area']
                total_void_area = sum([float(void['area']) for void in chip['voids']])
                max_void_area = max([float(void['area']) for void in chip['voids']], default=0)

                void_percentage = (total_void_area / chip_area) * 100 if chip_area > 0 else 0
                max_void_percentage = (max_void_area / chip_area) * 100 if chip_area > 0 else 0

                chips_data.append({
                    "chip_number": int(idx + 1),
                    "chip_area": round(chip_area, 2),
                    "void_percentage": round(void_percentage, 2),
                    "max_void_percentage": round(max_void_percentage, 2)
                })

        # Return the segmented image URL and table data
        segmented_url = f"/static/uploads/segmented/automatic/{result_filename}"
        return jsonify({
            "segmented_url": segmented_url,  # Use the URL for frontend access
            "table_data": {
                "image_name": file.filename,
                "chips": chips_data
            }
        })

    except Exception as e:
        print(f"Error in automatic segmentation: {e}")
        return jsonify({'error': 'Segmentation failed.'}), 500

@app.route('/save_both', methods=['POST'])
def save_both():
    """Save both the image and masks into the history folders."""
    data = request.json
    image_name = data.get('image_name')

    if not image_name:
        return jsonify({'error': 'Image name not provided'}), 400

    try:
        # Ensure image_name is a pure file name
        image_name = os.path.basename(image_name)  # Strip any directory path
        print(f"Sanitized Image Name: {image_name}")

        # Correctly resolve the input image path
        input_image_path = os.path.join(UPLOAD_FOLDERS['input'], image_name)
        if not os.path.exists(input_image_path):
            print(f"Input image does not exist: {input_image_path}")
            return jsonify({'error': f'Input image not found: {input_image_path}'}), 404

        # Copy the image to history/images
        image_history_path = os.path.join(HISTORY_FOLDERS['images'], image_name)
        os.makedirs(os.path.dirname(image_history_path), exist_ok=True)
        shutil.copy(input_image_path, image_history_path)
        print(f"Image saved to history: {image_history_path}")

        # Backup void mask
        void_mask_path = os.path.join(UPLOAD_FOLDERS['mask_voids'], 'raw_mask.png')
        if os.path.exists(void_mask_path):
            void_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png")
            os.makedirs(os.path.dirname(void_mask_history_path), exist_ok=True)
            shutil.copy(void_mask_path, void_mask_history_path)
            print(f"Voids mask saved to history: {void_mask_history_path}")
        else:
            print(f"Voids mask not found: {void_mask_path}")

        # Backup chip mask
        chip_mask_path = os.path.join(UPLOAD_FOLDERS['mask_chips'], 'raw_mask.png')
        if os.path.exists(chip_mask_path):
            chip_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png")
            os.makedirs(os.path.dirname(chip_mask_history_path), exist_ok=True)
            shutil.copy(chip_mask_path, chip_mask_history_path)
            print(f"Chips mask saved to history: {chip_mask_history_path}")
        else:
            print(f"Chips mask not found: {chip_mask_path}")

        return jsonify({'message': 'Image and masks saved successfully!'}), 200

    except Exception as e:
        print(f"Error saving files: {e}")
        return jsonify({'error': 'Failed to save files.', 'details': str(e)}), 500

@app.route('/get_history', methods=['GET'])
def get_history():
    try:
        saved_images = os.listdir(HISTORY_FOLDERS['images'])
        return jsonify({'status': 'success', 'images': saved_images}), 200
    except Exception as e:
        return jsonify({'status': 'error', 'message': f'Failed to fetch history: {e}'}), 500


@app.route('/delete_history_item', methods=['POST'])
def delete_history_item():
    data = request.json
    image_name = data.get('image_name')

    if not image_name:
        return jsonify({'error': 'Image name not provided'}), 400

    try:
        image_path = os.path.join(HISTORY_FOLDERS['images'], image_name)
        if os.path.exists(image_path):
            os.remove(image_path)

        void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png")
        if os.path.exists(void_mask_path):
            os.remove(void_mask_path)

        chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png")
        if os.path.exists(chip_mask_path):
            os.remove(chip_mask_path)

        return jsonify({'message': f'{image_name} and associated masks deleted successfully.'}), 200
    except Exception as e:
        return jsonify({'error': f'Failed to delete files: {e}'}), 500

# Lock for training status updates
status_lock = Lock()

def update_training_status(key, value):
    """Thread-safe update for training status."""
    with status_lock:
        training_status[key] = value

@app.route('/retrain_model', methods=['POST'])
def retrain_model():
    """Handle retrain model workflow."""
    global training_status

    if training_status.get('running', False):
        return jsonify({'error': 'Training is already in progress'}), 400

    try:
        # Update training status
        update_training_status('running', True)
        update_training_status('cancelled', False)
        logging.info("Training status updated. Starting training workflow.")

        # Backup masks and images
        backup_masks_and_images()
        logging.info("Backup completed successfully.")

        # Prepare YOLO labels
        prepare_yolo_labels()
        logging.info("YOLO labels prepared successfully.")

        # Start YOLO training in a separate thread
        threading.Thread(target=run_yolo_training).start()
        return jsonify({'message': 'Training started successfully!'}), 200

    except Exception as e:
        logging.error(f"Error during training preparation: {e}")
        update_training_status('running', False)
        return jsonify({'error': f"Failed to start training: {e}"}), 500
        
def prepare_yolo_labels():
    """Convert all masks into YOLO-compatible labels and copy images to the dataset folder."""
    images_folder = HISTORY_FOLDERS['images']  # Use history images as the source
    train_labels_folder = DATASET_FOLDERS['train_labels']
    train_images_folder = DATASET_FOLDERS['train_images']
    val_labels_folder = DATASET_FOLDERS['val_labels']
    val_images_folder = DATASET_FOLDERS['val_images']

    # Ensure destination directories exist
    os.makedirs(train_labels_folder, exist_ok=True)
    os.makedirs(train_images_folder, exist_ok=True)
    os.makedirs(val_labels_folder, exist_ok=True)
    os.makedirs(val_images_folder, exist_ok=True)

    try:
        all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))]
        random.shuffle(all_images)  # Shuffle the images for randomness

        # Determine split index
        split_idx = int(len(all_images) * 0.8)  # 80% for training, 20% for validation

        # Split images into train and validation sets
        train_images = all_images[:split_idx]
        val_images = all_images[split_idx:]

        # Process training images
        for image_name in train_images:
            process_image_and_mask(
                image_name,
                source_images_folder=images_folder,
                dest_images_folder=train_images_folder,
                dest_labels_folder=train_labels_folder
            )

        # Process validation images
        for image_name in val_images:
            process_image_and_mask(
                image_name,
                source_images_folder=images_folder,
                dest_images_folder=val_images_folder,
                dest_labels_folder=val_labels_folder
            )

        logging.info("YOLO labels prepared, and images split into train and validation successfully.")

    except Exception as e:
        logging.error(f"Error in preparing YOLO labels: {e}")
        raise
  
import random

def prepare_yolo_labels():
    """Convert all masks into YOLO-compatible labels and copy images to the dataset folder."""
    images_folder = HISTORY_FOLDERS['images']  # Use history images as the source
    train_labels_folder = DATASET_FOLDERS['train_labels']
    train_images_folder = DATASET_FOLDERS['train_images']
    val_labels_folder = DATASET_FOLDERS['val_labels']
    val_images_folder = DATASET_FOLDERS['val_images']

    # Ensure destination directories exist
    os.makedirs(train_labels_folder, exist_ok=True)
    os.makedirs(train_images_folder, exist_ok=True)
    os.makedirs(val_labels_folder, exist_ok=True)
    os.makedirs(val_images_folder, exist_ok=True)

    try:
        all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))]
        random.shuffle(all_images)  # Shuffle the images for randomness

        # Determine split index
        split_idx = int(len(all_images) * 0.8)  # 80% for training, 20% for validation

        # Split images into train and validation sets
        train_images = all_images[:split_idx]
        val_images = all_images[split_idx:]

        # Process training images
        for image_name in train_images:
            process_image_and_mask(
                image_name,
                source_images_folder=images_folder,
                dest_images_folder=train_images_folder,
                dest_labels_folder=train_labels_folder
            )

        # Process validation images
        for image_name in val_images:
            process_image_and_mask(
                image_name,
                source_images_folder=images_folder,
                dest_images_folder=val_images_folder,
                dest_labels_folder=val_labels_folder
            )

        logging.info("YOLO labels prepared, and images split into train and validation successfully.")

    except Exception as e:
        logging.error(f"Error in preparing YOLO labels: {e}")
        raise


def process_image_and_mask(image_name, source_images_folder, dest_images_folder, dest_labels_folder):
    """
    Process a single image and its masks, saving them in the appropriate YOLO format.
    """
    try:
        image_path = os.path.join(source_images_folder, image_name)
        label_file_path = os.path.join(dest_labels_folder, f"{os.path.splitext(image_name)[0]}.txt")

        # Copy image to the destination images folder
        shutil.copy(image_path, os.path.join(dest_images_folder, image_name))

        # Clear the label file if it exists
        if os.path.exists(label_file_path):
            os.remove(label_file_path)

        # Process void mask
        void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png")
        if os.path.exists(void_mask_path):
            convert_mask_to_yolo(
                mask_path=void_mask_path,
                image_path=image_path,
                class_id=0,  # Void class
                output_path=label_file_path
            )

        # Process chip mask
        chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png")
        if os.path.exists(chip_mask_path):
            convert_mask_to_yolo(
                mask_path=chip_mask_path,
                image_path=image_path,
                class_id=1,  # Chip class
                output_path=label_file_path,
                append=True  # Append chip annotations
            )

        logging.info(f"Processed {image_name} into YOLO format.")
    except Exception as e:
        logging.error(f"Error processing {image_name}: {e}")
        raise
  
def backup_masks_and_images():
    """Backup current masks and images from history folders."""
    temp_backup_paths = {
        'voids': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/voids'),
        'chips': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/chips'),
        'images': os.path.join(DATASET_FOLDERS['temp_backup'], 'images')
    }

    # Prepare all backup directories
    for path in temp_backup_paths.values():
        if os.path.exists(path):
            shutil.rmtree(path)
        os.makedirs(path, exist_ok=True)

    try:
        # Backup images from history
        for file in os.listdir(HISTORY_FOLDERS['images']):
            src_image_path = os.path.join(HISTORY_FOLDERS['images'], file)
            dst_image_path = os.path.join(temp_backup_paths['images'], file)
            shutil.copy(src_image_path, dst_image_path)

        # Backup void masks from history
        for file in os.listdir(HISTORY_FOLDERS['masks_void']):
            src_void_path = os.path.join(HISTORY_FOLDERS['masks_void'], file)
            dst_void_path = os.path.join(temp_backup_paths['voids'], file)
            shutil.copy(src_void_path, dst_void_path)

        # Backup chip masks from history
        for file in os.listdir(HISTORY_FOLDERS['masks_chip']):
            src_chip_path = os.path.join(HISTORY_FOLDERS['masks_chip'], file)
            dst_chip_path = os.path.join(temp_backup_paths['chips'], file)
            shutil.copy(src_chip_path, dst_chip_path)

        logging.info("Masks and images backed up successfully from history.")
    except Exception as e:
        logging.error(f"Error during backup: {e}")
        raise RuntimeError("Backup process failed.")

def run_yolo_training(num_epochs=10):
    """Run YOLO training process."""
    global training_process

    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        data_cfg_path = os.path.join(BASE_DIR, "models/data.yaml")  # Ensure correct YAML path

        logging.info(f"Starting YOLO training on {device} with {num_epochs} epochs.")
        logging.info(f"Using dataset configuration: {data_cfg_path}")

        training_command = [
            "yolo",
            "train",
            f"data={data_cfg_path}",
            f"model={os.path.join(DATASET_FOLDERS['models'], 'best.pt')}",
            f"device={device}",
            f"epochs={num_epochs}",
            "project=runs",
            "name=train"
        ]

        training_process = subprocess.Popen(
            training_command,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            env=os.environ.copy(),
        )

        # Display and log output in real time
        for line in iter(training_process.stdout.readline, ''):
            print(line.strip())
            logging.info(line.strip())
            socketio.emit('training_update', {'message': line.strip()})  # Send updates to the frontend

        training_process.wait()

        if training_process.returncode == 0:
            finalize_training()  # Finalize successfully completed training
        else:
            raise RuntimeError("YOLO training process failed. Check logs for details.")
    except Exception as e:
        logging.error(f"Training error: {e}")
        restore_backup()  # Restore the dataset and masks

        # Emit training error event to the frontend
        socketio.emit('training_status', {'status': 'error', 'message': f"Training failed: {str(e)}"})
    finally:
        update_training_status('running', False)
        training_process = None  # Reset the process


@socketio.on('cancel_training')
def handle_cancel_training():
    """Cancel the YOLO training process."""
    global training_process, training_status

    if not training_status.get('running', False):
        socketio.emit('button_update', {'action': 'retrain'})  # Update button to retrain
        return

    try:
        training_process.terminate()
        training_process.wait()
        training_status['running'] = False
        training_status['cancelled'] = True

        restore_backup()
        cleanup_train_val_directories()

        # Emit button state change
        socketio.emit('button_update', {'action': 'retrain'})
        socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'})
    except Exception as e:
        logging.error(f"Error cancelling training: {e}")
        socketio.emit('training_status', {'status': 'error', 'message': str(e)})

def finalize_training():
    """Finalize training by promoting the new model and cleaning up."""
    try:
        # Locate the most recent training directory
        runs_dir = os.path.join(BASE_DIR, 'runs')
        if not os.path.exists(runs_dir):
            raise FileNotFoundError("Training runs directory does not exist.")

        # Get the latest training run folder
        latest_run = max(
            [os.path.join(runs_dir, d) for d in os.listdir(runs_dir)],
            key=os.path.getmtime
        )
        weights_dir = os.path.join(latest_run, 'weights')
        best_model_path = os.path.join(weights_dir, 'best.pt')

        if not os.path.exists(best_model_path):
            raise FileNotFoundError(f"'best.pt' not found in {weights_dir}.")

        # Backup the old model
        old_model_folder = DATASET_FOLDERS['models_old']
        os.makedirs(old_model_folder, exist_ok=True)
        existing_best_model = os.path.join(DATASET_FOLDERS['models'], 'best.pt')

        if os.path.exists(existing_best_model):
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            shutil.move(existing_best_model, os.path.join(old_model_folder, f"old_{timestamp}.pt"))
            logging.info(f"Old model backed up to {old_model_folder}.")

        # Move the new model to the models directory
        new_model_dest = os.path.join(DATASET_FOLDERS['models'], 'best.pt')
        shutil.move(best_model_path, new_model_dest)
        logging.info(f"New model saved to {new_model_dest}.")

        # Notify frontend that training is completed
        socketio.emit('training_status', {
            'status': 'completed',
            'message': 'Training completed successfully! Model saved as best.pt.'
        })

        # Clean up train/val directories
        cleanup_train_val_directories()
        logging.info("Train and validation directories cleaned up successfully.")

    except Exception as e:
        logging.error(f"Error finalizing training: {e}")
        # Emit error status to the frontend
        socketio.emit('training_status', {'status': 'error', 'message': f"Error finalizing training: {str(e)}"})

def restore_backup():
    """Restore the dataset and masks from the backup."""
    try:
        temp_backup = DATASET_FOLDERS['temp_backup']
        shutil.copytree(os.path.join(temp_backup, 'masks/voids'), UPLOAD_FOLDERS['mask_voids'], dirs_exist_ok=True)
        shutil.copytree(os.path.join(temp_backup, 'masks/chips'), UPLOAD_FOLDERS['mask_chips'], dirs_exist_ok=True)
        shutil.copytree(os.path.join(temp_backup, 'images'), UPLOAD_FOLDERS['input'], dirs_exist_ok=True)
        logging.info("Backup restored successfully.")
    except Exception as e:
        logging.error(f"Error restoring backup: {e}")

@app.route('/cancel_training', methods=['POST'])
def cancel_training():
    global training_process

    if training_process is None:
        logging.error("No active training process to terminate.")
        return jsonify({'error': 'No active training process to cancel.'}), 400

    try:
        training_process.terminate()
        training_process.wait()
        training_process = None  # Reset the process after termination

        # Update training status
        update_training_status('running', False)
        update_training_status('cancelled', True)

        # Check if the model is already saved as best.pt
        best_model_path = os.path.join(DATASET_FOLDERS['models'], 'best.pt')
        if os.path.exists(best_model_path):
            logging.info(f"Model already saved as best.pt at {best_model_path}.")
            socketio.emit('button_update', {'action': 'revert'})  # Notify frontend to revert button state
        else:
            logging.info("Training canceled, but no new model was saved.")

        # Restore backup if needed
        restore_backup()
        cleanup_train_val_directories()

        # Emit status update to frontend
        socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'})
        return jsonify({'message': 'Training canceled and data restored successfully.'}), 200

    except Exception as e:
        logging.error(f"Error cancelling training: {e}")
        return jsonify({'error': f"Failed to cancel training: {e}"}), 500

@app.route('/clear_history', methods=['POST'])
def clear_history():
    try:
        for folder in [HISTORY_FOLDERS['images'], HISTORY_FOLDERS['masks_chip'], HISTORY_FOLDERS['masks_void']]:
            shutil.rmtree(folder, ignore_errors=True)
            os.makedirs(folder, exist_ok=True)  # Recreate the empty folder
        return jsonify({'message': 'History cleared successfully!'}), 200
    except Exception as e:
        return jsonify({'error': f'Failed to clear history: {e}'}), 500

@app.route('/training_status', methods=['GET'])
def get_training_status():
    """Return the current training status."""
    if training_status.get('running', False):
        return jsonify({'status': 'running', 'message': 'Training in progress.'}), 200
    elif training_status.get('cancelled', False):
        return jsonify({'status': 'cancelled', 'message': 'Training was cancelled.'}), 200
    return jsonify({'status': 'idle', 'message': 'No training is currently running.'}), 200

def cleanup_train_val_directories():
    """Clear the train and validation directories."""
    try:
        for folder in [DATASET_FOLDERS['train_images'], DATASET_FOLDERS['train_labels'], 
                       DATASET_FOLDERS['val_images'], DATASET_FOLDERS['val_labels']]:
            shutil.rmtree(folder, ignore_errors=True)  # Remove folder contents
            os.makedirs(folder, exist_ok=True)  # Recreate empty folders
        logging.info("Train and validation directories cleaned up successfully.")
    except Exception as e:
        logging.error(f"Error cleaning up train/val directories: {e}")


if __name__ == '__main__':
    multiprocessing.set_start_method('spawn')  # Required for multiprocessing on Windows
    app.run(debug=True, use_reloader=False)