머신러닝&딥러닝/자연어처리

Transformer(2) - Multi head attention

Like_Me 2019. 12. 25. 21:52

 Multi-head-attention은 앞에서 알아본 Scaled Dot-Product Attention을 여러 개 만들어 다양한 특징에 대한 어텐션을 볼 수 있게 한 방법이다. 즉, 입력받은 query, key, value를 헤드 수만큼 나누어 병렬적으로 계산해주는 것이 핵심이다!

 

 과정은 다음과 같다. 내적 셀프 어텐션에서 본 query, key, value를 헤드 수만큼 나누어 Linear layer를 통과시키고 내적 어텐션을 구해 합치는 과정을 거친다. 마지막으로 선형층을 거쳐 나오면 멀티 헤드 어텐션이 끝나게 된다.

총구조를 보면 아래와 같다.

 

 

멀티-헤드 어텐션 구조

처음 Linear Layer를 이용해서 Q, K, V의 차원을 감소하며 Query와 Key의 차원이 맞지 않을 경우 맞추는 역할을 한다. 

마지막 Linear Layer는 attention value의 차원을 필요에 따라 변경시킬 수 있게 해 준다.

 

다음은 tensorflow로 구현해본 코드이다.

def multi_head_attention(query,key,value,num_units,heads,masked=False):
    query = keras.layers.Dense(num_units,activation='relu')(query)
    key = keras.layers.Dense(num_units,activation='relu')(key)
    value = keras.layers.Dense(num_units,activation='relu')(value)

    query = tf.concat(tf.split(query,heads,axis=-1),axis=0)
    key = tf.concat(tf.split(key,heads,axis=-1),axis=0)
    value = tf.concat(tf.split(value,heads,axis=-1),axis=0)

    attention = scaled_dot_product_attention(query,key,value,masked)

    output = tf.concat(tf.split(attention,heads,axis=0),axis=-1)
    output = keras.layers.Dense(num_units,activation='relu')(output)
    return output

query, key, value를 Dense층을 거치게 한 후 axis=-1로 분해한 후, axis=0으로 다시 concatenate 해준다.

이는 (배치 사이즈, 시퀀스 길이, 피쳐 길이)의 행렬이 있을 때, 피쳐 길이를 헤드수만큼 분리해준 후 배치 차원에 헤드 차원만큼 늘려주는 것을 말한다. 그래서 결국 (배치*헤드, 시퀀스, 피쳐/헤드)가 된다. 이는 3차원 행렬 모양으로 셀프 어텐션 연산을 해야 하기 때문이다. 다음으로 전에 구현했던 함수로(이전 글) scaled dot product attention 연산을 해주고 다시 원래의 행렬 모양으로 돌려주기 위해 axis=0으로 분리 후, axis=-1로 concatenate 해준다. 마지막으로, Dense층을 거치면 연산이 끝난다.