Network Interface
AlphaZero.Network — ModuleA generic, framework agnostic interface for neural networks.
Mandatory Interface
AlphaZero.Network.AbstractNetwork — TypeAbstractNetworkAbstract base type for a neural network.
Constructor
Any subtype Network must implement Base.copy along with the following constructor:
Network(game_spec, hyperparams)where the expected type of hyperparams is given by HyperParams(Network).
AlphaZero.Network.HyperParams — FunctionHyperParams(::Type{<:AbstractNetwork})Return the hyperparameter type associated with a given network type.
AlphaZero.Network.hyperparams — Functionhyperparams(::AbstractNetwork)Return the hyperparameters of a network.
AlphaZero.Network.game_spec — Functiongame_spec(::AbstractNetwork)Return the game specification that was passed to the network's constructor.
AlphaZero.Network.forward — Functionforward(::AbstractNetwork, states)Compute the forward pass of a network on a batch of inputs.
Expect a Float32 tensor states whose batch dimension is the last one.
Return a (P, V) triple where:
Pis a matrix of size(num_actions, batch_size). It is allowed to put weight on invalid actions (seeevaluate).Vis a row vector of size(1, batch_size)
AlphaZero.Network.train! — Functiontrain!(callback, ::AbstractNetwork, opt::OptimiserSpec, loss, batches, n)Update a given network to fit some data.
optspecifies which optimiser to use.lossis a function that maps a batch of samples to a tracked real.datais an iterator over minibatches.nis the number of minibatches. Iflengthis defined ondata, we must havelength(data) == n. However, not all finite iterators implementlengthand thus this argument is needed.callback(i, loss)is called at each step with the batch numberiand the loss on last batch.
AlphaZero.Network.set_test_mode! — Functionset_test_mode!(mode=true)Put a network in test mode or in training mode. This is relevant for networks featuring layers such as batch normalization layers.
AlphaZero.Network.params — Functionparams(::AbstractNetwork)Return the collection of trainable parameters of a network.
AlphaZero.Network.regularized_params — Functionregularized_params(::AbstractNetwork)Return the collection of regularized parameters of a network. This usually excludes neuron's biases.
Conversion and Copy
AlphaZero.Network.to_gpu — Functionto_gpu(::AbstractNetwork)Return a copy of the given network that has been transferred to the GPU if one is available. Otherwise, return the given network untouched.
AlphaZero.Network.to_cpu — Functionto_cpu(::AbstractNetwork)Return a copy of the given network that has been transferred to the CPU or return the given network untouched if it is already on CPU.
AlphaZero.Network.on_gpu — Functionon_gpu(::AbstractNetwork) :: BoolTest whether or not a network is located on GPU.
AlphaZero.Network.convert_input — Functionconvert_input(::AbstractNetwork, input)Convert an array (or number) to the right format so that it can be used as an input by a given network.
AlphaZero.Network.convert_output — Functionconvert_output(::AbstractNetwork, output)Convert an array (or number) produced by a neural network to a standard CPU array (or number) type.
Misc
AlphaZero.Network.gc — Functiongc(::AbstractNetwork)Perform full garbage collection and empty the GPU memory pool.
Derived Functions
Evaluation Functions
AlphaZero.Network.forward_normalized — Functionforward_normalized(network::AbstractNetwork, states, actions_mask)Evaluate a batch of vectorized states. This function is a wrapper on forward that puts a zero weight on invalid actions.
Arguments
statesis a tensor whose last dimension has sizebach_sizeactions_maskis a binary matrix of size(num_actions, batch_size)
Returned value
Return a (P, V, Pinv) triple where:
Pis a matrix of size(num_actions, batch_size).Vis a row vector of size(1, batch_size).Pinvis a row vector of size(1, batch_size)that indicates the total probability weight put by the network on invalid actions for each sample.
All tensors manipulated by this function have elements of type Float32.
AlphaZero.Network.evaluate — Functionevaluate(::AbstractNetwork, state)
(nn::AbstractNetwork)(state) = evaluate(nn, state)Evaluate the neural network as an MCTS oracle on a single state.
Note, however, that evaluating state positions once at a time is slow and so you may want to use a BatchedOracle along with an inference server that uses evaluate_batch.
AlphaZero.Network.evaluate_batch — Functionevaluate_batch(::AbstractNetwork, batch)Evaluate the neural network as an MCTS oracle on a batch of states at once.
Take a list of states as input and return a list of (P, V) pairs as defined in the MCTS oracle interface.
Misc
AlphaZero.Network.num_parameters — Functionnum_parameters(::AbstractNetwork)Return the total number of parameters of a network.
AlphaZero.Network.num_regularized_parameters — Functionnum_regularized_parameters(::AbstractNetwork)Return the total number of regularized parameters of a network.
AlphaZero.Network.mean_weight — Functionmean_weight(::AbstractNetwork)Return the mean absolute value of the regularized parameters of a network.
AlphaZero.Network.copy — Methodcopy(::AbstractNetwork; on_gpu, test_mode)A copy function that also handles CPU/GPU transfers and test/train mode switches.
Optimiser Specification
AlphaZero.Network.OptimiserSpec — TypeOptimiserSpecAbstract type for an optimiser specification.
AlphaZero.Network.CyclicNesterov — TypeCyclicNesterov(; lr_base, lr_high, lr_low, momentum_low, momentum_high)SGD optimiser with a cyclic learning rate and cyclic Nesterov momentum.
- During an epoch, the learning rate goes from
lr_lowtolr_highand then back tolr_low. - The momentum evolves in the opposite way, from high values to low values and then back to high values.
AlphaZero.Network.Adam — TypeAdam(;lr)Adam optimiser.