| import os | |
| import random | |
| from datetime import datetime | |
| import gradio as gr | |
| import chess | |
| import chess.svg | |
| from transformers import DebertaV2ForSequenceClassification, AutoTokenizer, pipeline | |
| token = os.environ['auth_token'] | |
| tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv6', use_auth_token=token) | |
| model = DebertaV2ForSequenceClassification.from_pretrained('jrahn/chessv6', use_auth_token=token) | |
| pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer) | |
| def predict_move(fen, top_k=3): | |
| preds = pipe(fen, top_k=top_k) | |
| weights = [p['score'] for p in preds] | |
| p = random.choices(preds, weights=weights)[0] | |
| return p['label'] | |
| def btn_load(inp_fen): | |
| print(f'** log - load - ts {datetime.now().isoformat()}, fen: {inp_fen}') | |
| board = chess.Board() | |
| with open('board.svg', 'w') as f: | |
| f.write(str(chess.svg.board(board))) | |
| return 'board.svg', board.fen(), '' | |
| def btn_play(inp_fen, inp_move, inp_notation, inp_k): | |
| print(f'** log - play - ts {datetime.now().isoformat()}, fen: {inp_fen}, move: {inp_move}, notation: {inp_notation}, top_k: {inp_k}') | |
| board = chess.Board(inp_fen) | |
| if inp_move: | |
| if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move) | |
| elif inp_notation == 'SAN': mv = board.parse_san(inp_move) | |
| else: | |
| mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k)) | |
| if mv in board.legal_moves: | |
| board.push(mv) | |
| else: | |
| raise ValueError(f'Illegal Move: {str(mv)} @ {board.fen()}') | |
| with open('board.svg', 'w') as f: | |
| f.write(str(chess.svg.board(board, lastmove=mv))) | |
| return 'board.svg', board.fen(), '' | |
| with gr.Blocks() as block: | |
| gr.Markdown( | |
| ''' | |
| # Play YoloChess - Policy Network v0.6 | |
| 87M Parameter Transformer (DeBERTaV2-base architecture) | |
| - pre-trained (MLM) from scratch on chess positions in FEN notation | |
| - fine-tuned for text classification (moves) on expert games. | |
| ''' | |
| ) | |
| with gr.Row() as row: | |
| with gr.Column(): | |
| with gr.Row(): | |
| move = gr.Textbox(label='human player move') | |
| notation = gr.Radio(["SAN", "UCI"], value="SAN", label='move notation') | |
| fen = gr.Textbox(value=chess.Board().fen(), label='FEN') | |
| top_k = gr.Number(value=3, label='sample from top_k moves', precision=0) | |
| with gr.Row(): | |
| load_btn = gr.Button("Load") | |
| play_btn = gr.Button("Play") | |
| gr.Markdown( | |
| ''' | |
| - Click "Load" button to start and reset board. | |
| - Click "Play" button to get Engine move. | |
| - Enter a "human player move" in UCI or SAN notation and click "Play" to move a piece. | |
| - Output "ERROR" generally occurs on illegal moves (Human or Engine). | |
| - Enter "FEN" to start from a custom position. | |
| ''' | |
| ) | |
| with gr.Column(): | |
| position_output = gr.Image(label='board') | |
| load_btn.click(fn=btn_load, inputs=fen, outputs=[position_output, fen, move]) | |
| play_btn.click(fn=btn_play, inputs=[fen, move, notation, top_k], outputs=[position_output, fen, move]) | |
| block.launch() |