본문 바로가기
연구실/디버깅

[Pytorch] loss가 줄어들지 않을 때

by 정은진공부해 2020. 8. 4.

Pytorch로 neural net을 설계/구현한 뒤 학습을 했을 때 loss가 줄어들지 않는 이유는 굉장히 많을 것이다.
목표 문제를 풀기 위한 neural net의 설계(layer 수, activation function 등)가 잘못되었을 수도 있고 자잘한 구현 상의 오류가 있을 수도 있기 때문이다.

이번 글에서는 최근에 내가 regression model과 적절한 loss function을 설계하고 이를 학습했을 때 loss가 전혀 줄어들지 않았던 이유와 해결 방법을 기록하고자 한다.


먼저, 아래 코드는 정상적으로 돌아가는 코드의 일부분이다. dynamic이라는 간단한 mlp-based neural net과 dynamic_optim이라는 adma optimizer가 선언되어 있다. dynamic model은 입력 데이터(states,actions)와 ground truth 데이터인 next_states를 통해 loss를 계산하며 이를 토대로 dynamic model의 weight를 갱신한다.

여기서 주의깊게 봐야할 점은 입력 데이터와 ground truth 데이터가 모두 batch size만큼 각 요소 데이터를 포함하는 numpy array이다. 

    def get_dynamic_loss(self, states, actions, next_states):
        states = torch.FloatTensor(states)
        actions = torch.FloatTensor(actions)
        next_states = torch.FloatTensor(next_states)
        pred_next_states = self.dynamic(states, actions)
        
        dynamic_loss = F.mse_loss(next_states, pred_next_states, reduction='sum')
        return dynamic_loss
    def dynamic_update_parameters(self, dynamic_loss):
        self.dynamic_optim.zero_grad()
        dynamic_loss.backward()
        self.dynamic_optim.step() 

반면, 아래 코드는 문제가 되었던 코드이다. 나는 어떠한 이유 batch size만큼의 데이터를 model에 한번에 통과시키고 싶지 않았다. 따라서 위의 코드와는 다르게, 입력 데이터 states와 actions의 각 요소 state, action을 dynamic model에 통과시킨 뒤 각각의 출력값 pred_next_state를 출력하고 이를 pred_next_states라는 list에 추가했다. 그 후, 이 list를 tensor로 변환시키고 loss를 구했다.

이 코드를 실행하면 loss값은 구해지지만 loss가 줄지 않는다.

    def get_dynamic_loss(self, states, actions, next_states):
        states = torch.FloatTensor(states)
        actions = torch.FloatTensor(actions)
        next_states = torch.FloatTensor(next_states)
        pred_next_states = []
        for i in range(len(states)):
            pred_next_state = self.dynamic(states[i], actions[i])
        pred_next_states = torch.FloatTensor(np.array(pred_next_states))

        dynamic_loss = F.mse_loss(next_states, pred_next_states, reduction='sum')
        return dynamic_loss

그 이유는 pred_next_states가 dynamic model에서 나온 output값을 참조(?)하고 있지 않아서였다.
위 코드에서 pred_next_states를 numpy array로 변환을 하는데 이 때 pred_next_states의 value만 copy된다.
따라서 pytorch는 value 값만 copy된 pred_next_states를 일종의 상수로 치부해버리는 것 같다.

이를 해결하기 위해 나는 아래와 같이 간단한 custom mse loss function을 만들어 각 요소간의 loss를 구한 뒤 전체 데이터의 loss를 더해 위 문제를 해결할 수 있었다.

    def mse_loss_fn(self, output, target):
        return torch.mean((output-target)**2)

정리하자면, loss function이 model의 output의 reference를 참조하고 있는지 확인해야 한다.

pytorch가 직관적이고 간편한 만큼 데이터를 다룰 때 많이 신중해야할 것으로 보인다.