File size: 3,908 Bytes
88cd6e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import chess
import tensorflow as tf
from keras.models import load_model

class Model_handler :
  def __init__(self) :
    self.model = load_model("model.h5")
    print("Model loaded and is ready to use !!!")

  def predict(self, state, color) :
    prediction = self.model(self.convert_state_to_input(state, color))
    return self.convert_output_to_probs(state, color, prediction[1][0])

  def convert_state_to_input(self, state, color) :
    if type(state) == str :
      temp = state.split("_")
      temp_arr = np.zeros((8,8,12), dtype = np.float16)
      arr2 = np.zeros((8,8,10), dtype = np.float16)
      temp_arr, arr2 = self.convert_board_to_input(temp[-1], color)
      for i in range(1,5) :
        temp_arr = np.concatenate([self.convert_board_to_input(temp[-1-i], color, False), temp_arr], axis = 2)
      return (np.expand_dims(temp_arr, axis = 0), np.expand_dims(arr2, axis = 0))

  def convert_board_to_input(self, state, color, current = True) :
    if current :
      board = chess.Board(state)
      board.turn = color
      arr = np.zeros((8,8,12), dtype = np.float16)
      arr2 = np.zeros((8,8,10), dtype = np.float16)
      piece_to_value = self.get_piece_to_value(color)
      piece_to_value2 = self.get_piece_to_value(color, False)
      for i in range(64) :
        if(board.piece_at(i) is not None) :
          arr[i//8,i%8,piece_to_value[board.piece_at(i).symbol()]] = 1
      for move in board.legal_moves :
        square = move.to_square
        arr[square//8, square%8, piece_to_value[board.piece_at(move.from_square).symbol()]] = 0.5
        symbol = board.piece_at(move.from_square).symbol()
        if move.promotion is not None :
          arr2[move.promotion-2, move.from_square%8, 9] = 1
        else :
          arr2[square//8, square%8, piece_to_value2[symbol]] = 1
          if(piece_to_value2[symbol] == 1 or piece_to_value2[symbol] == 3 or piece_to_value2[symbol] == 5) :
            piece_to_value2[symbol] += 1
      return (arr, arr2)
    else :
      arr = np.zeros((8,8,12), dtype = np.float16)
      if len(state) == 0 :
        return arr
      board = chess.Board(state)
      board.turn = color
      piece_to_value = self.get_piece_to_value(color)
      for i in range(64) :
        if(board.piece_at(i) is not None) :
          arr[i//8,i%8,piece_to_value[board.piece_at(i).symbol()]] = 1
      return arr
    
  def convert_output_to_probs(self, state, color, policy_output) :
    policy = tf.reshape(policy_output, [8,8,10])
    board = chess.Board(state.split("_")[-1])
    board.turn = color
    piece_to_value = self.get_piece_to_value(color,False)
    move_dict = {}
    for move in list(board.legal_moves) :
      to_square = move.to_square
      from_square = move.from_square
      piece_type = piece_to_value[board.piece_at(from_square).symbol()]
      if move.promotion is not None :
        move_dict[move.uci()] = policy[move.promotion-2, from_square%8, 9]
      else :
        move_dict[move.uci()] = policy[to_square//8, to_square%8, piece_type]
        if(piece_type == 1 or piece_type == 3 or piece_type == 5) :
          piece_to_value[board.piece_at(from_square).symbol()] += 1
   # print(list(board.legal_moves))
    move = [item[0] for item in sorted(move_dict.items(), key = lambda x: x[1], reverse = True)][0]
    print(move_dict[move])
    return move
  
  def get_piece_to_value(self, color, inp = True) :
    if inp :
      if(color == 1) :
        return {
        'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
        'p': 6, 'n':7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
        }
      return {
        'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
        'P': 6, 'N':7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11
      }
    else :
      if(color == 1) :
        return {
        'P': 0, 'N': 1, 'B': 3, 'R': 5, 'Q': 7, 'K': 8
        }
      return {
        'p': 0, 'n': 1, 'b': 3, 'r': 5, 'q': 7, 'k': 8
      }