본문 바로가기
연구실/강화학습 논문 리딩

Towards Interpretable Reinforcement Learning sing Attention Augmented Agents

by 정은진공부해 2020. 6. 21.

논문을 선정한 이유

최근 attention 기법에 대해 흥미를 많이 느끼고 있었는데, 심층 강화학습(Deep Reinforcement Learning, DRL)에 적용된 논문이 있어 찾아 읽어보게 되었다.


논문의 목표

논문 제목에서 볼 수 있듯이, interpretable한 DRL 모델을 만들고자 한 것이다. DNN(Deep Neural Network, DNN)가 강화학습에 접목된 DRL은 DNN과 같이 학습 과정이 블랙 박스라는 단점이 있다. 이 논문에서는 제안 모델이 어떤 시점에 어느 지역에 attention을 가하는지를 불 수 있기 때문에 interpretable RL이라고 하였다.

atari 게임에서의 attention 적용 예

논문의 주된 제안 방법

입력 영상에 대해 attention 기법을 적용하여 DRL의 입력으로 활용


구체적인 제안 방법

제안 모델의 attention 과정에 대한 그림은 아래와 같다.

attention 과정

먼저, 시간 단계 t에서 입력으로 받은 observation X가 있다고 하자. X는 height, width, channel이 각각 H, W, C인 하나의 이미지이다.

1의 과정에서는 이 X의 feature map O를 추출하게 된다. O를 추출하기 위해 이 논문에서는 ConvLSTM과 같이 LSTM layer와 Conv layer가 연결된 네트워크를 이용했다.

2의 과정에서는 추출한 O를 토대로 key tensor K와 value tensor V를 생성한다. 이는 O에 대해 특정 채널에서 split하여 두 tensor로 쪼갠 것이다. 

이에 따라, 아래의 식을 만족한다.

3의 과정에서는 spatial basis라는 vector를 key, value tensor에 concat 시킨다. 이에 대한 내용은 뒤에서 자세히 설명한다.

4의 과정에서는 query 벡터를 생성하는 과정이다. 이 query 벡터는 MLP로 설계된 query network의 출력값이다. query network는 LSTM layer의 hidden state를 입력으로 받는다. LSTM layer에 대한 자세한 설명은 논문에 잘 나와있지 않다. 확인한 바로는, 이 LSTM layer의 입력은 agent의 action, reward, 이전 attention된 벡터 등이 이용되는 것 같다.
그렇다면 이 LSTM layer의 hidden state를 입력으로 받는 query network는 무엇을 의미할까? 나는 이전 환경과 에이전트에 대한 history 정보를 반영하기 위해 LSTM의 hidden state를 이용하는 것이라고 해석했다.
즉, query network는 "에이전트가 지금까지 이런 행동을 했고 이런 보상을 받았고 과거 입력 영상에 대한 attention된 feature는 이랬으니까 어느 부분에 집중을 해야겠구나~!"를 고려해 query vector를 출력하게 되는 것이다.

5의 과정에서는 아래와 같이 앞서 구한 key value와 query network를 pixel-wise inner product 연산을 취한다. 이 연산이 무엇인지는 쉽게 유추할 수 있을 것이다. pixel-wise 연산을 함으로써 query vector의 값을 토대로 Key tensor K에서 중요한 특징은 더 크게, 반대로 덜 중요한 특징은 더 작게 하고자 한 것이다.

6의 과정에서는 이 텐서 A에 대해 spatial softmax를 취한다. spatial softmax는 주어진 feature map에 대해 각 pixel 좌표의 값(ex. activation function을 취한 이전 layer의 결과값)들에 대해 softmax를 취한 것이다. 이를 통해, 6의 과정에서는 5의 과정에서 구한 tensor를 토대로 probability map을 만든 것이다. 이 probability map을 attention map이라고 하자.

7의 과정에서는 attention map과 value tensor를 곱한다. 이 과정도 직관적으로 유추될 수 있다. 본래 입력 영상의 feature의 일부인 value tensor와 attention map을 곱한다는 것은 feature에서 중요한 부분에 더욱 가중을 두겠다는 얘기다. 다시 상기하자면, 이전의 과정들을 통해 attention map은 환경과 agent와 관련된 history 정보를 반영하여 입력 영상에서 어디에 더욱 중점을 둘 것인지에 대한 probability map이었다.
따라서, 이러한 attention map과 value tensor를 곱한다는 것은 환경과 agent 등을 고려했을 때 어디(region or pixel 좌표)에 집중을 하는 것이 좋은지를 반영하는 것이다.

8의 과정에서는 7의 과정에서 구한 attention이 가미된 feature를 축소(reduce sum을 통해)시켜 최종적으로 answer 벡터를 구한다. 이 answer 벡터는 이전 과정들을 통해, 입력 영상의 특징과 더불어 입력 영상의 특징에서 어디를 집중적으로 고려할지를 반영한 벡터이다.  feature를 축소시키는 이유는 단순하다. LSTM의 입력으로 활용하기 위해서다.
하지만, 문제가 있다. 이렇게 축소된 answer 벡터는 과연 입력 영상에서의 공간 정보를 충분히 반영하지 못한다. 아무리 1의 과정에서 Convolution 연산을 통해 공간 정보를 반영했더라도, 최종 feature를 1차원의 벡터로 축소시켰기 때문에 공간 정보를 잃게 된다.

