Spaces:
Build error
Build error
import streamlit as st | |
import chess | |
import chess.svg | |
import torch | |
from o2_model import O2Net, board_to_tensor | |
from o2_agent import O2Agent | |
from PIL import Image | |
import io | |
import base64 | |
# Load model | |
MODEL_PATH = 'FlameF0X/o2' | |
def load_agent(): | |
agent = O2Agent() | |
agent.model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) | |
agent.model.eval() | |
return agent | |
def render_svg(svg): | |
b64 = base64.b64encode(svg.encode('utf-8')).decode('utf-8') | |
return f"<img src='data:image/svg+xml;base64,{b64}'/>", b64 | |
def main(): | |
st.set_page_config(page_title="Play Chess vs o2", layout="wide") | |
st.title("Play Chess vs o2 (AlphaZero-inspired)") | |
st.write("You can play as White or Black. o2 will play the other side.") | |
agent = load_agent() | |
# Session state | |
if 'board' not in st.session_state: | |
st.session_state.board = chess.Board() | |
st.session_state.human_color = 'white' | |
st.session_state.game_over = False | |
# Color selection | |
if st.session_state.board.fullmove_number == 1 and st.session_state.board.turn == chess.WHITE: | |
st.session_state.human_color = st.radio("Choose your color:", ('white', 'black'), horizontal=True) | |
board = st.session_state.board | |
human_color = st.session_state.human_color | |
# Show board | |
svg_board = chess.svg.board(board=board) | |
svg_html, _ = render_svg(svg_board) | |
st.markdown(svg_html, unsafe_allow_html=True) | |
# Game over message | |
if board.is_game_over(): | |
st.session_state.game_over = True | |
st.success(f"Game over: {board.result()} ({board.outcome().termination.name})") | |
if st.button("Restart Game"): | |
st.session_state.board = chess.Board() | |
st.session_state.game_over = False | |
st.experimental_rerun() | |
return | |
# Human move | |
if (board.turn == chess.WHITE and human_color == 'white') or (board.turn == chess.BLACK and human_color == 'black'): | |
legal_moves = list(board.legal_moves) | |
move_uci = st.selectbox("Your move:", [m.uci() for m in legal_moves]) | |
if st.button("Make Move"): | |
move = chess.Move.from_uci(move_uci) | |
if move in legal_moves: | |
board.push(move) | |
st.experimental_rerun() | |
else: | |
# o2 agent move | |
with st.spinner("o2 is thinking..."): | |
move = agent.select_move(board, use_mcts=True, simulations=30) | |
board.push(move) | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
main() | |