File size: 2,553 Bytes
a93809d
d8a8f31
 
 
 
 
 
 
 
a93809d
d8a8f31
1328a85
d8a8f31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import streamlit as st
import chess
import chess.svg
import torch
from o2_model import O2Net, board_to_tensor
from o2_agent import O2Agent
from PIL import Image
import io
import base64

# Load model
MODEL_PATH = 'FlameF0X/o2'
@st.cache_resource
def load_agent():
    agent = O2Agent()
    agent.model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
    agent.model.eval()
    return agent

def render_svg(svg):
    b64 = base64.b64encode(svg.encode('utf-8')).decode('utf-8')
    return f"<img src='data:image/svg+xml;base64,{b64}'/>", b64

def main():
    st.set_page_config(page_title="Play Chess vs o2", layout="wide")
    st.title("Play Chess vs o2 (AlphaZero-inspired)")
    st.write("You can play as White or Black. o2 will play the other side.")

    agent = load_agent()

    # Session state
    if 'board' not in st.session_state:
        st.session_state.board = chess.Board()
        st.session_state.human_color = 'white'
        st.session_state.game_over = False

    # Color selection
    if st.session_state.board.fullmove_number == 1 and st.session_state.board.turn == chess.WHITE:
        st.session_state.human_color = st.radio("Choose your color:", ('white', 'black'), horizontal=True)

    board = st.session_state.board
    human_color = st.session_state.human_color

    # Show board
    svg_board = chess.svg.board(board=board)
    svg_html, _ = render_svg(svg_board)
    st.markdown(svg_html, unsafe_allow_html=True)

    # Game over message
    if board.is_game_over():
        st.session_state.game_over = True
        st.success(f"Game over: {board.result()} ({board.outcome().termination.name})")
        if st.button("Restart Game"):
            st.session_state.board = chess.Board()
            st.session_state.game_over = False
            st.experimental_rerun()
        return

    # Human move
    if (board.turn == chess.WHITE and human_color == 'white') or (board.turn == chess.BLACK and human_color == 'black'):
        legal_moves = list(board.legal_moves)
        move_uci = st.selectbox("Your move:", [m.uci() for m in legal_moves])
        if st.button("Make Move"):
            move = chess.Move.from_uci(move_uci)
            if move in legal_moves:
                board.push(move)
                st.experimental_rerun()
    else:
        # o2 agent move
        with st.spinner("o2 is thinking..."):
            move = agent.select_move(board, use_mcts=True, simulations=30)
            board.push(move)
            st.experimental_rerun()

if __name__ == "__main__":
    main()