이 포스트는 Do it 딥러닝 교과서 (윤성진 저)를 참고하여 만들어졌음!
LSTM (Long Short Term Memory)
LSTM에 들어가기 전에 기존 RNN의 문제점에 대해 살펴본다. 이전까지 배운 RNN은 Vanilla RNN으로, 순환신경망의 기초 동작원리를 설명한다고 볼 수 있다. Vanilla RNN은 다음과 같은 문제점을 갖는다.
1. 장기의존성 (long-term dependency) : 시간 상 멀리 떨어진 입력의 영향이 약해진다. 즉 어떤 입력 데이터가 실제로 멀리 떨어진 입력에 대해 장기 의존성이 있음에도 불구하고 Vanilla RNN으로는 이를 파악할 수 없다. 따라서 순차열이 길어질수록 데이터의 정보가 서서히 사라진다. → 모델의 기억력이 짧다!
2. Gradient Vanishing and Exploding : Vanilla RNN의 구조는 아래와 같다.
입력데이터 $x_i$가 입력될 때 마다 매번 가중치 행렬 $W$가 곱해진다. 이렇게 행렬이 반복적으로 곱해지면 행렬의 거듭제곱이 되는데, 이때 각 차원의 고윳값 크기가 1 초과이면 발산, 1 미만이면 0으로 수렴하게 된다. 따라서 가중치 행렬을 여러 번 곱하는 작업을 피하는 방향으로 개선이 이루어져야 한다.
3. 사실 Gradient Exploding은 간단히 막을 수 있다? → Gradient clipping
용어를 더 잘 이해하기 위해 Clip이 무엇인지 먼저 알아본다!
(사진출처 : clip (【동사】깎다 ) 뜻, 용법, 그리고 예문 | Engoo Words)
Clip은 '깎다'를 의미한다. numpy.clip(array, min, max)라는 넘파이 문법도 존재하는데, min보다 작은 값을 min으로 깎고, max보다 큰 값을 max로 깎는다고 이해하면 편하다. 이것을 클리핑(Clipping)이라고 부른다.
Gradient clipping은 위와 마찬가지로 gradient가 일정 크기 이상으로 커지지 못하게 클리핑 한다.
$$\vec{g} \leftarrow \frac{\vec{g}}{||\vec{g} ||}, \quad if || \vec{g}|| > v $$
일정 크기 ($v$)보다 gradient 값이 커지면 그 크기를 $v$로 맞추어 gradient exploding을 막을 수 있다.
LSTM 구조
LSTM은 Vanilla RNN의 gradient vanishing 문제를 해결하기 위해 개발되었다. (gradient exploding은 clipping으로 해결한다.) LSTM이 어떤 구조를 가졌기에 gradient vanishing을 해결할 수 있는지 알아보도록 한다.
Vanilla RNN에서 Gradient vanishng을 발생하는 원인은 바로 가중치 행렬 $W$를 연속으로 곱하는 것이었다. LSTM은 cell state를 연결하는 과정에서의 $W$ 연산을 제거하여 gradient vanishing, 장기의존성 문제를 해결하였다. (Hidden state $\textbf{H}$와 input $x$를 곱하는 과정에서는 $W$이 존재한다.)
위 LSTM cell 그림을 더 자세하게 살펴본다.
- Cell state (맨 위 $\textbf{C}$) : 셀 상태 (장기 기억)
- Hidden state ($\textbf{H}$) : 은닉 상태 (단기 기억)
- Forget gate ($\textbf{F}$)
- Input gate ($\textbf{I}$)
- Candidate memory ($\tilde{\textbf{C}}$) : 새로운 기억
- Output gate ($\textbf{O}$) : 출력
- $\sigma$ = Sigmoid function
LSTM cell의 동작 과정
아래 설명된 내용은 이 글을 참고하였다. Long Short Term Memory | Architecture Of LSTM (analyticsvidhya.com)
1. Forget gate
예를 들어 Bob is a nice person. Dan on the other hand is evil 이라는 문장이 있다고 해보자. 단어가 순차적으로 입력되다가 person이 입력되었을 때 문장이 끝난다. 뒤이어 새로운 문장이 시작될 것이므로 forget gate는 'Bob ~~' 문장을 잊어야 하는데, 이 과정이 forget gate에서 수행된다.
$$f_t = \sigma(W_{hf}\vec{h}_{t-1} + W_{xf}\vec{x}_t)$$
Forget gate는 이전 시간의 hidden state $\vec{h}_{t-1}$과 현재 시간의 입력데이터 $\vec{x}_t$를 입력받은 후, 이들을 가중합산한 결과를 sigmoid에 입력한다. Sigmoid function의 출력은 [0,1]이므로 만약 forget gate의 출력이 0이면 Cell state에 t-1의 정보가 아예 들어가지 않게 되고, 출력이 1이면 t-1 정보를 그대로 현재 Cell state에 추가하게 된다.
t-1 시간의 입력이 다음 시간 t의 cell state에 포함되어야 하는지 아닌지는 $W_{hf}, W_{xf}$가 학습되며 결정하는 것이다.
2. Input gate & Candidate memory
Input gate, candidate memory는 입력된 데이터를 cell state에 저장하는 역할을 수행한다.
예를들어 Bob knows swimming, He told me over the phone that he had served the navy for 4 long years 라는 문장이 입력되었다고 해보자. 이 문장에서는 Bob know swimming that he has served the Navy for four years 가 핵심 정보이고, he told all this over the phone은 그다지 중요하지 않아보인다. 즉 한 문장이 입력되었을 때, 그 중 중요한 정보는 Cell state에 넘기고 나머지는 걸러내는 작업이 필요한데, 이것이 Input gate에서 수행된다.
$$\begin{align} &i_t = \sigma(W_{hi}\vec{h}_{t-1}+W_{xi}\vec{x}_i) \\ &g_t = \text{tanh}(W_{hg}\vec{h}_{t-1}+W_{xg}\vec{x}_t) \end{align}$$
- Input gate는 $\vec{h}_{t-1}, \vec{x}_t$를 입력받은 후 이를 가중합산($W_{hi}, W_{xi}$)하여 sigmoid를 통과한다. → forget gate와 아주 같은 작업을 수행한다.
- $\vec{h}_{t-1}, \vec{x}_t$는 $W_{hg}, W_{xg}$를 이용하여 가중합산한 후 tanh를 거친다.
- 위 두 결과를 요소별로 곱한 결과는 그대로 Cell state에 남아서 이번 시간 t에 기억할 새로운 값이 된다.
3. Output gate
LSTM의 어떤 Cell에서 output이 출력되어야 하는 경우가 있다. 이때 Output gate를 사용한다.
예를 들어 다음과 같은 문장이 있다고 해보자. 'Bob fought single handedly with the enemy and died for his country. For his contributions brave ____'
현재 입력된 단어가 brave라고 했을 때, brave는 형용사이므로 그 뒤에 나올 단어는 높은 확률로 명사라는 것을 알 수 있다. 여기서는 Bob이 가장 옳은 output이 될 것이다. 이처럼 현재 입력된 단어 $x_t$를 보고 그 다음 나올 수 있는 단어 $x_{t+1}$를 추측하는 과정이 output gate에서 수행된다.
- Cell state에 tanh를 적용하여 [-1, 1] 범위에서 정의되도록 scaling 한다.
- $\vec{h}_{t-1}, x_{t}$를 가중합산한 결과에 sigmoid를 통과시킨다.
- 위 두 결과를 요소별로 곱한 것이 바로 현재시간 $t$의 hidden state = output이 된다.
이때 Cell state $C$를 장기 기억, hidden state $h$를 단기 기억이라고 생각할 수 있다.
Cell state (장기 기억) : Cell state는 forget gate, input gate의 매 time에서의 결과를 단순히 더한 것으로 생성된다. 따라서 가중치 $W$에 의해 정보가 뭉게지지 않으므로 오래 전에 입력되었던 정보를 비교적 길게 저장할 수 있다.
새로운 사건이 입력될 때마다 조금씩 강화되거나 약화될 수 있으며 자주 사용하지 않으면 정보를 잊을 수 있다.
Hidden state (단기 기억) : 새로운 정보가 입력될 때마다 현재 Cell state와 t-1의 hidden state $h_{t-1}$, 그리고 현재 입력을 전부 결합하여 현재 시간의 hidden state $h_t$를 생성한다. $h_t$를 생성하는 연산은 단순 덧셈이 아니기 때문에 최근에 일어난 사건은 빠르게 기억하지만, 상황이 전환되면 이전 정보를 빠르게 잊는다.
Gradient vanishing이 생기지 않는 이유
Gradient vanishing을 발생시키는 범인은 바로 가중치 행렬 $W$의 반복된 곱 연산이었다. 그렇다면 LSTM에서는 이것이 사라졌을까?
$C_t$를 $C_{t-1}$에 대해 지역미분을 구해보자. $C_t=f_t \odot C_{t-1} + i_t \odot g_t $이므로 지역미분은 다음과 같이 얻어진다.
$$\frac{\partial C_t}{\partial C_{t-1}}=f_t $$
현재 시간 $t$부터 초기 시점 $0$까지의 연쇄미분은 다음과 같이 구할 수 있다.
$$\frac{\partial C_t}{\partial C_0} = \frac{\partial C_t}{\partial C_{t-1}} \cdot \frac{\partial C_{t-1}}{\partial C_{t-2}} \cdot \cdots \cdot \frac{\partial C_1}{\partial C_0} = f_t \cdot f_{t-1} \cdot f_1 = \Pi^{t}_{i-1}f_i $$
따라서 Cell state의 역방향 미분은 forget gate $f_i$의 곱으로 구성되며 $W$를 포함하지 않으므로 gradient vanishing 문제가 해결되었다. (forget gate $f_i$ 값은 각 cell마다 다르게 출력되므로 $f_i$를 곱은 gradient vanishing을 발생할 가능성이 작다.)
GRU (Gated Recurrent Unit)
GRU는 LSTM의 장점을 유지하면서 gate 구조를 단순하게 만든 순환 신경망이다. LSTM은 Cell state가 있어 약간 복잡했는데 GRU는 Cell state를 없애고 hidden state가 장기 기억, 단기 기억을 모두 기억하도록 만들었다.
t-1 Hidden state $H_{t-1}$이 위, 아래 두 경로로 나뉘어진 후 다시 합쳐지는데, 위 경로는 $W$ 연산이 없으므로 장기 기억을 담당하고, 아래 경로는 무수한 $W$ 연산을 통해 단기 기억을 담당한다.
더하여 LSTM에서는 3가지 gate를 사용했지만 GRU는 Reset gate, Update gate 두 가지를 사용한다.
Reset gate $\vec{r}_t$와 Update gate $\vec{z}_t$는 위와 같이 계산된다. Reset gate와 Update gate는 LSTM의 forget gate, Input gate와 수식이 동일하다. 이들은 sigmoid function을 거치며 중요한 정보만을 남긴다.
새로운 입력으로 만들어진 새로운 hidden state $tilde{h}_t$는 위와 같이 계산된다. 과거의 기억 $\vec{h}_{t-1}$에 reset gate를 요소별 곱하여 $\vec{h}_{t-1}$에서 필요한 부분을 선택한다.
과거의 기억 $\vec{h}_{t-1}$이 입력되었고, 새롭게 추가된 기억 $\tilde{h}_t$를 얻었으므로 이들을 이용하여 새롭게 남길 기억(현재 기억)을 만든다. Update gate $\vec{z}_t$를 가중치로 사용하여 이전 상태 $h_{t-1}$와 새로운 상태 $\tilde{h}_t$의 반영비율울 결정한다.
'딥러닝' 카테고리의 다른 글
[LLM] 2. 허깅페이스 트랜스포머 모델 학습하기 (0) | 2024.11.01 |
---|---|
[LLM] 1. 임베딩, 어텐션, 트랜스포머 모델들 (1) | 2024.11.01 |
13. RNN 코드실습 (0) | 2023.10.31 |
12. RNN (Recurrent Neural Network) 이론 (0) | 2023.10.06 |
VGG net 논문리뷰 + 실습 (0) | 2023.10.04 |