Python으로 Diffusion 바닥부터 구현하기[1] (ResidualBlock, AttentionBlock, UpsampleBlock)

Updated:

Text-to-Image

최근 Stable Diffusion이라는 모델이 발전하면서 많은 사람들이 이미지 생성 AI를 사용하고 있습니다. Diffusion 모델은 노이즈를 기존 이미지에 더한 뒤에 노이즈를 예측하고 지워가며 복원하는 학습과정을 거치게 됩니다.

저는 Diffusion 모델을 구현하여 MNIST를 학습시키고 label이라는 condition을 주어 생성해보려고 합니다.

Diffusion 모델 중 Huggingface의 UNet2DModel을 선택했는데 UNet2DModel 내부에 class_embedding 레이어가 있어 class condition을 넣을 수 있도록 제공하고 있습니다.

이 페이지에선 ResidualBlock, AttentionBlock, UpsampleBlock을 구현하고 UNet의 구현은 다음 글에서 이어서 작성하겠습니다.

ResidualBlock

ResidualBlock은 UNet에서 Down, Mid, Up 레이어 모두에서 사용됩니다. ResidualBlock은 컨볼루션 레이어를 이용해 이미지의 특징을 추출하거나 채널을 변경할 수 있습니다. 또한, 시간 임베딩(Time Embedding) 정보도 입력으로 받아 통합하는 역할을 수행합니다.

Forward

아래는 ResidualBlock의 순전파를 도식화한 것입니다.

01
ResidualBlock Forward Flow

먼저, 입력 X를 연산하는 (Conv2d, Norm, Actv) 과정을 두 번 반복하여 out_channels를 가진 이미지가 됩니다. 중간에 time을 인코딩한 time embedding 벡터를 이미지에 더해주어 ResidualBlock을 사용하는 모든 연산들에 대해 timestep 정보를 알도록 합니다.

그리고 in_channels 채널을 가진 이미지를 kernel size 1인 Conv2d를 이용해 out_channels 채널을 가진 이미지로 변환하여 skip connection을 수행합니다.

Backward

아래는 ResidualBlock의 역전파를 도식화한 것입니다.

02
ResidualBlock Backward Flow

Upstream gradient dz는 (Conv2d, Norm, Actv) 레이어와 kernel size 1인 Conv2d레이어에 동일하게 전파됩니다. 마찬가지로 time embedding 또한 동일하게 전파됩니다. Forward에서 입력 X가 두 Conv2d에 동일하게 전파되었으므로 Backward에선 두 Conv2d의 기울기가 더해져 dx가 됩니다.

AttentionBlock

AttentionBlock은 노이즈가 섞인 이미지로부터 정보의 중요도를 재조정하고, 관계를 파악하는데 집중합니다. 이를 이용해 추론단계에서 노이즈를 예측하는데 중요한 역할을 담당합니다.

Forward

아래는 AttentionBlock의 순전파를 도식화한 것입니다.

03
AttentionBlock Forward Flow

입력 X는 정규화를 진행한 뒤에 Query, Key, Value로 만드는 Linear 레이어를 통과하게 됩니다. 이 Query, Key, Value는 다시 Scaled-dot Product Attention 연산에 의해 hidden_states가 생성되고 이 벡터를 to_out레이어를 거쳐 AttentionBlock의 출력으로 사용됩니다.

연산 중간마다 reshape을 통해 변환을 수행하고 있는데 주석으로 어떤 형태가 되어야 하는지 추가로 적어두었습니다.

Backward

아래는 AttentionBlock의 역전파를 도식화한 것입니다.

04
AttentionBlock Backward Flow

Upstream gradient는 to_out을 거쳐 Scaled-dot Product Attention에 전파됩니다. 입력을 to_q, to_k, to_v에 똑같이 들어갔기 때문에 SDPA의 역전파 결과인 dQ, dK, dV를 to_q, to_k, to_v로 전파된 후 합쳐서 전파됩니다.

UpsampleBlock

UpsampleBlock은 UNet의 Up layer에서 사용되어 Down layer에서 낮아진 resolution을 복원하는 역할을 합니다. 사용자 정의에 따라 interpolate + Conv2d 또는 ConvTranspose2d 레이어를 선택할 수 있습니다.

Forward

Forward 과정은 간단하게 ConvTranspose2d 레이어를 사용하여 해상도와 채널을 이전 상태로 복원하거나 interpolate로 해상도 복원, Conv2d로 채널 복원을 시도하는 과정을 수행합니다.

interpolate 함수는 다음과 같이 nearest 모드를 기준으로 작성했습니다. nearest 모드는 추가해야 할 공간에 주변 픽셀을 복사하는 방법으로 수행됩니다.

예를 들어, [0.1, 0.2, 0.3]인 (1,3) 배열을 (1,5)로 보간한다면 [0.1, 0.1, 0.2, 0.2, 0.3]의 결과를 받을 수 있습니다.

Backward

Backward의 경우도 Forward의 역순으로 수행합니다.

interpolate의 backward 함수는 Upstream gradient를 원래의 위치에 누적해야 합니다.

예를 들어, [0.1, 0.2, 0.3]인 (1,3) 배열을 보간하여 [0.1, 0.1, 0.2, 0.2, 0.3]이 되었고, upstream gradient [1,2,3,4,5]를 역전파시킨다면 입력에 대한 기울기는 [1+2,3+4,5] = [3,7,5]가 됩니다.

Reference

Comments