Long Short Term Memory
How LSTM Mitigated the Vanishing Gradients But Not the Exploding Gradients
In theory, RNNs (Recurrent Neural Networks) should extract features (hidden states) from long sequential data. In reality, researchers had a hard time training the basic RNNs using BPTT (Back-Propagation Through Time).
The main reasons are the vanishing and exploding gradient problems, which LSTM (Long Short Term Memory) mitigated enough to be more trainable, but it did not entirely solve the problem. Then, what are the remaining issues with LSTM?
To understand the issue, we need to know how BPTT works. Then, it will be clearer how the vanishing and exploding gradients occur. After that, we can appreciate why LSTM works better than the basic RNN, especially for long sequential data. Finally, we will understand why LSTM does not completely solving the problems.
In this article, we discuss the following topics:
- BPTT (Back-Propagation Through Time)
- Vanishing and Exploding Gradients
- LSTM (Long Short Term Memory)
- BPTT Through LSTM Cells
BPTT (Back-Propagation Through Time)
Let’s use a basic RNN to discuss how BPTT works. Suppose the RNN uses the final hidden state to predict (i.e., regression or classification) as shown below:
The model processes the input data sequentially:
At the first step, the RNN takes in the initial hidden state containing zeros and the input data’s first element to produce the hidden state of the first step. At the second step, the network takes in the first step’s hidden state and the input data’s second element to produce the hidden state of the second step. At each step, the network takes in the previous hidden state and the input data’s current element to produce the hidden state of that step.
We can mathematically summarize the process as follows: