# Network Interface

`AlphaZero.Network`

— ModuleA generic, framework agnostic interface for neural networks.

## Mandatory Interface

`AlphaZero.Network.AbstractNetwork`

— Type`AbstractNetwork`

Abstract base type for a neural network.

**Constructor**

Any subtype `Network`

must implement `Base.copy`

along with the following constructor:

`Network(game_spec, hyperparams)`

where the expected type of `hyperparams`

is given by `HyperParams(Network)`

.

`AlphaZero.Network.HyperParams`

— Function`HyperParams(::Type{<:AbstractNetwork})`

Return the hyperparameter type associated with a given network type.

`AlphaZero.Network.hyperparams`

— Function`hyperparams(::AbstractNetwork)`

Return the hyperparameters of a network.

`AlphaZero.Network.game_spec`

— Function`game_spec(::AbstractNetwork)`

Return the game specification that was passed to the network's constructor.

`AlphaZero.Network.forward`

— Function`forward(::AbstractNetwork, states)`

Compute the forward pass of a network on a batch of inputs.

Expect a `Float32`

tensor `states`

whose batch dimension is the last one.

Return a `(P, V)`

triple where:

`P`

is a matrix of size`(num_actions, batch_size)`

. It is allowed to put weight on invalid actions (see`evaluate`

).`V`

is a row vector of size`(1, batch_size)`

`AlphaZero.Network.train!`

— Function`train!(callback, ::AbstractNetwork, opt::OptimiserSpec, loss, batches, n)`

Update a given network to fit some data.

`opt`

specifies which optimiser to use.`loss`

is a function that maps a batch of samples to a tracked real.`data`

is an iterator over minibatches.`n`

is the number of minibatches. If`length`

is defined on`data`

, we must have`length(data) == n`

. However, not all finite iterators implement`length`

and thus this argument is needed.`callback(i, loss)`

is called at each step with the batch number`i`

and the loss on last batch.

`AlphaZero.Network.set_test_mode!`

— Function`set_test_mode!(mode=true)`

Put a network in test mode or in training mode. This is relevant for networks featuring layers such as batch normalization layers.

`AlphaZero.Network.params`

— Function`params(::AbstractNetwork)`

Return the collection of trainable parameters of a network.

`AlphaZero.Network.regularized_params`

— Function`regularized_params(::AbstractNetwork)`

Return the collection of regularized parameters of a network. This usually excludes neuron's biases.

### Conversion and Copy

`AlphaZero.Network.to_gpu`

— Function`to_gpu(::AbstractNetwork)`

Return a copy of the given network that has been transferred to the GPU if one is available. Otherwise, return the given network untouched.

`AlphaZero.Network.to_cpu`

— Function`to_cpu(::AbstractNetwork)`

Return a copy of the given network that has been transferred to the CPU or return the given network untouched if it is already on CPU.

`AlphaZero.Network.on_gpu`

— Function`on_gpu(::AbstractNetwork) :: Bool`

Test whether or not a network is located on GPU.

`AlphaZero.Network.convert_input`

— Function`convert_input(::AbstractNetwork, input)`

Convert an array (or number) to the right format so that it can be used as an input by a given network.

`AlphaZero.Network.convert_output`

— Function`convert_output(::AbstractNetwork, output)`

Convert an array (or number) produced by a neural network to a standard CPU array (or number) type.

### Misc

`AlphaZero.Network.gc`

— Function`gc(::AbstractNetwork)`

Perform full garbage collection and empty the GPU memory pool.

## Derived Functions

### Evaluation Functions

`AlphaZero.Network.forward_normalized`

— Function`forward_normalized(network::AbstractNetwork, states, actions_mask)`

Evaluate a batch of vectorized states. This function is a wrapper on `forward`

that puts a zero weight on invalid actions.

**Arguments**

`states`

is a tensor whose last dimension has size`bach_size`

`actions_mask`

is a binary matrix of size`(num_actions, batch_size)`

**Returned value**

Return a `(P, V, Pinv)`

triple where:

`P`

is a matrix of size`(num_actions, batch_size)`

.`V`

is a row vector of size`(1, batch_size)`

.`Pinv`

is a row vector of size`(1, batch_size)`

that indicates the total probability weight put by the network on invalid actions for each sample.

All tensors manipulated by this function have elements of type `Float32`

.

`AlphaZero.Network.evaluate`

— Function```
evaluate(::AbstractNetwork, state)
(nn::AbstractNetwork)(state) = evaluate(nn, state)
```

Evaluate the neural network as an MCTS oracle on a single state.

Note, however, that evaluating state positions once at a time is slow and so you may want to use a `BatchedOracle`

along with an inference server that uses `evaluate_batch`

.

`AlphaZero.Network.evaluate_batch`

— Function`evaluate_batch(::AbstractNetwork, batch)`

Evaluate the neural network as an MCTS oracle on a batch of states at once.

Take a list of states as input and return a list of `(P, V)`

pairs as defined in the MCTS oracle interface.

### Misc

`AlphaZero.Network.num_parameters`

— Function`num_parameters(::AbstractNetwork)`

Return the total number of parameters of a network.

`AlphaZero.Network.num_regularized_parameters`

— Function`num_regularized_parameters(::AbstractNetwork)`

Return the total number of regularized parameters of a network.

`AlphaZero.Network.mean_weight`

— Function`mean_weight(::AbstractNetwork)`

Return the mean absolute value of the regularized parameters of a network.

`AlphaZero.Network.copy`

— Method`copy(::AbstractNetwork; on_gpu, test_mode)`

A copy function that also handles CPU/GPU transfers and test/train mode switches.

### Optimiser Specification

`AlphaZero.Network.OptimiserSpec`

— Type`OptimiserSpec`

Abstract type for an optimiser specification.

`AlphaZero.Network.CyclicNesterov`

— Type`CyclicNesterov(; lr_base, lr_high, lr_low, momentum_low, momentum_high)`

SGD optimiser with a cyclic learning rate and cyclic Nesterov momentum.

- During an epoch, the learning rate goes from
`lr_low`

to`lr_high`

and then back to`lr_low`

. - The momentum evolves in the opposite way, from high values to low values and then back to high values.

`AlphaZero.Network.Adam`

— Type`Adam(;lr)`

Adam optimiser.