switch model deberta
Browse files
app.py
CHANGED
|
@@ -5,12 +5,12 @@ from datetime import datetime
|
|
| 5 |
import gradio as gr
|
| 6 |
import chess
|
| 7 |
import chess.svg
|
| 8 |
-
from transformers import
|
| 9 |
|
| 10 |
token = os.environ['auth_token']
|
| 11 |
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
|
| 13 |
-
model =
|
| 14 |
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
|
| 15 |
|
| 16 |
empty_field = '0'
|
|
@@ -58,8 +58,8 @@ def btn_play(inp_fen, inp_move, inp_notation, inp_k):
|
|
| 58 |
board = chess.Board(inp_fen)
|
| 59 |
|
| 60 |
if inp_move:
|
| 61 |
-
if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move)
|
| 62 |
-
elif inp_notation == 'SAN': mv = board.parse_san(inp_move)
|
| 63 |
else:
|
| 64 |
mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
|
| 65 |
|
|
@@ -76,8 +76,10 @@ def btn_play(inp_fen, inp_move, inp_notation, inp_k):
|
|
| 76 |
with gr.Blocks() as block:
|
| 77 |
gr.Markdown(
|
| 78 |
'''
|
| 79 |
-
# Play YoloChess - Policy Network v0.
|
| 80 |
-
110M Parameter Transformer (
|
|
|
|
|
|
|
| 81 |
'''
|
| 82 |
)
|
| 83 |
with gr.Row() as row:
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
import chess
|
| 7 |
import chess.svg
|
| 8 |
+
from transformers import DebertaV2ForSequenceClassification, AutoTokenizer, pipeline
|
| 9 |
|
| 10 |
token = os.environ['auth_token']
|
| 11 |
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
|
| 13 |
+
model = DebertaV2ForSequenceClassification.from_pretrained('jrahn/chessv4', use_auth_token=token)
|
| 14 |
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
|
| 15 |
|
| 16 |
empty_field = '0'
|
|
|
|
| 58 |
board = chess.Board(inp_fen)
|
| 59 |
|
| 60 |
if inp_move:
|
| 61 |
+
if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move)
|
| 62 |
+
elif inp_notation == 'SAN': mv = board.parse_san(inp_move)
|
| 63 |
else:
|
| 64 |
mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
|
| 65 |
|
|
|
|
| 76 |
with gr.Blocks() as block:
|
| 77 |
gr.Markdown(
|
| 78 |
'''
|
| 79 |
+
# Play YoloChess - Policy Network v0.4
|
| 80 |
+
110M Parameter Transformer (DeBERTaV2-base architecture)
|
| 81 |
+
- pre-trained (MLM) from scratch on FENs
|
| 82 |
+
- fine-tuned for text classification expert games in modified FEN notation.
|
| 83 |
'''
|
| 84 |
)
|
| 85 |
with gr.Row() as row:
|