지난 시간에 RNN에 대해 간단히 리뷰해보았는데, 오늘은 LSTM에 대해 리뷰하고자 한다.
LSTM은 RNN처럼 시계열 데이터를 처리할 때 사용하는 신경망으로 RNN의 단점을 극복하기 위해 등장한 신경망이다.
RNN의 약점
RNN은 시계열 데이터를 처리함에 있어서 '장기 의존성'에 대한 약점이 존재했다.
예를 들어 아래와 같은 RNN이 존재한다고 가정하자.
아래와 같이 Sequence가 짧은 경우에는 이전 정보들이 잘 업데이트가 되어갈 수 있다.
그러나, 만약 Sequence가 점점 길어지게 된다면, 초기정보들은 반복적인 곱하기 연산(Chain-Rule)으로 인해 기울기가 0으로 소실되어 간다. 이를 Vanishing Gradient 현상이라고 하는데, 이러한 현상은 시간적으로 먼 입력값일수록 학습에 미치는 영향이 줄어듦을 의미한다.
그래서 이러한 RNN의 약점을 극복하기 위해 LSTM이 등장했다.
LSTM 기본 구조
LSTM의 기본적인 구조는 아래와 같다.
우선적으로 말하자면, LSTM은 Cell state를 통해 장기 의존성 문제를 극복하며 4개의 Gate를 가지고 있다. 먼저 각각의 Gate가 어떻게 작동하는지 알아보자.
- Forget Gate
Forget Gate는 어떤 정보를 지울 것인지를 결정하는 Gate다. Forget Gate로 들어오는 정보는 지난 은닉상태 ($h_{t-1}$)와 현재 입력값($x_t$)를 concatenate한 값이다. 이렇게 concatenate된 값($X_t$)은 이전 은닉상태와 현재 입력값이 한데 묶여진 일종의 단기기억처럼 된다. (이 $X_t$는 LSTM내의 모든 Gate의 입력값이 된다는 것을 알고 가자.)
Forget Gate에서 주목해야 할 점은 시그모이드 함수이다. 시그모이드 함수는 어떤 입력이 들어와도 0~1 사이의 값을 리턴하는 함수인데, Forget Gate는 들어오는 값($X_t$)을 입력값으로 받아 가중치($W_f$)를 곱한 뒤, 0~1 사이의 값으로 바꿔주는 역할을 한다. (이때, 마이너스에 가까운 값은 0이 될 것이고 양수는 1에 가까워질 것이다.
그 후 0~1 사이로 바뀐 값들은 셀상태($C_{t-1}$)의 값을 만나 Element-wise 곱셈연산을 하여 원소가 1인 값들은 남기고 0인 값들은 망각한다. 즉, 셀상태는 Forget Gate를 지나면서 잊어버려야 할 것들을 잊어버리게 된다.
- Input Gate / Candidate Gate
Input Gates는 현재 input에서 얼마만큼의 정보들을 반영할 것인가를 계산한다.
Input Gate도 Forget Gate와 같이 시그모이드 함수를 사용한다. 다만, 부여되는 가중치($W_i$)가 다르다.
Forget/Input/Output/Candidate Gate 모두 현재 입력된 text하고 과거에서부터 처리되어온 Hidden state 정보를 반영하여 Weight가 결정된다.
더하여 Input Gate는 Candidate Gate와 같이 연산하여 셀상태($C_{t-1}$)를 기억해야할 것들로 업데이트 한다. Candidate Gate는 내부연산이 시그모이드가 아닌 tanh함수이며 이는 -1~1 사이의 값으로 정규화해준다.
그 다음 Input Gate에서 나온 값($i_t$)들과 Candidate Gate에서 나온 값($C_t$)를 element-wise 연산을 통해서 Candidate Gate에서 나온 값들 중 어떤 값들은 0에 가깝게 만들고 어떤 값들은 그대로 놔두는 역할을 한다. 이것들이 현재 입력중에 기억할 정보를 의미한다. 그 다음에 이 값들을 셀상태($C_{t-1}$)에 더해서 업데이트 하게 된다.
- Output Gate
그 다음 이러한 장기기억 상태($C_{t-1}$)를 tanh를 통해서 정규화한 후, Output Gate에서 나온 값과 element-wise 곱을 하여 새로운 히든상태($h_t$)를 만들어낸다.
Output Gate와 tanh의 콜라보는 업데이트된 셀상태($C_t$)에서 현재 입력값($X_t$)의 특성을 더 반영하는 새로운 히든 상태를 만들어내는 것이다.
이렇게 되면 새로운 히든 상태 $h_t$는 $C_t$에 비해 좀 더 Short-term한 특성을 보이게 될 것이다. 즉, $C_t$가 좀 더 Long-term한 정보를 많이 담는다면, $H_t$는 같은 입력으로 좀 더 Short-term에 가까운 정보를 담게 되므로 RNN의 단점(장기의존성 문제)을 어느정도 극복할 수 있다.
Reference
경희대학교 산업경영공학과 머신러닝
https://www.youtube.com/watch?v=HXa7Ah87_gM&t=1572s
'Analytics' 카테고리의 다른 글
[논문 리뷰] SHAP (1) | 2024.08.26 |
---|---|
[논문 리뷰] Prophet 시계열 예측 (0) | 2024.08.20 |
[Causal Inference] Causal Inference for The Brave and True 리뷰 (0) | 2024.08.12 |
[ML/DL] Recurrent Neural Network(RNN) (0) | 2024.08.09 |
[NLP] Text Representation (0) | 2024.06.14 |