Proximal Policy Optimiztion (PPO)

Paper Link: https://arxiv.org/abs/1707.06347

  • Policy Gradient 방법 중 하나로 실험 결과에서 대부분의 다른 Policy Gradient 보다 좋은 성능을 보임.
  • Sthocastic gradient ascent
  • 다른 Policy gradient 알고리즘들은 minibatch 하나당 한번의 gradient 업데이트를 하고 끝내는 반면 PPO는 minibatch 하나를 여러번의 epoch에 걸쳐 사용하여 gradient를 업데이트 함.
  • TRPO에 비해서 구현이 비교적 쉬움.

다른 Policy Gradient에 비해 나은 점?

  • Vanilla Policy Gradient 방법들은 data efficiency 측면과 안정성 측면에서 좋지 않은 모습을 보임.
  • TRPO 같은 경우는 내용이 너무 복잡함.
  • TRPO와 달리 first-order optimization을 통해 gradient update를 진행함.

Policy Optimization이란?

  • Advantageous Actor-Critic의 loss function은 다음과 같음. Advantage function은 보통 Q(s,a)-V(s)이다. 아래 loss function을 maximize하는 방향으로 학습이 진행. 왜 minimize가 아니고 maximize하는지 모르겠다면 REINFORCE 게시물 참조. 간단하게 말하면 아래 loss function은 value function과 같기 때문에 최대로 만들어줘야 한다.

    image

  • 위 loss function을 따라서 on-policy 업데이트를 계속 진행하면 REINFORCE 알고리즘에서 살펴 봤듯이 학습이 굉장히 불안정하다. 논문에서는 ‘destructively large policy updates’라고 표현하고 있는데, gradient가 너무 급격하게 바뀌는 경우가 많기 때문인 것 같다.

TRPO에서 차용한 loss function

TRPO의 loss function (‘surrogate’ objective)

image

결과론적으로 위 loss function에 gradient를 취해준 값은 Advantageous Actor-Critic의 loss function에 gradient를 취해준 값과 같은데, importance sampling을 통해 같음을 증명할 수 있다.

image

Importance sampling은 위와 같다. 간단히 말해서 P 분포에서 sampling한 x를 input으로 가지는 f 함수의 기댓값은 P와 다른 Q 분포에서 x를 sampling 함으로써 구할 수 있다는 것이다.

컨셉은 이렇고, 직접적으로 왜 같은지는 아래에서 확인할 수 있다.

image

KL이 무엇인가?

loss function에 붙은 constraint를 살펴보면 KL이라는 것을 볼 수 있다. KL은 Kullback-Leibler divergence로 두 확률분포의 차이를 계산하는 데에 사용하는 함수이다. 즉, 업데이트 전의 θold와 업데이트 후의 θ의 차이나는 정도를 제한함으로써 급격한 변화를 막겠단 뜻이다. 하지만 위와 같은 constraint가 걸린 optimization 문제는 풀기에 상당히 복잡하므로 TRPO에서는 constraint form 대신 아래와 같은 penalty form을 사용해서 optimization을 진행하는 것을 제시했다.

image

하지만 TRPO는 위와 같은 penalty form 대신에 hard constraint form을 사용했는데, 이는 다양한 문제에 적용 가능한 하나의 β를 찾기 어려울 뿐 아니라 하나의 문제에서도 고정된 β 값은 성능이 좋지 못한 것을 발견했다.

Clipped Surrogate Objective

위와 같은 문제를 해결하기 위해 PPO에서는 KL divergence를 쓰는 대신에 Clipped Surrogate Objective라는 modified된 loss 함수를 제시했다.

image

여기서 rt(θ)는 다음과 같다.

image

만약 ε = 0.2 라고 한다면 clip(rt(θ), 1-ε, 1+ε)은 항상 0.8에서 1.2 사이 값을 가지게 된다. 즉 업데이트 전 policy와 업데이트 후 policy의 probability ratio를 일정 범위 안에 속하도록 고정하겠다는 의미이다. 이렇게 clip된 값과 원래의 rt(θ) 값 중 더 작은 값을 쓰도록 하여 update가 너무 급격하게 일어나지 않게 만들어준다.

image

