Proximal Policy
Optimization Explained
Every symbol defined. Every intuition built from scratch. A step-by-step animated walkthrough of how PPO trains a language model, from the first token to the final gradient update.
Why do we need RL at all?
Picture the model you have right now. It saw a few trillion tokens during pretraining. It can finish sentences from books, write code, recite Wikipedia, and even hold a conversation if you nudge it the right way. But it has one big limitation: it was only ever taught to predict the next token. It was never taught that this answer is better than that answer.
Supervised fine tuning (SFT) helps a little. You take a dataset of question and answer pairs written by humans, and you fine tune the model to imitate them. After SFT, the model knows the shape of a good answer. But SFT teaches it the average of your demonstrations, not the best of them. It also gives the model no notion of "this response was better than that one." It only sees one target per prompt.
What we actually want is to push the model toward generations that humans prefer. That is a comparison task. Comparison data is cheap to collect: show two answers and click the better one. Imitation data is expensive: write the full perfect answer yourself.
So the question becomes: how do we use comparison signal to update the weights of a pretrained transformer? The answer is reinforcement learning. PPO is the specific algorithm that almost every RLHF pipeline uses, including the original InstructGPT and ChatGPT trainings.
The setup: what you walk in with
By the time you start PPO, you should already have:
- An SFT model. Your pretrained LLM, fine tuned on a dataset of high-quality demonstrations. This is the starting point for the policy that PPO will update. Call it ฯSFT.
- A reward model. A separate model that takes a (prompt, response) pair and returns a scalar score. We will see exactly how this is trained in the next section. Call it R(x, y).
- A pile of prompts. Just prompts. No answers needed. This is the dataset PPO will train on. The model will generate its own responses on the fly.
That last point is worth pausing on. PPO does not train on a (prompt, target answer) dataset. It only needs prompts. The model produces candidate answers itself during training, scores them with the reward model, and updates its weights based on those scores. That is the fundamental shift from supervised learning.
We will use one prompt the entire way through this post. The user asks the model:
What are some healthy breakfast options for someone with diabetes?
And the model produces this response (one of several it could sample):
"Good breakfast options for someone with diabetes include steel-cut oatmeal topped with berries and nuts, plain Greek yogurt with chia seeds, or a vegetable omelette with whole-grain toast. These choices release glucose slowly and pair carbs with protein and fibre."
A reward model scores this response R = 0.91. We will trace what PPO does with that signal, all the way down to the parameter update.
Where does the reward come from?
The reward signal is the heart of the whole pipeline. Without it, PPO has nothing to optimize. So how do you build a function that takes any (prompt, response) and tells you how good the response is?
Step 1: collect preferences
Take your SFT model. For each prompt in your dataset, sample k different responses by setting the temperature high enough that you get variation. A typical setup is k = 4 to 8 responses per prompt. Now you have a bag of (prompt, response1, response2, ..., responsek) tuples.
Show each pair of responses to a human (or sometimes to a stronger LLM acting as a judge) and ask: which one is better? The human clicks. You record the preference. The dataset looks like:
{
"prompt": "What are some healthy breakfast options for someone with diabetes?",
"chosen": "Good options include steel-cut oatmeal with berries, plain Greek yogurt with chia seeds, or a vegetable omelette with whole-grain toast...",
"rejected": "Just eat whatever you want, breakfast doesn't really matter."
}
{
"prompt": "Explain why the sky is blue.",
"chosen": "Sunlight is made of many colors. As it passes through the atmosphere, the shorter blue wavelengths scatter more than the others, so the sky looks blue from the ground.",
"rejected": "Because of the ocean reflection."
}
Step 2: train the reward model
Take a fresh copy of your SFT model. Replace its language modeling head with a single scalar head (one linear layer, output dim 1). Now it takes a (prompt, response) pair and returns one number.
Train it with the Bradley-Terry loss. Given a chosen and a rejected response for the same prompt, the loss wants the chosen one to score higher than the rejected one:
After a few epochs, the reward model has internalized human taste, at least the slice of taste captured by your annotations. It can now score any new response on any prompt.
How is this any better than just training on demonstrations?
Demonstrations tell you what a good answer looks like. Preferences tell you what makes one answer better than another. The second signal is finer-grained and much cheaper to collect: ranking two answers takes a few seconds; writing a perfect answer takes minutes. The reward model is the trick that turns thousands of cheap rank labels into a continuous score function you can apply to any new response.
Translating LLMs into RL language
RL papers use a vocabulary that does not appear anywhere in a transformer paper. Here is the cheat sheet:
| RL term | What it actually is for an LLM |
|---|---|
| policy ฯฮธ | Your transformer. Outputs a probability distribution over the vocabulary at every position. |
| action at | The token chosen at position t. Sampled from ฯฮธ(ยท | st). |
| state st | The prompt x plus every token generated so far: (a1, ..., at-1). |
| trajectory ฯ | One complete (prompt, response) pair. The whole episode. |
| horizon T | How many tokens the model generated before stopping. Variable per episode. |
| reward rt | Almost always zero for t < T. The reward model only fires at the end of the response. |
| return Gt | Sum of future rewards from step t. For LLMs, basically R(x, y) for every t. |
Two things are worth dwelling on.
The action space is enormous. Most RL papers picture games with maybe a dozen possible actions. For an LLM the action space is the whole vocabulary, 30,000 to 200,000 tokens depending on tokenizer. That changes the math of exploration: random sampling becomes incredibly weak because nearly every random choice is gibberish.
The reward is sparse. You generate hundreds of tokens. The reward model fires exactly once, at the end. So one scalar reward has to be back-propagated as a learning signal across the entire sequence. PPO's job is to figure out which tokens deserve credit (or blame) for that final score. This is called credit assignment, and it is the central technical difficulty.
Why not just SFT on high-reward outputs?
A natural reaction at this point: "Why not just sample a thousand responses, keep only the ones that score high, and supervised-fine-tune on those?" This is called rejection sampling fine tuning or best-of-N SFT. It works. It is also simpler than PPO. So why do people bother with PPO?
Two reasons:
- SFT cannot push probability away from bad outputs. When you fine tune on token sequences, the cross-entropy loss only ever increases the probability of the tokens that appear in your training data. It never explicitly punishes the tokens that should not appear. PPO's clipped objective, in contrast, has positive-advantage and negative-advantage cases: it increases good-token probability and decreases bad-token probability in the same step.
- SFT throws away most of your data. If you sample 8 responses and keep only the top 1, you discarded 7/8 of the compute spent generating. PPO uses all 8 responses, weighted by how good or bad they were. This is a much denser learning signal per dollar of inference.
That said, hybrids are common in practice. Many production pipelines do rejection sampling SFT first to get a strong starting point, then run PPO from there. The combination is more reliable than either method alone.
Policy gradient: the foundation
The goal is to maximize the expected reward our policy gets across the prompt distribution:
We cannot just take โJ(ฮธ) directly because the expectation is over samples from ฯฮธ itself, and the sampling depends on ฮธ. The policy gradient theorem handles this by moving the gradient inside the expectation:
The intuition is exactly what it looks like: multiply the reward by the gradient of log-probability of the action you took, and that is the direction to push your parameters in. If R is positive, the gradient step makes those actions more likely. If R is negative, less likely.
For the running example, the trajectory ฯ is the seven-ish tokens of the answer about oatmeal. R(ฯ) = 0.91. The gradient step nudges the model toward producing those tokens (in that context) more often. Simple in principle.
Two reasons. Variance: a single trajectory's reward can swing wildly between batches, so the gradient estimate is noisy. You need huge batches or many epochs to average it out. Sample inefficiency: as soon as you take one gradient step, your old samples are technically off-policy and you should throw them out. Generating a fresh batch every step is wildly expensive.
The next two sections fix variance (with baselines and advantages) and sample inefficiency (with importance sampling and clipping). Those two fixes together are PPO.
Baselines and advantages
Here is the trick. The policy gradient is unchanged in expectation if you subtract any quantity that does not depend on the action:
Why is the expectation unchanged? Because the expectation of (โ log ฯ ยท b(s)) is zero whenever b(s) doesn't depend on the action โ a standard property called the score function identity. So subtracting any state-only term is a free variance reduction.
The best baseline is the one that makes (R โ b) as small as possible while still being state-only. The optimal choice (it can be shown) is the value function: the expected return from state s under the current policy. We train a small network to estimate it.
And then the advantage is the gap:
Two networks now. The actor (your LLM with its usual LM head) generates tokens. The critic (your LLM with a scalar head bolted on) estimates V. In practice both share the same transformer body for memory reasons, and only the heads differ.
The critic is trained with plain regression against actual returns:
Off-policy and importance sampling
If you stick to REINFORCE-with-baseline, you must regenerate the entire batch of trajectories after every single gradient step. That is incredibly slow when generating one trajectory costs hundreds of forward passes through a 7B-parameter model.
The workaround is importance sampling. You collect rollouts with a frozen snapshot of the policy, call it ฯold. You then run several gradient updates on the current policy ฯฮธ, using the same rollouts. To compensate for the mismatch (the data was sampled from ฯold, but you are training ฯฮธ), you weight each token by the ratio:
The surrogate objective becomes:
This is unbiased as long as ฯ stays small. The problem is that ฯ can blow up. If the new policy starts assigning much higher probability to some token than the old one did, ฯ could become 5, 10, 100. The estimator's variance explodes. Worse, a few outlier ratios can dominate the gradient.
This is the exact problem PPO solves.
The clipped surrogate, finally
The clipped objective:
That min looks strange at first. Why is it there? Walk through the four cases:
The new policy hasn't pushed too hard yet. The unclipped term is in effect: gradient pushes ฯ up further, making the action more likely.
The new policy has already increased this action's probability by more than ฮต. The clip caps the objective. Pushing ฯ higher gives zero extra reward, so the gradient stops.
The new policy hasn't pushed this action down enough yet. The unclipped term lets the gradient continue pushing ฯ down.
The new policy already decreased this action's probability by more than ฮต. The clip caps it. The gradient stops driving it lower.
The min ensures the bound is one-sided in the direction we don't want to go. We are still happy to take gradient steps that pull the ratio back toward 1 (corrective steps), but we won't take ones that push it further away.
The KL penalty: anchoring to the SFT model
The reward model is not perfect. It was trained on a finite set of preferences. If PPO is allowed to optimize hard enough, it will find weird outputs that the reward model accidentally rates highly โ outputs no human would actually prefer. This is called reward hacking, and it is the single most common failure mode in RLHF.
The fix is to penalize the policy whenever it strays from the original SFT model. We add a per-token KL term:
The full PPO objective is then:
Four terms, three loss coefficients to tune in practice:
- cv: weight on the critic loss. Usually 0.5 to 1.0.
- ce: tiny entropy bonus that discourages collapse. Usually 0.001 to 0.01.
- ฮฒ: how hard you pull back toward the SFT model. The most consequential hyperparameter. Too small and the model reward-hacks. Too large and it never learns anything new. Typical range: 0.01 to 0.2.
GAE: smoother advantages
Advantage estimation has a knob: how many future steps do we look at when computing At?
- Look at just one step: At โ rt + ฮณ V(st+1) โ V(st). Low variance, biased by V's mistakes.
- Look all the way to the end: At = R(ฯ) โ V(st). Unbiased, but high variance because R(ฯ) is one noisy number.
GAE is the exponentially-weighted average of all the in-between options. Define the one-step TD error:
Then GAE is:
For LLMs the typical values are ฮณ = 1.0 (we don't discount future tokens within one response) and ฮป = 0.95 (heavily weight near-term TD errors, but not exclusively). The recursive form makes it cheap to compute backward in one pass:
def compute_gae(rewards, values, gamma=1.0, lam=0.95):
"""
rewards: tensor of shape [B, T] โ per-token rewards (mostly zero plus KL penalties)
values: tensor of shape [B, T+1] โ V(s_t) for every state, including terminal
returns: advantages [B, T], returns [B, T]
"""
T = rewards.shape[1]
advantages = torch.zeros_like(rewards)
last_gae = 0.0
for t in reversed(range(T)):
delta = rewards[:, t] + gamma * values[:, t + 1] - values[:, t]
last_gae = delta + gamma * lam * last_gae
advantages[:, t] = last_gae
returns = advantages + values[:, :T]
return advantages, returns
One additional step in practice: normalize the advantages per batch (subtract mean, divide by standard deviation) before plugging them into the clip objective. This decouples the learning rate from the absolute scale of rewards, which is a free stability win.
The full training loop, end to end
This is the loop that almost every RLHF training script implements. A few hundred lines of bookkeeping around the math we just walked through. Read it once carefully, and you have the full picture.
# โโโ Models โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
policy = load_sft() # ฯ_ฮธ โ being trained
ref = load_sft().requires_grad_(False) # ฯ_ref โ frozen leash
value_head = ValueHead(policy) # V_ฯ โ scalar head on policy body
reward_m = load_reward_model().eval() # frozen reward model
old = copy.deepcopy(policy).eval() # ฯ_old โ frozen during rollout
optimiser = AdamW(policy.parameters(), lr=1e-6)
# โโโ Outer loop โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
for step in range(num_steps):
# โ Phase 1: rollout โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
prompts = sample_prompts(batch_size)
with torch.no_grad():
responses = old.generate(prompts, max_new_tokens=512, temperature=1.0)
old_logprobs = old.log_probs(responses) # [B, T]
ref_logprobs = ref.log_probs(responses) # [B, T]
values = value_head(responses) # [B, T+1]
R = reward_m(prompts, responses) # [B] task reward (one number per response)
kl_token = old_logprobs - ref_logprobs # [B, T] per-token KL estimate
per_tok_r = -beta * kl_token # KL shaping at every step
per_tok_r[:, -1] += R # task reward fires only at end
# โ Phase 2: advantages via GAE โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
advantages, returns = compute_gae(per_tok_r, values, gamma=1.0, lam=0.95)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# โ Phase 3: K inner epochs of clipped updates โโโโโโโโโโโโโโโโโโโ
for epoch in range(K_epochs): # K = 1 to 4 typically
for batch in shuffle_minibatches(...):
new_logprobs = policy.log_probs(batch.responses)
ratio = (new_logprobs - batch.old_logprobs).exp()
unclipped = ratio * batch.advantages
clipped = ratio.clamp(1 - eps, 1 + eps) * batch.advantages
loss_pg = -torch.min(unclipped, clipped).mean()
new_values = value_head(batch.responses)
loss_v = F.mse_loss(new_values, batch.returns)
loss_ent = -policy.entropy(batch.responses).mean()
loss = loss_pg + c_v * loss_v + c_e * loss_ent
optimiser.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
optimiser.step()
# โ Phase 4: sync ฯ_old to current ฯ_ฮธ โโโโโโโโโโโโโโโโโโโโโโโโโโ
old.load_state_dict(policy.state_dict())
What sizes do these things actually have?
A typical RLHF run on a 7B model looks roughly like this:
- Batch size: 64 to 256 prompts per outer step.
- Max response length: 512 or 1024 tokens.
- Inner epochs K: 1 to 4.
- Clip ฮต: 0.1 or 0.2.
- KL coefficient ฮฒ: 0.01 to 0.1, often adapted on the fly.
- Learning rate: 1e-6 to 5e-6. Much lower than during SFT.
- Number of outer steps: a few thousand, usually less than 100k.
The dominant cost is generation in phase 1, not the gradient updates. People work hard on vLLM-style fast inference engines just for the rollout step.
Practical pitfalls and how to spot them
Things that will go wrong, eventually, on every PPO run:
Classic reward hacking. The model has discovered a quirk in the reward model that humans would not endorse. Increase ฮฒ, retrain the reward model on the new failure cases, or both.
ฮฒ is too small or your learning rate is too high. Narrow ฮต (try 0.1 instead of 0.2), drop the learning rate by an order of magnitude, or use an adaptive ฮฒ controller that tightens when KL is high.
The policy collapsed onto one response. Raise c_e, widen ฮต so the clip stops biting on every token, or temperature-sample during rollout.
Value targets are scale-unstable. Standardize advantages, clip the value loss the same way you clip the policy loss, and double-check that your KL shaping isn't producing huge negative per-token rewards.
Clip is too tight; almost every token already sits at the boundary. Widen ฮต, or warm up with a few hundred steps of vanilla policy gradient before turning the clip on.
Phase 1 (generation) dominates. Use a fast inference engine for the rollouts, cache the reference and old logprobs once per rollout, and consider whether DPO would do the job with one forward pass per sample.
mean reward, KL(ฯ_ฮธ โ ฯ_ref), policy entropy, fraction of clipped tokens, mean |advantage|, value loss, gradient norm, longest response in batch. If any of these surprise you, pause and look at twenty random samples before continuing.
One last sanity check: read the samples
Every metric can look fine while the actual outputs degrade. Set up an evaluation loop that runs the current policy on a held-out set of prompts every few hundred steps and writes the responses to disk. Read them with your own eyes. PPO has a special talent for finding loopholes the reward model never noticed.
References and further reading
- Schulman, J. et al. (2017). Proximal Policy Optimization Algorithms. arXiv:1707.06347 ยท the original PPO paper
- Schulman, J. et al. (2016). High-Dimensional Continuous Control Using Generalized Advantage Estimation. arXiv:1506.02438 ยท the GAE paper
- Ouyang, L. et al. (2022). Training language models to follow instructions with human feedback. arXiv:2203.02155 ยท InstructGPT, the first big RLHF result
- Stiennon, N. et al. (2020). Learning to summarize with human feedback. arXiv:2009.01325 ยท the OpenAI summarisation paper that kickstarted modern RLHF
- Bai, Y. et al. (2022). Training a Helpful and Harmless Assistant with RLHF. arXiv:2204.05862 ยท the Anthropic HH paper
- Engstrom, L. et al. (2020). Implementation Matters in Deep Policy Gradients. arXiv:2005.12729 ยท what actually drives PPO's empirical performance
- Huang, S. et al. (2024). The N Implementation Details of RLHF with PPO. HuggingFace blog post ยท the practical bible if you are going to implement this
- Ahmadian, A. et al. (2024). Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback. arXiv:2402.14740 ยท the case against PPO's complexity for LLMs