Stopping Computation for Converged Tokens in Masked Diffusion-LM Decoding

Abstract

Masked Diffusion Language Models generate sequences via iterative sampling that progressively unmasks tokens. However, they still recompute the attention and feed-forward blocks for every token position at every step -- even when many unmasked tokens are essentially fixed, resulting in substantial waste in compute. We propose SureLock: when the posterior at an unmasked position has stabilized across steps (our sure condition), we lock that position -- thereafter skipping its query projection and feed-forward sublayers -- while caching its attention keys and values so other positions can continue to attend to it. This reduces the dominant per-iteration computational cost from O(N2d) to O(MNd) where N is the sequence length, M is the number of unlocked token positions, and d is the model dimension. In practice, M decreases as the iteration progresses, yielding substantial savings. On LLaDA-8B, SureLock reduces algorithmic FLOPs by 30--50% relative to the same sampler without locking, while maintaining comparable generation quality. We also provide a theoretical analysis to justify the design rationale of SureLock: monitoring only the local KL at the lock step suffices to bound the deviation in final token probabilities. Our project page is available at https://daioba.github.io/surelock .

0

Turn this paper into a lesson

ArcXiv compiles a structured reading guide from this paper's metadata: plain-English importance, contributions, prerequisite concepts, which sections to read first, flashcards, and a quiz. Grounded in the abstract, never invented.

Discussion (0)

Sign in to join the discussion.

Loading comments…