On a few pitfalls in KL divergence gradient estimation for RL

Abstract

We point out a few pitfalls in implementing gradient estimation for KL divergence in RL training for LLM, as seen in a number of open source projects and papers. The first major pitfall is to differentiate through the KL estimate as loss functions to minimize KL divergence. We show that such implementations are generally incorrect and do not produce the desired KL gradient. Secondly, we show that some implementations do not account for the sequential nature of the estimation problem and produce a partial gradient at best. We demonstrate the impact of such issues with illustrative tabular and LLM experiments, and show the correct way to implement the KL gradient.

0

Turn this paper into a full lesson

ArcXiv compiles a staged curriculum from this paper: 8-12 lessons across beginner → advanced, synthesised section guides, visuals, flashcards, a quiz, exercises, and on-demand deep dives per section. Grounded in the abstract, never invented.

Discussion (0)

Sign in to join the discussion.

Loading comments…