위와 같은 clipped surrogate objective를 사용하여 update하는 과정을 살펴보면 아래와 같다. 우선 advantage가 0보다 크다는 의미는 해당 action이 좋은 action이므로 다음 번에는 해당 action을 할 확률을 올려야 된다는 소리이다. 그래서 r 값은 증가하게 되는데 1+ε 지점에서 clip이 일어나게 된다. 아무리 정책을 좋게 만든다 하더라도 너무 급하게 바꾸지 않는다는 뜻 같다. 어찌됐던 ε이 0.2이고, advantage가 0보다 크다면 새로운 정책의 확률은 기존 정책의 확률 보다 1.2배 이상 높아지지 못한다.

마찬가지로 advantage가 0보다 크다는 의미는 해당 action이 나쁜 action이므로 해당 action을 할 확률을 낮춰야 한다는 의미이다. 여기서도 똑같이 1-ε에서 clip 되고, 이보다 크게 정책을 변화 시킬 수 없다.

Adaptive KL Penalty Coefficient

Clipped Surrogate Objective를 loss function으로 쓰는 게 아니라 TRPO에서 제시한 것처럼 KL penalty를 쓰되, 고정된 β 값이 아니라 adaptive한 β값을 사용한다. 하지만 논문에서 말하길, clipped surrogate objective를 쓰는 게 성능이 더 좋다고 한다. 고로 그냥 이런 게 있구나 하고 넘어가도록 하자.

image

d<dtarg/1.5 일 경우 β를 감소시키는 것을 볼 수 있는데, d가 작다는 뜻은 기존 정책과 바뀐 정책 간의 차이가 크지 않다는 뜻이므로 penalty를 조금만 줘야한다는 뜻이 된다. 반대로 d가 너무 크면 정책이 급격하게 바뀌었다는 뜻이므로 penalty를 크게 주어 loss function을 감소(다시 짚고 가면 우린 현재 loss function을 maximize하길 원한다)시켜 반대방향(d가 작아지도록)으로 학습이 되게 한다.

Algorithm

최종적인 loss function은 아래와 같이 정의된다.

image

Loss function을 다 합쳐서 학습을 진행하는 이유는 Pytorch 등에서 제공하는 automatic differentiation을 사용하기 위해서이다. 다시 한 번 강조하자면 위 loss function을 maximize하기 위한 stochastic gradient algorithm을 사용한다. 맨 마지막 term은 entropy로 sufficient exploration을 위해 더해진 term이다.

Pseudocode for training

image

  1. 우선 T (episode length 보다 훨씬 작음) timestep 동안 update를 위한 sample을 모으고 advantage를 계산 image image

  2. N개의 parallel actor을 사용하여 T timestep 동안 데이터를 모음
  3. 모은 데이터에서 M 크기의 minibatch를 만들어서 K번의 epoch 동안 stochastic gradient ascent 반복함
  4. θold을 업데이트 된 θ로 바꿈

Implementation

마찬가지로 팡요랩의 소스코드를 사용.

class PPO(nn.Module):
    def __init__(self):
        super(PPO, self).__init__()
        self.data = []

        self.fc1 = nn.Linear(4, 256)
        self.fc_pi = nn.Linear(256, 2)
        self.fc_v = nn.Linear(256, 1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def pi(self, x, softmax_dim=0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob

    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v

    def put_data(self, transition):
        self.data.append(transition)

    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst = [], [], [], [], [], []
        for transition in self.data:
            s, a, r, s_prime, prob_a, done = transition

            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            prob_a_lst.append([prob_a])
            done_mask = 0 if done else 1
            done_lst.append([done_mask])

        s, a, r, s_prime, done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                                              torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
                                              torch.tensor(done_lst, dtype=torch.float), torch.tensor(prob_a_lst)
        self.data = []
        return s, a, r, s_prime, done_mask, prob_a

    def train_net(self):
        s, a, r, s_prime, done_mask, prob_a = self.make_batch()

        for i in range(K_epoch):
            td_target = r + gamma * self.v(s_prime) * done_mask
            delta = td_target - self.v(s)
            delta = delta.detach().numpy()

            advantage_lst = []
            advantage = 0.0
            for delta_t in delta[::-1]:
                advantage = gamma * lmbda * advantage + delta_t[0]
                advantage_lst.append([advantage])
            advantage_lst.reverse()
            advantage = torch.tensor(advantage_lst, dtype=torch.float)

            pi = self.pi(s, softmax_dim=1)
            pi_a = pi.gather(1, a)
            ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a))  # a/b == exp(log(a)-log(b))

            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantage
            loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s), td_target.detach())

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

Comments