RNN 정리

Updated:

시퀀스 데이터

  • 소리, 문자열, 주가 등의 데이터를 시퀀스(sequence) 데이터로 분류한다.
  • 시퀀스 데이터는 독립동등분포(Independent and Identically Distributed, i.i.d.)가정을 잘 위배하기 때문에 순서를 바꾸거나 과거 정보에 손실이 발생하면 데이터의 확률분포도 바뀌게 된다.

시퀀스 데이터 핸들링

  • 이전 시퀀스 정보를 가지고 앞으로 발생할 데이터의 확률분포를 다루기 위해 조건부확률을 이용할 수 있다.
    • $P(X_1,…,X_t) = P(X_t|X_1,…,X_{t-1})P(X1,…,X_{t-1})$ (베이즈 정리)
    • 다시, $P(X_t|X_1,…,X_{t-1})P(X_{t-1}|X_1,…,X_{t-2})P(X_1,…,X_{t-2})$ 로 만들 수 있다.
    • 정리하면 $\prod_{s=1}^{t}P(X_s|X_{s-1},…,X_1)$
  • 시퀸스 데이터를 다루기 위해선 길이가 가변적인 데이터를 다룰 수 있는 모델이 필요하다.
    • $X_t \sim P(X_t|X_{t-1},…,X_1)$
    • $X_{t+1} \sim P(X_{t+1}|X_{t},X_{t-1}…,X_1)$
  • 몇 문장만이나 가까운 과거까지의 데이터만으로 모델링이 가능하다면 고정된 길이 $\tau$만큼의 시퀀스를 사용할 수 있다.
    • 이런 경우를 $AR(\tau)$, 자기회귀모델(Autoregressive Model)이라고 한다.
    • $\tau$는 하이퍼 파라미터로서 모델링을 하기전에 정해줘야 하는 값이다.
  • 다른 방법으로는 이전 정보를 제외한 나머지 정보들을 $H_t$라는 잠재변수로 인코딩해서 활용하는 잠재 AR 모델이 있다.
    • $X_t \sim P(X_t|X_{t-1},H_t)$
    • $X_{t+1} \sim P(X_{t+1}|X_{t},H_{t+1})$
    • $H_t = Net_\theta(H_{t-1},X_{t-1})$
    • 잠재변수 $H_t$를 신경망을 통해 반복하여 사용하여 시퀀스 데이터의 패턴을 학습하는 모델이 RNN이다

RNN(Recurrent Nerual Network)

  • 기본적인 RNN 모형은 MLP와 유사한 모양이다.
    • $O = HW^{(2)}+b^{(2)}$
    • $H = \sigma(XW^{(1)}+b^{(1)})$
    • $W^{(1)}, W^{(1)}$은 시퀀스와 상관없이 불변인 행렬이다.
    • 이 모델은 과거의 정보를 다룰 수 없다.
  • 잠재변수인 $H_t$를 복제해서 다음 순서의 잠재변수를 인코딩하는데 사용한다.
    • $O_t = H_tW^{(2)}+b^{(2)}$
    • $H_t = \sigma(X_tW_x^{(1)}+H_{t-1}W_H^{(1)}+b^{(1)})$
  • RNN의 역전파는 잠재변수의 연결그래프에 따라 순차적으로 계산한다. 이를 Backpropagation Through Time(BPTT)라 한다.
    • 다음 시점에서의 잠재변수로부터 오는 그레디언트 벡터와 출력에서 들어오는 그레디언트 벡터를 통해 현재 잠재변수에 전달이 된다.
    • 잠재변수의 그레디언트 벡터를 입력과 이전 시점의 잠재변수로 다시 전달한다.

BPTT

  • $\partial_{w_h}h_t = \partial_{w_h}f(x_t,h_{t-1},w_h)+\sum_{i=1}^{t-1}(\prod_{j=i+1}^{t}\partial_{h_{j-1}}f(x_j,h_{j-1},w_h))\partial_{w_h}f(x_i,h_{i-1},w_h)$
  • $\prod_{j=i+1}^{t}\partial_{h_{j-1}}f(x_j,h_{j-1},w_h)$은 i+1시점에서부터 t시점까지 모든 히든변수(잠재변수)에 대한 미분텀이 곱해지면서 더해진다.
    • 만약 시퀀스 길이가 길어지면(즉, 현재시점부터 예측이 끝나는 t시점까지) 곱해지는 텀들이 불안정해지기 쉬워진다. 만약 이 값이 1보다 크게 되면 굉장히 크게 커지고 1보다 작게되면 굉장히 작은 값으로 떨어진다. 미분값이 엄청 커지거나 엄청 작아지게 될 확률이 높아진다.
    • 그래서 일반적으로 BPTT를 모든 시점에 적용하게되면 RNN의 학습이 굉장히 불안정해지기 쉬워진다.
  • 시퀀스 길이가 길어지는 경우 역전파 알고리즘의 계산이 불안정해질 수 있으므로 길이를 끊는 것이 중요해진다. 이를 truncated BPTT라 한다.
    • 이런 문제를 해결하기 위해 등장한 RNN 네트워크가 LSTM과 GRU이다.

Comments