Tic-Tac-Toe-MCTS / TicTacToeGame.py
Yukino666's picture
e4
97b3fb2
import math
import random
import sys
class TicTacToeState:
def __init__(self, board=None, player=1):
"""
board: 長度為 9 的 list, 0 表示空, 1 表示 X, -1 表示 O
player: 表示目前該誰下,1 = X, -1 = O
"""
if board is None:
self.board = [0] * 9
else:
self.board = board[:]
self.player = player # 1 or -1
def get_legal_moves(self):
return [i for i, val in enumerate(self.board) if val == 0]
def is_terminal(self):
"""
檢查是否已分出勝負或和局
"""
win_patterns = [
[0,1,2], [3,4,5], [6,7,8], # 橫
[0,3,6], [1,4,7], [2,5,8], # 直
[0,4,8], [2,4,6] # 斜
]
for pattern in win_patterns:
vals = [self.board[i] for i in pattern]
if abs(sum(vals)) == 3: # sum = 3 => X贏, -3 => O贏
return True
if all(val != 0 for val in self.board):
return True # 和局
return False
def get_winner(self):
"""
回傳 1 (X), -1 (O), 或 0 (和局 / 未分勝負就回傳 None)
"""
win_patterns = [
[0,1,2], [3,4,5], [6,7,8],
[0,3,6], [1,4,7], [2,5,8],
[0,4,8], [2,4,6]
]
for pattern in win_patterns:
vals = [self.board[i] for i in pattern]
s = sum(vals)
if s == 3:
return 1 # X wins
if s == -3:
return -1 # O wins
if all(val != 0 for val in self.board):
return 0 # tie
return None # 尚未結束
def move(self, idx):
"""
在 idx 位置落子,回傳新的狀態
"""
new_board = self.board[:]
new_board[idx] = self.player
return TicTacToeState(new_board, -self.player) # 輪流下
class MCTSNode:
def __init__(self, state, parent=None):
self.state = state
self.parent = parent
self.children = []
self.wins = 0
self.visits = 0
self.untried_moves = state.get_legal_moves()
def is_fully_expanded(self):
return len(self.untried_moves) == 0
def best_child(self, c=1.4):
"""
使用 UCB1 公式選擇子節點
UCB = Q + c * sqrt( ln(N) / n )
"""
best = None
best_value = -float('inf')
for child in self.children:
q = child.wins / (child.visits + 1e-7)
ucb = q + c * math.sqrt(
math.log(self.visits + 1) / (child.visits + 1e-7)
)
if ucb > best_value:
best_value = ucb
best = child
return best
def tree_policy(node):
"""
Selection + Expansion
"""
while not node.state.is_terminal():
if not node.is_fully_expanded():
return expand(node)
else:
node = node.best_child()
return node
def expand(node):
"""
從 node.untried_moves 中擴展一個子節點
"""
move = node.untried_moves.pop()
next_state = node.state.move(move)
child_node = MCTSNode(next_state, parent=node)
node.children.append(child_node)
return child_node
def default_policy(state):
"""
Simulation:從當前節點狀態開始隨機走到結束,回傳勝者
"""
current_state = state
while not current_state.is_terminal():
moves = current_state.get_legal_moves()
move = random.choice(moves)
current_state = current_state.move(move)
return current_state.get_winner()
def backup(node, winner):
"""
Backpropagation
"""
while node is not None:
node.visits += 1
# 這裡的判斷方式:若 winner 與該節點「上一手玩家」一致,wins+1
# node.state.player 是下一手該下的玩家,所以上一手是 -node.state.player
if winner == 0:
node.wins += 0.5
if winner == -node.state.player:
node.wins += 1
node = node.parent
def mcts(root_state, iter_count=2000):
"""
執行多次 MCTS 模擬,回傳在根節點下最好的子節點
"""
root_node = MCTSNode(root_state, parent=None)
for _ in range(iter_count):
# 1. Selection & Expansion
node = tree_policy(root_node)
# 2. Simulation
winner = default_policy(node.state)
# 3. Backpropagation
backup(node, winner)
# 最後從根節點的所有子節點中,選擇拜訪次數最多的(或勝率最高的)
best_child_node = max(root_node.children, key=lambda c: c.visits)
return best_child_node