File size: 8,486 Bytes
6ee506d
c4a53fb
 
 
 
 
 
 
6ee506d
 
 
 
c4a53fb
 
 
 
 
 
51cbff7
 
c4a53fb
 
 
 
 
bf3c282
c4a53fb
 
 
bf3c282
c4a53fb
 
51cbff7
 
c4a53fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf3c282
 
e647ee0
 
c4a53fb
e647ee0
 
c4a53fb
e647ee0
c4a53fb
bf3c282
c4a53fb
 
 
 
 
 
 
 
 
 
f296950
 
c4a53fb
 
f296950
 
 
 
 
 
c4a53fb
 
 
 
 
 
 
 
 
 
 
f296950
 
6ee506d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4079f62
 
6ee506d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f956c5
6ee506d
 
 
9ea4440
6ee506d
f296950
 
9ea4440
f296950
9ea4440
f296950
 
 
 
 
 
 
 
 
 
e647ee0
f296950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4079f62
 
f296950
 
 
 
 
28e9da1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from inference_sdk import InferenceHTTPClient
from ultralytics import YOLO
import cv2
from stockfish import Stockfish
import os 
import numpy as np 
import streamlit as st

CLIENT = InferenceHTTPClient(
    api_url="https://outline.roboflow.com",
    api_key="9Ez1hwfkqVa2h6pRQQHH"
)

# Constants
FEN_MAPPING = {
    "black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k",
    "white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K"
}
GRID_BORDER = 0  # Border size in pixels
GRID_SIZE = 224  # Effective grid size (10px to 214px)
BLOCK_SIZE = GRID_SIZE // 8  # Each block is ~25px
X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']  # Labels for x-axis (a to h)
Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1]  # Reversed labels for y-axis (8 to 1)

# Functions
def get_grid_coordinate(pixel_x, pixel_y, perspective):
    """
    Function to determine the grid coordinate of a pixel, considering a 10px border and
    the grid where bottom-left is (a, 1) and top-left is (h, 8).
    The perspective argument can adjust for white ('w') or black ('b') viewpoint.
    """
    # Grid settings
    border = 0  # 10px border
    grid_size = 224  # Effective grid size (10px to 214px)
    block_size = grid_size // 8  # Each block is ~25px

    x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']  # Labels for x-axis (a to h)
    y_labels = [8, 7, 6, 5, 4, 3, 2, 1]  # Reversed labels for y-axis (8 to 1)

    # Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10)
    adjusted_x = pixel_x - border
    adjusted_y = pixel_y - border

    # Check bounds
    if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size:
        return "Pixel outside grid bounds"

    # Determine the grid column and row
    x_index = adjusted_x // block_size
    y_index = adjusted_y // block_size

    if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels):
        return "Pixel outside grid bounds"

    # Adjust labels based on perspective
    if perspective == "b":
        x_index = 7 - x_index  # Flip x-axis for black's perspective
        y_index = 7- y_index  # Flip y-axis for black's perspective

    file = x_labels[x_index]
    rank = y_labels[y_index]

    return f"{file}{rank}"


def predict_next_move(fen, stockfish):
    """
    Predict the next move using Stockfish.
    """
    if stockfish.is_fen_valid(fen):
        stockfish.set_fen_position(fen)
    else:
        return "Invalid FEN notation!"

    best_move = stockfish.get_best_move()
    return f"The predicted next move is: {best_move}" if best_move else "No valid move found (checkmate/stalemate)."

        

