Spaces:
Sleeping
Sleeping
from inference_sdk import InferenceHTTPClient | |
from ultralytics import YOLO | |
import cv2 | |
from stockfish import Stockfish | |
import os | |
import numpy as np | |
import streamlit as st | |
CLIENT = InferenceHTTPClient( | |
api_url="https://outline.roboflow.com", | |
api_key="9Ez1hwfkqVa2h6pRQQHH" | |
) | |
# Constants | |
FEN_MAPPING = { | |
"black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k", | |
"white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K" | |
} | |
GRID_BORDER = 0 # Border size in pixels | |
GRID_SIZE = 224 # Effective grid size (10px to 214px) | |
BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px | |
X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h) | |
Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1) | |
# Functions | |
def get_grid_coordinate(pixel_x, pixel_y, perspective): | |
""" | |
Function to determine the grid coordinate of a pixel, considering a 10px border and | |
the grid where bottom-left is (a, 1) and top-left is (h, 8). | |
The perspective argument can adjust for white ('w') or black ('b') viewpoint. | |
""" | |
# Grid settings | |
border = 0 # 10px border | |
grid_size = 224 # Effective grid size (10px to 214px) | |
block_size = grid_size // 8 # Each block is ~25px | |
x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h) | |
y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1) | |
# Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10) | |
adjusted_x = pixel_x - border | |
adjusted_y = pixel_y - border | |
# Check bounds | |
if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size: | |
return "Pixel outside grid bounds" | |
# Determine the grid column and row | |
x_index = adjusted_x // block_size | |
y_index = adjusted_y // block_size | |
if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels): | |
return "Pixel outside grid bounds" | |
# Adjust labels based on perspective | |
if perspective == "b": | |
x_index = 7 - x_index # Flip x-axis for black's perspective | |
y_index = 7- y_index # Flip y-axis for black's perspective | |
file = x_labels[x_index] | |
rank = y_labels[y_index] | |
return f"{file}{rank}" | |
def predict_next_move(fen, stockfish): | |
""" | |
Predict the next move using Stockfish. | |
""" | |
if stockfish.is_fen_valid(fen): | |
stockfish.set_fen_position(fen) | |
else: | |
return "Invalid FEN notation!" | |
best_move = stockfish.get_best_move() | |
return f"The predicted next move is: {best_move}" if best_move else "No valid move found (checkmate/stalemate)." | |
def main(): | |
st.title("Chessboard Position Detection and Move Prediction") | |
# Set permissions for the Stockfish engine binary | |
os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755) | |
# User uploads an image or captures it from their camera | |
image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"]) | |
if image_file is not None: | |
# Save the image to a temporary file | |
temp_dir = "temp_images" | |
os.makedirs(temp_dir, exist_ok=True) | |
temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg") | |
with open(temp_file_path, "wb") as f: | |
f.write(image_file.getbuffer()) | |
# Load the YOLO models | |
model = YOLO("chessDetection3d.pt") # Replace with your trained model weights file | |
seg_model = YOLO("segmentation.pt") | |
# Load and process the image | |
img = cv2.imread(temp_file_path) | |
r = seg_model.predict(source=temp_file_path) | |
if len(r) == 0 or len(r) > 1: | |
if len(r) == 0: | |
st.write("NO BOARD IN THE IMAGE") | |
elif len(r) > 1: | |
st.write("Multiple boards are there in the image, please take only at a time") | |
return | |
xyxy = r[0].boxes.xyxy | |
x_min, y_min, x_max, y_max = map(int, xyxy[0]) | |
new_img = img[y_min:y_max, x_min:x_max] | |
image = cv2.resize(new_img, (224, 224)) | |
st.image(image, caption="Segmented Chessboard", use_container_width=True) | |
height, width, _ = image.shape | |
# Get user input for perspective | |
p = st.radio("Select perspective:", ["b (Black)", "w (White)"]) | |
p = p[0].lower() | |
# Initialize the board for FEN (empty rows represented by "8") | |
board = [["8"] * 8 for _ in range(8)] | |
# Run detection | |
results = model.predict(source=image, save=False, save_txt=False, conf=0.7) | |
# Extract predictions and map to FEN board | |
for result in results[0].boxes: | |
x1, y1, x2, y2 = result.xyxy[0].tolist() | |
class_id = int(result.cls[0]) | |
class_name = model.names[class_id] | |
fen_piece = FEN_MAPPING.get(class_name, None) | |
if not fen_piece: | |
continue | |
center_x = (x1 + x2) / 2 | |
center_y = (y1 + y2) / 2 | |
pixel_x = int(center_x) | |
pixel_y = int(height - center_y) | |
grid_position = get_grid_coordinate(pixel_x, pixel_y, p) | |
if grid_position != "Pixel outside grid bounds": | |
file = ord(grid_position[0]) - ord('a') | |
rank = int(grid_position[1]) - 1 | |
board[rank][file] = fen_piece | |
# Generate the FEN string | |
fen_rows = [] | |
for row in board: | |
fen_row = "" | |
empty_count = 0 | |
for cell in row: | |
if cell == "8": | |
empty_count += 1 | |
else: | |
if empty_count > 0: | |
fen_row += str(empty_count) | |
empty_count = 0 | |
fen_row += cell | |
if empty_count > 0: | |
fen_row += str(empty_count) | |
fen_rows.append(fen_row) | |
position_fen = "/".join(fen_rows) | |
move_side = st.radio("Select the side to move:", ["w (White)", "b (Black)"])[0].lower() | |
fen_notation = f"{position_fen} {move_side} - - 0 0" | |
st.subheader("Generated FEN Notation:") | |
st.code(fen_notation) | |
# Initialize the Stockfish engine | |
stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt") | |
stockfish = Stockfish( | |
path=stockfish_path, | |
depth=10, | |
parameters={"Threads": 2, "Minimum Thinking Time": 2} | |
) | |
# Predict the next move | |
next_move = predict_next_move(fen_notation, stockfish) | |
st.subheader("Stockfish Recommended Move:") | |
st.write(next_move) | |
if __name__ == "__main__": | |
main() |