🐹

【ML】Understanding Batch Gradient Descent

2024/09/23に公開

1. what is batch gradient descent?

Batch gradient descent is a technique used in training machine learning models by dividing the data into multiple batches.
Simply put, a specified number of data (e.g., 32) are passed through the model, the gradient is calculated, and the average of the gradient information is used to update the model weights.

With batch gradient descent, it is important to choose an appropriate batch size (the number of each piece of data to be split). An extreme example is shown below.

1.1 Full Batch Learning

Full batch training means that the batch size is equal to the total data.
In this case, all data is used and then the gradient is updated only once.

The disadvantages of this technique are as follows
Increased memory consumption: Need to store gradient information for all data.
Less frequent updating of weights: Since the weights are updated only once per epoch, learning takes longer.

1.2 Online Learning

Online learning (sequential learning) refers to learning with a batch size of 1.
In this case, the model weights are updated each time one data pass.

The disadvantages of this method are as follows
Unstable gradient: The gradient changes significantly with each update, which can lead to unstable learning and oscillations.
Computational inefficiency: The GPU advantage of parallel computation cannot be taken advantage of when performing each update, resulting in slow learning.
Batch regularization is not available: Batch regularization, which is very effective in reducing overlearning and improving model performance, is not available.

However, unlike full batch learning, online learning can be used in some situations, such as
Real-time sequential learning: When new data are being generated sequentially and immediate reflection of them is desired (online learning or streaming data), batch size 1 learning is appropriate. In this case, sequential updating is effective because data is being added rapidly.
Memory savings: If the model is too large to fit in memory, training with a batch size of 1 may be used.

1.3 Summary

From the above, we can see that, in general, it is important to set an appropriate batch size, neither too large nor too small. In practice, I have the impression that 32~256 is often seen.

2. purpose

If gradient information is computed every time data is trained without using batch, the following two main problems are encountered

  • Increased learning cost and time
    Computing gradient information by back propagation is computationally demanding, and if it is done every time data is passed through, the learning time increases greatly.
  • Locally optimal solution
    Updating the weights each time increases the likelihood that the model weights will be optimized for each piece of data, leading to a locally optimal solution.

To avoid these problems, batch learning is used.

3. Description

Batch optimization is a method in which the act of “updating the model's gradient information” is performed using “the average of the data inside the batch”.

The purpose of a machine learning model is to “minimize the loss function in real-world data,” and to achieve this, it is necessary to collect all the data in the real-world environment. However, this is not realistic, so the idea of batch learning is to try as many patterns of data as possible.

The shape of the loss function changes depending on the data used to update the weights. Here, the average of a certain number of training data is used as the data for updating, and by repeating this process, the number of data patterns can be greatly increased and over-fitting to individual data can be prevented.
Since the shape of the loss function changes with each training, it is difficult to fall into a local optimum solution, and since the average of multiple data is taken, model parameters with high generalization performance can be obtained.
This is a major advantage of using batch learning.

4. Summary

Batch gradient descent is a method used to improve model performance, generalizability, and computational efficiency, and it is important to determine the appropriate size depending on memory allowances and other factors.
Personally, I have the impression that raising the size as much as possible up to about 256 improves both computation speed and generalizability, except in cases where the data size is small.

That's all for this issue. Thank you for reading.

Discussion