Spaces:
Sleeping
Sleeping
import math | |
import random | |
import sys | |
class TicTacToeState: | |
def __init__(self, board=None, player=1): | |
""" | |
board: 長度為 9 的 list, 0 表示空, 1 表示 X, -1 表示 O | |
player: 表示目前該誰下,1 = X, -1 = O | |
""" | |
if board is None: | |
self.board = [0] * 9 | |
else: | |
self.board = board[:] | |
self.player = player # 1 or -1 | |
def get_legal_moves(self): | |
return [i for i, val in enumerate(self.board) if val == 0] | |
def is_terminal(self): | |
""" | |
檢查是否已分出勝負或和局 | |
""" | |
win_patterns = [ | |
[0,1,2], [3,4,5], [6,7,8], # 橫 | |
[0,3,6], [1,4,7], [2,5,8], # 直 | |
[0,4,8], [2,4,6] # 斜 | |
] | |
for pattern in win_patterns: | |
vals = [self.board[i] for i in pattern] | |
if abs(sum(vals)) == 3: # sum = 3 => X贏, -3 => O贏 | |
return True | |
if all(val != 0 for val in self.board): | |
return True # 和局 | |
return False | |
def get_winner(self): | |
""" | |
回傳 1 (X), -1 (O), 或 0 (和局 / 未分勝負就回傳 None) | |
""" | |
win_patterns = [ | |
[0,1,2], [3,4,5], [6,7,8], | |
[0,3,6], [1,4,7], [2,5,8], | |
[0,4,8], [2,4,6] | |
] | |
for pattern in win_patterns: | |
vals = [self.board[i] for i in pattern] | |
s = sum(vals) | |
if s == 3: | |
return 1 # X wins | |
if s == -3: | |
return -1 # O wins | |
if all(val != 0 for val in self.board): | |
return 0 # tie | |
return None # 尚未結束 | |
def move(self, idx): | |
""" | |
在 idx 位置落子,回傳新的狀態 | |
""" | |
new_board = self.board[:] | |
new_board[idx] = self.player | |
return TicTacToeState(new_board, -self.player) # 輪流下 | |
class MCTSNode: | |
def __init__(self, state, parent=None): | |
self.state = state | |
self.parent = parent | |
self.children = [] | |
self.wins = 0 | |
self.visits = 0 | |
self.untried_moves = state.get_legal_moves() | |
def is_fully_expanded(self): | |
return len(self.untried_moves) == 0 | |
def best_child(self, c=1.4): | |
""" | |
使用 UCB1 公式選擇子節點 | |
UCB = Q + c * sqrt( ln(N) / n ) | |
""" | |
best = None | |
best_value = -float('inf') | |
for child in self.children: | |
q = child.wins / (child.visits + 1e-7) | |
ucb = q + c * math.sqrt( | |
math.log(self.visits + 1) / (child.visits + 1e-7) | |
) | |
if ucb > best_value: | |
best_value = ucb | |
best = child | |
return best | |
def tree_policy(node): | |
""" | |
Selection + Expansion | |
""" | |
while not node.state.is_terminal(): | |
if not node.is_fully_expanded(): | |
return expand(node) | |
else: | |
node = node.best_child() | |
return node | |
def expand(node): | |
""" | |
從 node.untried_moves 中擴展一個子節點 | |
""" | |
move = node.untried_moves.pop() | |
next_state = node.state.move(move) | |
child_node = MCTSNode(next_state, parent=node) | |
node.children.append(child_node) | |
return child_node | |
def default_policy(state): | |
""" | |
Simulation:從當前節點狀態開始隨機走到結束,回傳勝者 | |
""" | |
current_state = state | |
while not current_state.is_terminal(): | |
moves = current_state.get_legal_moves() | |
move = random.choice(moves) | |
current_state = current_state.move(move) | |
return current_state.get_winner() | |
def backup(node, winner): | |
""" | |
Backpropagation | |
""" | |
while node is not None: | |
node.visits += 1 | |
# 這裡的判斷方式:若 winner 與該節點「上一手玩家」一致,wins+1 | |
# node.state.player 是下一手該下的玩家,所以上一手是 -node.state.player | |
if winner == 0: | |
node.wins += 0.5 | |
if winner == -node.state.player: | |
node.wins += 1 | |
node = node.parent | |
def mcts(root_state, iter_count=2000): | |
""" | |
執行多次 MCTS 模擬,回傳在根節點下最好的子節點 | |
""" | |
root_node = MCTSNode(root_state, parent=None) | |
for _ in range(iter_count): | |
# 1. Selection & Expansion | |
node = tree_policy(root_node) | |
# 2. Simulation | |
winner = default_policy(node.state) | |
# 3. Backpropagation | |
backup(node, winner) | |
# 最後從根節點的所有子節點中,選擇拜訪次數最多的(或勝率最高的) | |
best_child_node = max(root_node.children, key=lambda c: c.visits) | |
return best_child_node | |