MCTS
AlphaZero.MCTS
— ModuleA generic, standalone implementation of Monte Carlo Tree Search. It can be used on any game that implements GameInterface
and with any external oracle.
Both a synchronous and an asynchronous version are implemented, which share most of their code. When browsing the sources for the first time, we recommend that you study the sychronous version first.
Oracles
AlphaZero.MCTS.Oracle
— TypeMCTS.Oracle{Game}
Abstract base type for an oracle. Oracles must implement MCTS.evaluate
and MCTS.evaluate_batch
.
AlphaZero.MCTS.evaluate
— FunctionMCTS.evaluate(oracle::Oracle, board)
Evaluate a single board position (assuming white is playing).
Return a pair (P, V)
where:
P
is a probability vector onGI.available_actions(Game(board))
V
is a scalar estimating the value or win probability for white.
AlphaZero.MCTS.evaluate_batch
— FunctionMCTS.evaluate_batch(oracle::Oracle, boards)
Evaluate a batch of board positions.
Expect a vector of boards and return a vector of (P, V)
pairs.
A default implementation is provided that calls MCTS.evaluate
sequentially on each position.
AlphaZero.MCTS.RolloutOracle
— TypeMCTS.RolloutOracle{Game} <: MCTS.Oracle{Game}
This oracle estimates the value of a position by simulating a random game from it (a rollout). Moreover, it puts a uniform prior on available actions. Therefore, it can be used to implement the "vanilla" MCTS algorithm.
Environment
AlphaZero.MCTS.Env
— TypeMCTS.Env{Game}(oracle; <keyword args>) where Game
Create and initialize an MCTS environment with a given oracle
.
Keyword Arguments
nworkers=1
: numbers of asynchronous workers (see below)fill_batches=false
: if true, a constant batch size is enforced for evaluation requests, by completing batches with dummy entries if necessarycpuct=1.
: exploration constant in the UCT formulanoise_ϵ=0., noise_α=1.
: parameters for the dirichlet exploration noise (see below)
Asynchronous MCTS
If
nworkers == 1
, MCTS is run in a synchronous fashion and the oracle is invoked throughMCTS.evaluate
.If
nworkers > 1
,nworkers
asynchronous workers are spawned, along with an additional task to serve board evaluation requests. Such requests are processed by batches of sizenworkers
usingMCTS.evaluate_batch
.
Dirichlet Noise
A naive way to ensure exploration during training is to adopt an ϵ-greedy policy, playing a random move at every turn instead of using the policy prescribed by MCTS.policy
with probability ϵ. The problem with this naive strategy is that it may lead the player to make terrible moves at critical moments, thereby biasing the policy evaluation mechanism.
A superior alternative is to add a random bias to the neural prior for the root node during MCTS exploration: instead of considering the policy $p$ output by the neural network in the UCT formula, one uses $(1-ϵ)p + ϵη$ where $η$ is drawn once per call to MCTS.explore!
from a Dirichlet distribution of parameter $α$.
AlphaZero.MCTS.explore!
— FunctionMCTS.explore!(env, state, nsims)
Run nsims
MCTS simulations from state
.
AlphaZero.MCTS.policy
— FunctionMCTS.policy(env, state; τ=1.)
Return the recommended stochastic policy on state
, with temperature parameter equal to τ
. If τ
is zero, all the weight goes to the action with the highest visits count.
A call to this function must always be preceded by a call to MCTS.explore!
.
AlphaZero.MCTS.reset!
— FunctionMCTS.reset!(env)
Empty the MCTS tree.
Profiling Utilities
AlphaZero.MCTS.inference_time_ratio
— FunctionMCTS.inference_time_ratio(env)
Return the ratio of time spent by MCTS.explore!
on position evaluation (through functions MCTS.evaluate
or MCTS.evaluate_batch
) since the environment's creation.
AlphaZero.MCTS.memory_footprint_per_node
— FunctionMCTS.memory_footprint_per_node(env)
Return an estimate of the memory footprint of a single node of the MCTS tree (in bytes).
AlphaZero.MCTS.approximate_memory_footprint
— FunctionMCTS.approximate_memory_footprint(env)
Return an estimate of the memory footprint of the MCTS tree (in bytes).
AlphaZero.MCTS.average_exploration_depth
— FunctionMCTS.average_exploration_depth(env)
Return the average number of nodes that are traversed during an MCTS simulation, not counting the root.