Network Interface

Mandatory Interface

AlphaZero.Network.AbstractNetworkType
AbstractNetwork

Abstract base type for a neural network.

Constructor

Any subtype Network must implement Base.copy along with the following constructor:

Network(game_spec, hyperparams)

where the expected type of hyperparams is given by HyperParams(Network).

source
AlphaZero.Network.forwardFunction
forward(::AbstractNetwork, states)

Compute the forward pass of a network on a batch of inputs.

Expect a Float32 tensor states whose batch dimension is the last one.

Return a (P, V) triple where:

  • P is a matrix of size (num_actions, batch_size). It is allowed to put weight on invalid actions (see evaluate).
  • V is a row vector of size (1, batch_size)
source
AlphaZero.Network.train!Function
train!(callback, ::AbstractNetwork, opt::OptimiserSpec, loss, batches, n)

Update a given network to fit some data.

  • opt specifies which optimiser to use.
  • loss is a function that maps a batch of samples to a tracked real.
  • data is an iterator over minibatches.
  • n is the number of minibatches. If length is defined on data, we must have length(data) == n. However, not all finite iterators implement length and thus this argument is needed.
  • callback(i, loss) is called at each step with the batch number i and the loss on last batch.
source
AlphaZero.Network.set_test_mode!Function
set_test_mode!(mode=true)

Put a network in test mode or in training mode. This is relevant for networks featuring layers such as batch normalization layers.

source

Conversion and Copy

AlphaZero.Network.to_gpuFunction
to_gpu(::AbstractNetwork)

Return a copy of the given network that has been transferred to the GPU if one is available. Otherwise, return the given network untouched.

source
AlphaZero.Network.to_cpuFunction
to_cpu(::AbstractNetwork)

Return a copy of the given network that has been transferred to the CPU or return the given network untouched if it is already on CPU.

source
AlphaZero.Network.convert_inputFunction
convert_input(::AbstractNetwork, input)

Convert an array (or number) to the right format so that it can be used as an input by a given network.

source
AlphaZero.Network.convert_outputFunction
convert_output(::AbstractNetwork, output)

Convert an array (or number) produced by a neural network to a standard CPU array (or number) type.

source

Misc

Derived Functions

Evaluation Functions

AlphaZero.Network.forward_normalizedFunction
forward_normalized(network::AbstractNetwork, states, actions_mask)

Evaluate a batch of vectorized states. This function is a wrapper on forward that puts a zero weight on invalid actions.

Arguments

  • states is a tensor whose last dimension has size bach_size
  • actions_mask is a binary matrix of size (num_actions, batch_size)

Returned value

Return a (P, V, Pinv) triple where:

  • P is a matrix of size (num_actions, batch_size).
  • V is a row vector of size (1, batch_size).
  • Pinv is a row vector of size (1, batch_size) that indicates the total probability weight put by the network on invalid actions for each sample.

All tensors manipulated by this function have elements of type Float32.

source
AlphaZero.Network.evaluateFunction
evaluate(::AbstractNetwork, state)

(nn::AbstractNetwork)(state) = evaluate(nn, state)

Evaluate the neural network as an MCTS oracle on a single state.

Note, however, that evaluating state positions once at a time is slow and so you may want to use a BatchedOracle along with an inference server that uses evaluate_batch.

source
AlphaZero.Network.evaluate_batchFunction
evaluate_batch(::AbstractNetwork, batch)

Evaluate the neural network as an MCTS oracle on a batch of states at once.

Take a list of states as input and return a list of (P, V) pairs as defined in the MCTS oracle interface.

source

Misc

AlphaZero.Network.copyMethod
copy(::AbstractNetwork; on_gpu, test_mode)

A copy function that also handles CPU/GPU transfers and test/train mode switches.

source

Optimiser Specification

AlphaZero.Network.CyclicNesterovType
CyclicNesterov(; lr_base, lr_high, lr_low, momentum_low, momentum_high)

SGD optimiser with a cyclic learning rate and cyclic Nesterov momentum.

  • During an epoch, the learning rate goes from lr_low to lr_high and then back to lr_low.
  • The momentum evolves in the opposite way, from high values to low values and then back to high values.
source