Network Interface
AlphaZero.Network
— ModuleA generic, framework agnostic interface for neural networks.
Mandatory Interface
AlphaZero.Network.AbstractNetwork
— TypeAbstractNetwork
Abstract base type for a neural network.
Constructor
Any subtype Network
must implement Base.copy
along with the following constructor:
Network(game_spec, hyperparams)
where the expected type of hyperparams
is given by HyperParams(Network)
.
AlphaZero.Network.HyperParams
— FunctionHyperParams(::Type{<:AbstractNetwork})
Return the hyperparameter type associated with a given network type.
AlphaZero.Network.hyperparams
— Functionhyperparams(::AbstractNetwork)
Return the hyperparameters of a network.
AlphaZero.Network.game_spec
— Functiongame_spec(::AbstractNetwork)
Return the game specification that was passed to the network's constructor.
AlphaZero.Network.forward
— Functionforward(::AbstractNetwork, states)
Compute the forward pass of a network on a batch of inputs.
Expect a Float32
tensor states
whose batch dimension is the last one.
Return a (P, V)
triple where:
P
is a matrix of size(num_actions, batch_size)
. It is allowed to put weight on invalid actions (seeevaluate
).V
is a row vector of size(1, batch_size)
AlphaZero.Network.train!
— Functiontrain!(callback, ::AbstractNetwork, opt::OptimiserSpec, loss, batches, n)
Update a given network to fit some data.
opt
specifies which optimiser to use.loss
is a function that maps a batch of samples to a tracked real.data
is an iterator over minibatches.n
is the number of minibatches. Iflength
is defined ondata
, we must havelength(data) == n
. However, not all finite iterators implementlength
and thus this argument is needed.callback(i, loss)
is called at each step with the batch numberi
and the loss on last batch.
AlphaZero.Network.set_test_mode!
— Functionset_test_mode!(mode=true)
Put a network in test mode or in training mode. This is relevant for networks featuring layers such as batch normalization layers.
AlphaZero.Network.params
— Functionparams(::AbstractNetwork)
Return the collection of trainable parameters of a network.
AlphaZero.Network.regularized_params
— Functionregularized_params(::AbstractNetwork)
Return the collection of regularized parameters of a network. This usually excludes neuron's biases.
Conversion and Copy
AlphaZero.Network.to_gpu
— Functionto_gpu(::AbstractNetwork)
Return a copy of the given network that has been transferred to the GPU if one is available. Otherwise, return the given network untouched.
AlphaZero.Network.to_cpu
— Functionto_cpu(::AbstractNetwork)
Return a copy of the given network that has been transferred to the CPU or return the given network untouched if it is already on CPU.
AlphaZero.Network.on_gpu
— Functionon_gpu(::AbstractNetwork) :: Bool
Test whether or not a network is located on GPU.
AlphaZero.Network.convert_input
— Functionconvert_input(::AbstractNetwork, input)
Convert an array (or number) to the right format so that it can be used as an input by a given network.
AlphaZero.Network.convert_output
— Functionconvert_output(::AbstractNetwork, output)
Convert an array (or number) produced by a neural network to a standard CPU array (or number) type.
Misc
AlphaZero.Network.gc
— Functiongc(::AbstractNetwork)
Perform full garbage collection and empty the GPU memory pool.
Derived Functions
Evaluation Functions
AlphaZero.Network.forward_normalized
— Functionforward_normalized(network::AbstractNetwork, states, actions_mask)
Evaluate a batch of vectorized states. This function is a wrapper on forward
that puts a zero weight on invalid actions.
Arguments
states
is a tensor whose last dimension has sizebach_size
actions_mask
is a binary matrix of size(num_actions, batch_size)
Returned value
Return a (P, V, Pinv)
triple where:
P
is a matrix of size(num_actions, batch_size)
.V
is a row vector of size(1, batch_size)
.Pinv
is a row vector of size(1, batch_size)
that indicates the total probability weight put by the network on invalid actions for each sample.
All tensors manipulated by this function have elements of type Float32
.
AlphaZero.Network.evaluate
— Functionevaluate(::AbstractNetwork, state)
(nn::AbstractNetwork)(state) = evaluate(nn, state)
Evaluate the neural network as an MCTS oracle on a single state.
Note, however, that evaluating state positions once at a time is slow and so you may want to use a BatchedOracle
along with an inference server that uses evaluate_batch
.
AlphaZero.Network.evaluate_batch
— Functionevaluate_batch(::AbstractNetwork, batch)
Evaluate the neural network as an MCTS oracle on a batch of states at once.
Take a list of states as input and return a list of (P, V)
pairs as defined in the MCTS oracle interface.
Misc
AlphaZero.Network.num_parameters
— Functionnum_parameters(::AbstractNetwork)
Return the total number of parameters of a network.
AlphaZero.Network.num_regularized_parameters
— Functionnum_regularized_parameters(::AbstractNetwork)
Return the total number of regularized parameters of a network.
AlphaZero.Network.mean_weight
— Functionmean_weight(::AbstractNetwork)
Return the mean absolute value of the regularized parameters of a network.
AlphaZero.Network.copy
— Methodcopy(::AbstractNetwork; on_gpu, test_mode)
A copy function that also handles CPU/GPU transfers and test/train mode switches.
Optimiser Specification
AlphaZero.Network.OptimiserSpec
— TypeOptimiserSpec
Abstract type for an optimiser specification.
AlphaZero.Network.CyclicNesterov
— TypeCyclicNesterov(; lr_base, lr_high, lr_low, momentum_low, momentum_high)
SGD optimiser with a cyclic learning rate and cyclic Nesterov momentum.
- During an epoch, the learning rate goes from
lr_low
tolr_high
and then back tolr_low
. - The momentum evolves in the opposite way, from high values to low values and then back to high values.
AlphaZero.Network.Adam
— TypeAdam(;lr)
Adam optimiser.