Play-with-o2 / src /streamlit_app.py
FlameF0X's picture
Update src/streamlit_app.py
1328a85 verified
raw
history blame
2.55 kB
import streamlit as st
import chess
import chess.svg
import torch
from o2_model import O2Net, board_to_tensor
from o2_agent import O2Agent
from PIL import Image
import io
import base64
# Load model
MODEL_PATH = 'FlameF0X/o2'
@st.cache_resource
def load_agent():
agent = O2Agent()
agent.model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
agent.model.eval()
return agent
def render_svg(svg):
b64 = base64.b64encode(svg.encode('utf-8')).decode('utf-8')
return f"<img src='data:image/svg+xml;base64,{b64}'/>", b64
def main():
st.set_page_config(page_title="Play Chess vs o2", layout="wide")
st.title("Play Chess vs o2 (AlphaZero-inspired)")
st.write("You can play as White or Black. o2 will play the other side.")
agent = load_agent()
# Session state
if 'board' not in st.session_state:
st.session_state.board = chess.Board()
st.session_state.human_color = 'white'
st.session_state.game_over = False
# Color selection
if st.session_state.board.fullmove_number == 1 and st.session_state.board.turn == chess.WHITE:
st.session_state.human_color = st.radio("Choose your color:", ('white', 'black'), horizontal=True)
board = st.session_state.board
human_color = st.session_state.human_color
# Show board
svg_board = chess.svg.board(board=board)
svg_html, _ = render_svg(svg_board)
st.markdown(svg_html, unsafe_allow_html=True)
# Game over message
if board.is_game_over():
st.session_state.game_over = True
st.success(f"Game over: {board.result()} ({board.outcome().termination.name})")
if st.button("Restart Game"):
st.session_state.board = chess.Board()
st.session_state.game_over = False
st.experimental_rerun()
return
# Human move
if (board.turn == chess.WHITE and human_color == 'white') or (board.turn == chess.BLACK and human_color == 'black'):
legal_moves = list(board.legal_moves)
move_uci = st.selectbox("Your move:", [m.uci() for m in legal_moves])
if st.button("Make Move"):
move = chess.Move.from_uci(move_uci)
if move in legal_moves:
board.push(move)
st.experimental_rerun()
else:
# o2 agent move
with st.spinner("o2 is thinking..."):
move = agent.select_move(board, use_mcts=True, simulations=30)
board.push(move)
st.experimental_rerun()
if __name__ == "__main__":
main()