python으로 LSTM 바닥부터 구현하기
Updated:
Related
RNN에 이어서 LSTM을 구현했습니다.
LSTM
LSTM(Long-Short Term Memory)은 RNN의 long-term dependencies 문제를 해결한 모델입니다. 장기의존성(long-term dependencies) 문제란, 과거의 정보가 먼 미래까지 전달되기 어려운 문제를 말합니다.
이러한 long-term dependencies는 gradient vanishinig problem과도 관련이 있습니다. gradient vanishing problem이란, 미분 기울기가 0과 1사이의 값을 가지고 여러번 반복적으로 곱하게 되면 기울기가 0으로 수렴되는 문제를 말합니다. 반대로 1보다 큰 기울기를 반복적으로 곱하면 gradient explodinig problem이라 합니다. 먼 과거 정보에 대해 gradient가 소실되어 weight를 업데이트 할 수 없기 때문에 장기의존성 문제가 발생합니다.
LSTM은 이러한 문제를 cell state를 추가하여 해결합니다. LSTM은 cell state에 정보를 저장하거나 삭제할 수 있는데 cell state를 제어하기 위한 Gate들을 가지고 있습니다.
Forward
pytorch 문서에 있는 수식을 사용했고 선언한 weight 크기 때문에 조금 변형하여 사용했습니다. 사용되는 weight는 $W_{ih},\ W_{hh},\ b_{ih},\ b_{hh}$ 입니다.
$W_{ih}$는 각각 (hidden_size, input_size) 크기를 가진 ($W_{ii}, W_{if}, W_{ig}, W_{io}$)로 구성되며 (4 * hidden_size, input_size) 크기를 가집니다. 만약 멀티레이어 LSTM이라면 두 번째 레이어의 $W_{ih}$ 크기는 (4 * hidden_size, hidden_size)입니다.
$W_{hh}$는 각각 (hidden_size, hidden_size) 크기를 가진 ($W_{hi}, W_{hf}, W_{hg}, W_{ho}$)로 구성되며 (4 * hidden_size, hidden_size) 크기를 가집니다.
$b_{ih}$는 ($b_{ii}, b_{if}, b_{ig}, b_{io}$), $b_{hh}$는 ($b_{hi}, b_{hf}, b_{hg}, b_{ho}$)로 구성되고 둘 다 (4 * hidden_size) 크기를 가집니다.
$\odot$은 hadamard product(아다마르 곱)의 기호이고 두 행렬의 원소끼리 곱하는 element-wise product 연산입니다. $\sigma$는 sigmoid 연산입니다.
input gate
-
$i_t = \sigma(x_t W_{ii}^T + b_ii + h_{t-1} W_{hi}^T + b_hi)$
-
$g_t = \tanh(x_t W_{ig}^T + b_ig + h_{t-1} W_{hg}^T + b_hg)$
forget gate
- $f_t = \sigma(x_t W_{if}^T + b_if + h_{t-1} W_{hf}^T + b_hf)$
output gate
- $o_t = \sigma(x_t W_{io}^T + b_io + h_{t-1} W_{ho}^T + b_ho)$
cell state
- $c_t = f_t \odot c_{t-1} + i_t \odot g_t$
hidden state
- $h_t = o_t \odot \tanh(c_t)$
먼저, forget gate는 이전 hidden state와 입력을 보고 cell state의 정보를 유지할지 제거할지 결정합니다. sigmoid를 사용하기 때문에 0과 1사이의 범위를 가지고 1이라면 완전히 유지, 0이라면 완전히 제거합니다.
다음, cell state에 새로운 정보를 추가합니다. sigmoid를 사용하여 얼마나 반영할지 결정할 $i_t$와 새로운 후보 값이 들어간 벡터를 생성하는 $g_t$를 곱하여 cell state에 정보를 업데이트 합니다.
마지막으로, output은 입력에 sigmoid로 출력하여 cell state의 어느 부분을 출력으로 내보낼지 결정합니다.
구현할때는 weight_ih_l[k]와 weight_hh_l[k]를 분리하지 말고 그대로 곱해도 상관없습니다.
Backward
가장 먼저 hidden state $h_t$와 cell state $c_t$를 미분해야 합니다. input gate, forget gate, output gate 모두 hidden state, cell state 계산에 참여하기 때문입니다. LSTM의 output으로 들어오는 gradient를 $\Delta_{t}$, cost function을 $J$라 하겠습니다.
RNN과 마찬가지로 LSTM 또한 hidden state가 다음 셀로 전달되고 t에서의 hidden state가 output이 됩니다. 따라서 시간의 역순으로 hidden state의 gradient가 전달되어야 합니다. t 시점에서의 gradient는 $\Delta_{t}$와 t+1 시점에서의 hidden state의 gradient와 합치면 됩니다.
${dJ \over dh_t} = \Delta_{t} + {dJ \over dh_{t+1}}$
LSTM은 cell state도 다음 셀로 전달되기 때문에 t+1 시점에서의 cell state의 gradient를 더해줘야 합니다.
${dJ \over dc_t} = {dJ \over dh_t} \ {dh_t \over dc_t} + {dJ \over dc_{t+1}}$
input gate 부터 output gate의 미분은 다음과 같습니다. bias는 따로 계산하지 않겠습니다.
${dJ \over dW_{ii}} = [{dJ \over dc_t} \ {dc_t \over di_t} \ d\sigma]^T \cdot x_t$
${dJ \over dW_{if}} = [{dJ \over dc_t} \ {dc_t \over df_t} \ d\sigma]^T \cdot x_t$
${dJ \over dW_{ig}} = [{dJ \over dc_t} \ {dc_t \over dg_t} \ d\tanh]^T \cdot x_t$
${dJ \over dW_{io}} = [{dJ \over dh_t} \ {dh_t \over do_t} \ d\sigma]^T \cdot x_t$
${dJ \over dW_{hi}} = [{dJ \over dc_t} \ {dc_t \over di_t} \ d\sigma]^T \cdot h_{t-1}$
${dJ \over dW_{hf}} = [{dJ \over dc_t} \ {dc_t \over df_t} \ d\sigma]^T \cdot h_{t-1}$
${dJ \over dW_{hg}} = [{dJ \over dc_t} \ {dc_t \over dg_t} \ d\tanh]^T \cdot h_{t-1}$
${dJ \over dW_{ho}} = [{dJ \over dh_t} \ {dh_t \over do_t} \ d\sigma]^T \cdot h_{t-1}$
$d\sigma$와 $d\tanh$는 다음의 예시로 설명할 수 있습니다.
$i_t = \sigma(Wx+b)$라는 수식이 주어졌다고 가정한다면 $i_t = \sigma(s_t),\ s_t = Wx+b$라는 두 수식으로 분리할 수 있고 chain rule에 따라 ${di_t \over dW} = {di_t \over ds_t} {ds_t \over dW}$로 기울기를 구할 수 있습니다.
sigmoid를 미분한다면 $\sigma \cdot (1 - \sigma)$이므로 ${di_t \over ds_t} = \sigma(s_t) (1 - \sigma(s_t)) = i_t * (1 - i_t)$라는 결과를 얻을 수 있습니다.
마찬가지로 $\tanh$를 미분한다면 $1 - \tanh^2$이므로 $g_t = \tanh(s_t)$에서 ${dg_t \over dW} = 1 - \tanh(s_t)^2 = 1 - g_t^2$라는 결과를 얻을 수 있습니다.
위에서 구한 값들을 모두 대입하면 다음과 같은 결과를 얻을 수 있습니다. transpose를 취한 이유는 forward에서 weight에 transpose를 취하여 계산했기 때문입니다.
${dJ \over dW_{ii}} = [{dJ \over dc_t} \cdot g_t \cdot i_t \ (1 - i_t)]^T \cdot x_t$
${dJ \over dW_{if}} = [{dJ \over dc_t} \cdot c_{t-1} \cdot f_t \ (1 - f_t)]^T \cdot x_t$
${dJ \over dW_{ig}} = [{dJ \over dc_t} \cdot i_t \cdot (1 - g_t^2)]^T \cdot x_t$
${dJ \over dW_{io}} = [{dJ \over dh_t} \cdot \tanh(c_t) \cdot o_t \ (1 - o_t)]^T \cdot x_t$
${dJ \over dW_{hi}} = [{dJ \over dc_t} \cdot g_t \cdot i_t \ (1 - i_t)]^T \cdot h_{t-1}$
${dJ \over dW_{hf}} = [{dJ \over dc_t} \cdot c_{t-1} \cdot f_t \ (1 - f_t)]^T \cdot h_{t-1}$
${dJ \over dW_{hg}} = [{dJ \over dc_t} \cdot i_t \cdot (1 - g_t^2)]^T \cdot h_{t-1}$
${dJ \over dW_{ho}} = [{dJ \over dh_t} \cdot \tanh(c_t) \cdot o_t \ (1 - o_t)]^T \cdot h_{t-1}$
${dJ \over dc_t} = {dJ \over dh_t} \cdot o_t \cdot (1 - \tanh(c_t)^2) + {dJ \over dc_{t+1}}$
다음 셀에 전달할 ${dJ \over dh_{t-1}}$와 ${dJ \over dc_{t-1}}$, 입력에 대한 gradient ${dJ \over dx_t}$는 다음과 같습니다.
${dJ \over dh_{t-1}} = {dJ \over di_t} \ {di_t \over dh_{t-1}} + {dJ \over df_t} \ {df_t \over dh_{t-1}} + {dJ \over dg_t} \ {dg_t \over dh_{t-1}} + {dJ \over do_t} \ {do_t \over dh_{t-1}}$
${dJ \over dc_{t-1}} = {dJ \over dc_t} {dc_t \over dc_{t-1}} = {dJ \over dc_t} \cdot f_t$
${dJ \over dx_t} = {dJ \over di_t} \ {di_t \over dx_t} + {dJ \over df_t} \ {df_t \over dx_t} + {dJ \over dg_t} \ {dg_t \over dx_t} + {dJ \over do_t} \ {do_t \over dx_t}$
구현하기 위해 forward 과정과 마찬가지로 $W_{ih}$와 $W_{hh}$를 $W_{ii}, \ W_{hi}$부터 $W_{io}, \ W_{ho}$로 분리하지 말고 그대로 곱해줍니다.
Result
LSTM과 Linear레이어만으로 이루어진 모델을 사용했고 hidden size는 256으로 사용했습니다. LSTM의 마지막 output을 Linear로 출력했고 이전 포스트들과 똑같이 MNIST 5000장을 훈련 데이터로 사용하고 1000장을 테스트로 사용했습니다.
10 Epochs의 Loss와 Accuracy의 그래프는 다음과 같습니다.
놀라운 점은 RNN에 비해 4 epoch만에 90%의 Accuracy를 기록했고 9 epoch에서 95% 이상의 Accuracy를 기록했습니다. 뿐만 아니라 loss도 잘떨어졌습니다.
Comments