Network Interface

Mandatory Interface

AlphaZero.Network.AbstractNetworkType
AbstractNetwork{Game} <: MCTS.Oracle{Game}

Abstract base type for a neural network.

Constructor

Any subtype Network must implement Base.copy along with the following constructor:

Network(hyperparams)

where the expected type of hyperparams is given by HyperParams(Network).

source
AlphaZero.Network.forwardFunction
forward(::AbstractNetwork, boards)

Compute the forward pass of a network on a batch of inputs.

Expect a Float32 tensor boards whose batch dimension is the last one.

Return a (P, V) triple where:

  • P is a matrix of size (num_actions, batch_size). It is allowed to put weight on invalid actions (see evaluate).
  • V is a row vector of size (1, batch_size)
source
AlphaZero.Network.train!Function
train!(callback, ::AbstractNetwork, opt::OptimiserSpec, loss, batches, n)

Update a given network to fit some data.

  • opt specifies which optimiser to use.
  • loss is a function that maps a batch of samples to a tracked real.
  • data is an iterator over minibatches.
  • n is the number of minibatches. If length is defined on data, we must have length(data) == n. However, not all finite iterators implement length and thus this argument is needed.
  • callback(i, loss) is called at each step with the batch number i and the loss on last batch.
source
AlphaZero.Network.set_test_mode!Function
set_test_mode!(mode=true)

Put a network in test mode or in training mode. This is relevant for networks featuring layers such as batch normalization layers.

source

Conversion and Copy

AlphaZero.Network.to_gpuFunction
to_gpu(::AbstractNetwork)

Return a copy of the given network that has been transferred to the GPU if one is available. Otherwise, return the given network untouched.

source
AlphaZero.Network.to_cpuFunction
to_cpu(::AbstractNetwork)

Return a copy of the given network that has been transferred to the CPU or return the given network untouched if it is already on CPU.

source
AlphaZero.Network.convert_inputFunction
convert_input(::AbstractNetwork, input)

Convert an array (or number) to the right format so that it can be used as an input by a given network.

source
AlphaZero.Network.convert_outputFunction
convert_output(::AbstractNetwork, output)

Convert an array (or number) produced by a neural network to a standard CPU array (or number) type.

source

Misc

Derived Functions

Evaluation Function

AlphaZero.Network.evaluateFunction
evaluate(network::AbstractNetwork, boards, actions_mask)

Evaluate a batch of board positions. This function is a wrapper on forward that puts a zero weight on invalid actions.

Arguments

  • boards is a tensor whose last dimension has size bach_size
  • actions_mask is a binary matrix of size (num_actions, batch_size)

Return

Return a (P, V, Pinv) triple where:

  • P is a matrix of size (num_actions, batch_size).
  • V is a row vector of size (1, batch_size).
  • Pinv is a row vector of size (1, batch_size) that indicates the total probability weight put by the network on invalid actions for each sample.

All tensors manipulated by this function have elements of type Float32.

source

Oracle Interface

All subtypes of AbstractNetwork implement the MCTS.Oracle interface through functions:

Since evaluating a neural network on single samples at a time is slow, the latter should be used whenever possible.

Misc

AlphaZero.Network.copyMethod
copy(::AbstractNetwork; on_gpu, test_mode)

A copy function that also handles CPU/GPU transfers and test/train mode switches.

source

Optimiser Specification

AlphaZero.Network.CyclicNesterovType
CyclicNesterov(; lr_base, lr_high, lr_low, momentum_low, momentum_high)

SGD optimiser with a cyclic learning rate and cyclic Nesterov momentum.

  • During an epoch, the learning rate goes from lr_low to lr_high and then back to lr_low.
  • The momentum evolves in the opposite way, from high values to low values and then back to high values.
source