def main():
    st.title("Chessboard Position Detection and Move Prediction")

    # Set permissions for the Stockfish engine binary
    os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755)

    # User uploads an image or captures it from their camera
    image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])

    if image_file is not None:
        # Save the image to a temporary file
        temp_dir = "temp_images"
        os.makedirs(temp_dir, exist_ok=True)
        temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg")
        with open(temp_file_path, "wb") as f:
            f.write(image_file.getbuffer())

        # Load the YOLO models
        model = YOLO("fine_tuned_on_all_data.pt")  # Replace with your trained model weights file

        
        # seg_model = YOLO("segmentation.pt")

        # # Load and process the image
        # img = cv2.imread(temp_file_path)
        # r = seg_model.predict(source=temp_file_path)
        # xyxy = r[0].boxes.xyxy
        # x_min, y_min, x_max, y_max = map(int, xyxy[0])
        # new_img = img[y_min:y_max, x_min:x_max]



        
        image = cv2.imread(temp_file_path)
        
        result = CLIENT.infer(image, model_id="chessboard-segmentation/1")
        
        if image is None:
            st.write("Error: Image not loaded.")
        
        
        prediction_data = result
        
        
        if not prediction_data.get('predictions'):
            st.write("No board found.")
            return 
        else:
            for prediction in prediction_data.get('predictions', []):
                if 'x' in prediction and 'y' in prediction and 'width' in prediction and 'height' in prediction:
                    x, y, w, h = prediction['x'], prediction['y'], prediction['width'], prediction['height']
                    # print(f"Bounding box coordinates: ({x}, {y}), width={w}, height={h}")
            
                    x1, y1 = int(x - w / 2), int(y - h / 2)
                    x2, y2 = int(x + w / 2), int(y + h / 2)
            
                    src_pts = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]], dtype="float32")
                    # print(f"Source Points: {src_pts}")
            
                    chessboard_size = 600
                    dst_pts = np.array([[0, 0], [chessboard_size - 1, 0], [chessboard_size - 1, chessboard_size - 1], [0, chessboard_size - 1]], dtype="float32")
                    # print(f"Destination Points: {dst_pts}")
            
                    matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
            
                    # Apply the perspective warp
                    transformed_chessboard = cv2.warpPerspective(image, matrix, (chessboard_size, chessboard_size))
                    
            
                    # Convert images to RGB for display
                    new_img = cv2.cvtColor(transformed_chessboard, cv2.COLOR_BGR2RGB)
                    


        # Resize the image to 224x224
        
        image = cv2.resize(new_img, (224, 224))
        st.image(image, caption="Segmented Chessboard", use_container_width=True)
        height, width, _ = image.shape

        # Get user input for perspective
        p = st.radio("Select perspective:", ["b (Black)", "w (White)"])
        p = p[0].lower()

        # Initialize the board for FEN (empty rows represented by "8")
        board = [["8"] * 8 for _ in range(8)]

        # Run detection
        results = model.predict(source=image, save=False, save_txt=False, conf=0.7)

        # Extract predictions and map to FEN board
        for result in results[0].boxes:
            x1, y1, x2, y2 = result.xyxy[0].tolist()
            class_id = int(result.cls[0])
            class_name = model.names[class_id]

            fen_piece = FEN_MAPPING.get(class_name, None)
            if not fen_piece:
                continue

            center_x = (x1 + x2) / 2
            center_y = (y1 + y2) / 2
            pixel_x = int(center_x)
            pixel_y = int(height - center_y)

            grid_position = get_grid_coordinate(pixel_x, pixel_y, p)
            if grid_position != "Pixel outside grid bounds":
                file = ord(grid_position[0]) - ord('a')
                rank = int(grid_position[1]) - 1
                board[rank][file] = fen_piece

        # Generate the FEN string
        fen_rows = []
        for row in board:
            fen_row = ""
            empty_count = 0
            for cell in row:
                if cell == "8":
                    empty_count += 1
                else:
                    if empty_count > 0:
                        fen_row += str(empty_count)
                        empty_count = 0
                    fen_row += cell
            if empty_count > 0:
                fen_row += str(empty_count)
            fen_rows.append(fen_row)

        position_fen = "/".join(fen_rows)
        move_side = st.radio("Select the side to move:", ["w (White)", "b (Black)"])[0].lower()
        fen_notation = f"{position_fen} {move_side} - - 0 0"
        st.subheader("Generated FEN Notation:")
        st.code(fen_notation)

        # Initialize the Stockfish engine
        stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt")
        stockfish = Stockfish(
        path=stockfish_path,
        depth=10,
        parameters={"Threads": 2, "Minimum Thinking Time": 2}
        )

        # Predict the next move
        next_move = predict_next_move(fen_notation, stockfish)
        st.subheader("Stockfish Recommended Move:")
        st.write(next_move)

if __name__ == "__main__":
    main()