from ultralytics import YOLO
import cv2
from stockfish import Stockfish
import os 
import numpy as np 
import streamlit as st

# Constants
FEN_MAPPING = {
    "black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k",
    "white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K"
}
GRID_BORDER = 10  # Border size in pixels
GRID_SIZE = 204  # Effective grid size (10px to 214px)
BLOCK_SIZE = GRID_SIZE // 8  # Each block is ~25px
X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']  # Labels for x-axis (a to h)
Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1]  # Reversed labels for y-axis (8 to 1)

# Functions
def get_grid_coordinate(pixel_x, pixel_y):
    """
    Function to determine the grid coordinate of a pixel, considering a 10px border and
    the grid where bottom-left is (a, 1) and top-left is (h, 8).
    """
    # Grid settings
    border = 10  # 10px border
    grid_size = 204  # Effective grid size (10px to 214px)
    block_size = grid_size // 8  # Each block is ~25px

    x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']  # Labels for x-axis (a to h)
    y_labels = [8, 7, 6, 5, 4, 3, 2, 1]  # Reversed labels for y-axis (8 to 1)

    # Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10)
    adjusted_x = pixel_x - border
    adjusted_y = pixel_y - border

    # Check bounds
    if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size:
        return "Pixel outside grid bounds"

    # Determine the grid column and row
    x_index = adjusted_x // block_size
    y_index = adjusted_y // block_size

    if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels):
        return "Pixel outside grid bounds"

    # Convert indices to grid coordinates
    x_index = adjusted_x // block_size  # Determine the column index (0-7)
    y_index = adjusted_y // block_size  # Determine the row index (0-7)

    # Convert row index to the correct label, with '8' at the bottom
    y_labeld = y_labels[y_index]  # Correct index directly maps to '8' to '1'
    x_label = x_labels[x_index]
    y_label = 8 - y_labeld + 1

    return f"{x_label}{y_label}"

def predict_next_move(fen, stockfish):
    """
    Predict the next move using Stockfish.
    """
    if stockfish.is_fen_valid(fen):
        stockfish.set_fen_position(fen)
    else:
        return "Invalid FEN notation!"

    best_move = stockfish.get_best_move()
    ans = transform_string(best_move)
    return f"The predicted next move is: {ans}" if best_move else "No valid move found (checkmate/stalemate)."




def process_image(image_path):
    # Ensure output directory exists
    if not os.path.exists('output'):
        os.makedirs('output')

    # Load the segmentation model
    segmentation_model = YOLO("segmentation.pt")

    # Run inference to get segmentation results
    results = segmentation_model.predict(
        source=image_path,
        conf=0.8  # Confidence threshold
    )

    # Initialize variables for the segmented mask and bounding box
    segmentation_mask = None
    bbox = None

    for result in results:
        if result.boxes.conf[0] >= 0.8:  # Filter results by confidence
            segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
            bbox = result.boxes.xyxy[0].cpu().numpy()  # Get the bounding box coordinates
            break

    if segmentation_mask is None:
        print("No segmentation mask with confidence above 0.8 found.")
        return None

    # Load the image
    image = cv2.imread(image_path)

    # Resize segmentation mask to match the input image dimensions
    segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))

    # Extract bounding box coordinates
    if bbox is not None:
        x1, y1, x2, y2 = bbox
        # Crop the segmented region based on the bounding box
        cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]

        # Save the cropped segmented image
        cropped_image_path = 'output/cropped_segment.jpg'
        cv2.imwrite(cropped_image_path, cropped_segment)
        print(f"Cropped segmented image saved to {cropped_image_path}")

        # Return the cropped image
        return cropped_segment