이를 해결하기 위해 이 논문에서는 3의 과정에서 spatial basis를 두 tensor에 추가한 것이다. 이 논문에서는 Fourie basis representation을 이용해 아래와 같은 spatial basis vector를 생성한다.

이 식은 h*w의 크기를 갖는 이미지에서 픽셀 좌표 (i,j)가 갖는 공간 상의 위치 정보를 나타낸다. (이를 위해 fourie basis representation을 활용하여 주기 함수들의 조합으로 나타낸 것이다.)

정리하자면, spatial basis는 주어진 전체 이미지(또는 feature) 대비 (i,j) 위치의 인코딩 된 정보를 나타낸다. 이러한 spatial basis 벡터를 입력 영상의 feature인 key와 vector의 텐서에 concat 하게되면 앞서 언급한 문제가 해결된다. 
이는 answer 벡터가 1차원의 벡터로 축소되더라도 feature map을 이루는 각각의 값들에 대한 인코딩 된 위치 정보가 존재하기 때문이다. 

이 논문에서 재밌는건 spatial basis를 얼만큼 가중할 것인지에 따라서 attention 효과가 다르다는 것이다. query 벡터와 key 벡터간의 pixel-wise inner product 연산을 취할 때 입력 영상의 feature 부분과 spatial basis 부분에 각각 다른 가중치 상수를 곱해줄 수 있다. 

가중치를 어떻게 주느냐는 아래와 같이 크게 두가지로 분류될 수 있다.

  • 이미지 특징 부분에 더욱 가중치를 줄 경우 : "입력 영상에서 어디(무엇)에 집중할 것인가"
  • spatial basis 부분에 더욱 가중치를 줄 경우 : "어느 픽셀 좌표에 가중치를 둘 것인가"

나는 이 부분을 이해하고 나서 많이 놀랐다. 다른 논문에서 이미 쓰인 방법인지는 잘 모르겠지만 나는 이렇게 본래 입력 영상의 feature 부분과 spatial basis에 서로 다른 가중치를 두어 위와 같은 attention 전략을 만들 수 있다는 것이 너무 참신하게 다가왔다.

마지막으로, 이렇게 attention 된 벡터는 LSTM layer의 입력으로 활용되며 이 LSTM layer의 output은 별도의 MLP를 통해 강화학습을 위한 network(policy net, value net)의 입력으로 활용된다. 강화학습 알고리즘은 IMPALA를 이용했다.


실험 환경

atari 게임

실험

이 논문에서는 각 atari game마다 많은 실험을 진행했다.
(실험이 워낙에 많다보니 이 글에서는 모든 실험 결과를 담지 않고 정성 평가를 위주로 담았다.)

첫 번째 실험은 강화학습 성능에 대한 정량 평가이다. policy net, value net의 신경망 구조를 Feedforward network일 때와 LSTM network일 때 그리고 LSTM network를 기반으로 attention을 적용했을 때의 실험 결과이다. 실험 1을 보면, LSTM보다 Feedfowrd network를 적용했을 때의 성능이 더욱 높다. LSTM network는 부분 관측 가능한 환경일수록 더욱 높은 성능을 보이지만 atari 게임과 같이 입력 데이터에 noise가 적은 환경에서는 Feedforward network가 더욱 높은 성능을 갖는 것으로 알려져 있다. 반면에, attention을 적용한 제안 모델이 가장 높은 성능을 보였다.

두 번째 실험은 제안 모델에서 novel state에 대해 잘 적응할 수 있는지를 보인 정성 실험이다. 이 실험을 위해, 게임 상에서는 존재하지 않지만 입력 영상에 별도의 물고기를 갑작스럽게 추가하여 이 물고기에 대해 attention이 잘 적용되는지를 보였다.

실험 2

세 번째 실험에서는 게임 환경에서 agent가 위협 요소를 잘 회피하고 미래의 보상들을 잘 획득하는 방향으로 attention이 적용되는지를 보인 실험이다. 실험 3에서 볼 수 있듯이, 위험 요소인 유령과 agent가 먹어야 할 먹이들에 대해 attention이 적용된 것을 볼 수 있다.

실험 3

네 번째 실험에서는 게임 전략에 따라 attention 전략도 유동적으로 변화하는가를 보인 실험이다. 실험 4에서 볼 수 있듯이, 공이 막대기에 가까울 때는 공과 막대기에 대해 attention이 적용된 것을 볼 수 있다. 반면에, 공이 블럭에 가까울 때는 게임 전체 화면에 attention이 넓게 적용된 것을 확인할 수 있다.

 

실험 4


개인적인 논문에 대한 생각

교수님의 말씀을 빌리자면, 사람의 가장 대단한 능력 중 하나는 attention 능력이다.
사람은 매 순간 들어오는 시각,청각 등의 정보들 중 중요한 부분 정보들만을 고려한다. 사람은 일종의 attention 기능을 갖고 있는 것이다. 이를 통해, 사람은 주어진 이벤트에 대해 적은 에너지로 작업을 수행한다.

이에 따라, 최근에는 Deep learning 기반의 자연어 처리, 영상 처리 분야에서 attention 기법이 각광받고 있다.

반면, 내가 아는 한에서 attention 기법이 강화학습 분야에 많이 적용되지 않았다. 사실 이 논문에서 제안하는 attention 기법은 기존에 있는 attention 방법과 크게 다르진 않지만 강화학습 분야에서 문제를 효과적으로 해결하고 interpretable한 모델을 만들 수 있다는 점에서 좋은 논문임에는 틀림없다고 생각한다.