How does a neural network learn

How does a neural network learn

mediumThis post was originally published by Gerry Chng at Medium [AI]

Each neuron will have both a linear function of the form z = wᵢxᵢ + b, with i representing values from 1 to 4096 in our example above. Each of this result is then passed to a non-linear activation function g(z).

This results in a vector with shape (8, 209) in our example, which will be fed into the output layer. The output layer in our example is a single node that performs a binary classification.

Let’s now walk through the key concepts to understand how the neural network learns through this architecture.

Linear functions and non-linear activation functions

If we look at each of the neuron, we see that there is both a linear function represented by z = wx + b and a non-linear activation function. This simple set-up allows for non-linear fitting, whilst keeping the functions simple.

As you will see later when we discuss back-propagation, we use derivatives to minimize the prediction loss, and adjust the weights and biases accordingly. This simple use of a linear and non-linear function allows us to achieve the desired non-linearity while retaining the ability to use the chain rule to easily compute the derivatives. But we’ll come to that later.

We discussed using the sigmoid function as an activation function in Part 1. This function is defined as:

The sigmoid function is not often used as there are more efficient implementations using the ReLU (Rectified Linear Unit) to be covered later. However, due to the nature that its output is constrained between 0 and 1, it is still useful for binary classification tasks.

Functions like this also suffer from what is commonly known as the vanishing gradient problem. The slope of function approaches 0 at extremes of the input values. If you recall, we will be using derivatives during back propagation to minimise the loss parameter. These activation functions present a challenge when the input value is large, resulting in a derivative (slope) close to zero. This results in very small deltas, meaning very slow learning!

This is partly the reason why we normally standardise our inputs to values that are typically between -1 and 1. When I first started learning this topic, I had no idea why this was needed … a number is a number right? Most literature only indicates that the network will perform better when the inputs are scaled to a number either between -1 and 1 or 0 and 1 depending on the scenario.

Understanding the vanishing gradient problem gives some insights into why standardising the inputs is recommended.

A more common activation function used today is the Rectified Linear Unit (ReLU). Mathematically, ReLU can be defined as:

Fun fact: Technically, the gradient is non-existent when x is exactly equal to zero. But the probability of this happening is very rare, and you can programatically just set the gradient to 0 if that ever happens. Not a real practical issue.

Programmatically, it can easily be defined as the max of 0 and the input value.

g(z) = max(0, input)

Another side benefit of this activation function during back propagation is that the slope of this function is simply 1 for any non-zero values. Though the vanishing gradient problem is avoided for positive values, it can still result in what is known as a “Dying ReLU” if there is a large negative bias on one of the neurons (effectively, the neuron will be stuck on the negative side and always output a value of 0). This neuron is effectively “dead” as it no longer plays a part in the fitting process.

Some implementations use a hybrid approach known as a Leaky ReLU to solve this issue, where:

For this series, we will only use ReLU for the hidden layers, and the sigmoid function for the output node (again — only because we are doing binary classification). If you are interested to read more on the other hybrid ReLU activation functions, there is an interesting Medium post by Liu Danqing below:

The network’s first attempt

Prior to the first training run, the weights matrix is initialised to a very small random number.

It is important that the weights start off with a small random number. If all the weights were initialized with zeros, all the outputs would be zero because of z = wx + b. Having equal weights will not work as well as the back propagated errors will all be equal, resulting in all the weights being updated by the same amount.

The bias vector can be initialized to 0 in the first run.

Clearly, the network’s first attempt is really nothing more than a random guess!

Forward propagation

Forward propagation is the process of taking the inputs, weights, bias, and activation function to compute the values at each stage. For example, in our single-layer architecture, the computation will happen in this order:

  • Compute Z1
    Remember that z = wx + b, we will essentially do the following matrix multiplication. The equation below shows the matrix multiplication for one training example, but Z will be a matrix with shape (8, 209) after applying the entire training set.
  • Compute A1 (activation function)
    We are using ReLU on the hidden layer, so A1 will simply be computed as the max of 0 and the current value of Z1 (i.e. any negative numbers just become 0).
  • Feed A1 to the output layer
    The matrix A1(also with shape of (8, 209)) will now be fed into the output layer as inputs.
  • Compute Z[L]
    In this final output layer, the weights matrix has a shape of (1, 8). Another matrix multiplication will result in a final output with a shape of (1, 209), which is then passed into a sigmoid function.
  • Compute A[L] (final prediction)
    We pass the computed value to the sigmoid function, which will result in a value between 0 and 1 that is returned.

Compute the loss function

The prediction from this run is based on the model’s weights and biases, and the output either a 1 or 0 for this binary classification example. This is compared with the ground truth which is found in the labels that we extracted as part of the training set (read into the train_set_y variable).

For binary classification, the loss function is used for each training example is as follows.

Intuitively, you can see how this achieves what we need for cases where the ground truth, y = 0 or 1. It helps to remember that the log curves cut the x-axis at x = 1.

The cost function is the average of all the losses over the training set, and this is the equation that we want to minimize (i.e. take the derivative of this with respect to the other parameters).

Using back propagation to adjust the weights and biases

The network will use a concept of back propagation to adjust the weights and biases. To understand this intuitively, you need to consider this.

  • We want to minimise the cost function with respect to the trainable parameters, i.e. what should the W and b be adjusted to? This diagram shows the functions used in the forward propagation, and the corresponding derivatives.
  • To minimise a parameter, we take the derivative, which in this case, means we want to obtain the following:

The proof of this derivative is outside the scope of this article, though some other writers have done a full article on it. I enjoyed reading the details of the derivation provided by Patrick David in this article:

  • We use the chain-rule of derivatives to compute the following:
  • We then update each of these weights and biases:

Where α is the hyperparameter for the learning rate. We will cover the details separately, but it defines how much to nudge the adjustments. Too small, and the learning process will be slow. Too large, and you might overshoot the local minimum. There’s an art and science to tuning this.

Repeat

The process above then repeats itself for the number of iterations that you have defined. This variable is known as a hyperparameter of the model, similar to the learning rate that you encountered above. These are not trainable parameters, and it is both an art and science to tune these hyperparameters. We will cover more of this in a later article.

Conclusion

This is a very heavy topic, but I hope that it highlights at a high-level how a neural network learns through forward propagation, and subsequently nudges the weights and biases by taking the derivatives of the equations with respect to the cost function in the back propagation.

It is mainly through this back propagation that the weights and biases are slowly nudged along to a value that will minimize the cost, which effectively maximizes the accuracy. That is why you always hear people say that the training neural networks is about minimizing the cost, rather than aiming to maximize the accuracy.

If you do not understand the concepts the first time round, do not despair. It took me a while to understand back propagation, and I must admit that I am still discovering new things while writing this article.

Writing these articles help me firm up my understanding, and I hope that you have learnt something from it.

Spread the word

This post was originally published by Gerry Chng at Medium [AI]

Related posts