def transform_string(input_str):
    # Remove extra spaces and convert to lowercase
    input_str = input_str.strip().lower()

    # Check if input is valid
    if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \
       not input_str[2].isalpha() or not input_str[3].isdigit():
        return "Invalid input"

    # Define mappings
    letter_mapping = {
        'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e',
        'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a'
    }
    number_mapping = {
        '1': '8', '2': '7', '3': '6', '4': '5',
        '5': '4', '6': '3', '7': '2', '8': '1'
    }

    # Transform string
    result = ""
    for i, char in enumerate(input_str):
        if i % 2 == 0:  # Letters
            result += letter_mapping.get(char, "Invalid")
        else:  # Numbers
            result += number_mapping.get(char, "Invalid")
    
    # Check for invalid transformations
    if "Invalid" in result:
        return "Invalid input"

    return result

# Example usage
 # Output: d6h2
# Example usage:



# Streamlit app
def main():
    st.title("Chessboard Position Detection and Move Prediction")

    # User uploads an image or captures it from their camera
    image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])

    if image_file is not None:
        # Save the image to a temporary file
        temp_dir = "temp_images"
        os.makedirs(temp_dir, exist_ok=True)
        temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg")
        with open(temp_file_path, "wb") as f:
            f.write(image_file.getbuffer())

        # Process the image using its file path
        processed_image = process_image(temp_file_path)

        if processed_image is not None:
            # Resize the image to 224x224
            processed_image = cv2.resize(processed_image, (224, 224))
            height, width, _ = processed_image.shape

            # Initialize the YOLO model
            model = YOLO("standard.pt")  # Replace with your trained model weights file

            # Run detection
            results = model.predict(source=processed_image, save=False, save_txt=False, conf=0.6)

            # Initialize the board for FEN (empty rows represented by "8")
            board = [["8"] * 8 for _ in range(8)]

            # Extract predictions and map to FEN board
            for result in results[0].boxes:
                x1, y1, x2, y2 = result.xyxy[0].tolist()
                class_id = int(result.cls[0])
                class_name = model.names[class_id]

                # Convert class_name to FEN notation
                fen_piece = FEN_MAPPING.get(class_name, None)
                if not fen_piece:
                    continue

                # Calculate the center of the bounding box
                center_x = (x1 + x2) / 2
                center_y = (y1 + y2) / 2

                # Convert to integer pixel coordinates
                pixel_x = int(center_x)
                pixel_y = int(height - center_y)  # Flip Y-axis for generic coordinate system

                # Get grid coordinate
                grid_position = get_grid_coordinate(pixel_x, pixel_y)

                if grid_position != "Pixel outside grid bounds":
                    file = ord(grid_position[0]) - ord('a')  # Column index (0-7)
                    rank = int(grid_position[1]) - 1  # Row index (0-7)

                    # Place the piece on the board
                    board[7 - rank][file] = fen_piece  # Flip rank index for FEN

            # Generate the FEN string
            fen_rows = []
            for row in board:
                fen_row = ""
                empty_count = 0
                for cell in row:
                    if cell == "8":
                        empty_count += 1
                    else:
                        if empty_count > 0:
                            fen_row += str(empty_count)
                            empty_count = 0
                        fen_row += cell
                if empty_count > 0:
                    fen_row += str(empty_count)
                fen_rows.append(fen_row)

            position_fen = "/".join(fen_rows)

            # Ask the user for the next move side
            move_side = st.selectbox("Select the side to move:", ["w (White)", "b (Black)"])
            move_side = "w" if move_side.startswith("w") else "b"

            # Append the full FEN string continuation
            fen_notation = f"{position_fen} {move_side} - - 0 0"

            st.subheader("Generated FEN Notation:")
            st.code(fen_notation)

            # Initialize the Stockfish engine
            stockfish = Stockfish(
                path=r"D:\Projects\ChessVision\StockFish\stockfish\stockfish-windows-x86-64-avx2.exe",  # Replace with your Stockfish path
                depth=15,
                parameters={"Threads": 2, "Minimum Thinking Time": 30}
            )

            # Predict the next move
            next_move = predict_next_move(fen_notation, stockfish)
            st.subheader("Stockfish Recommended Move:")
            st.write(next_move)

        else:
            st.error("Failed to process the image. Please try again.")

if __name__ == "__main__":
    main()