Dispatch-Aware Ragged Attention for Pruned Vision Transformers
Abstract
Token pruning methods for Vision Transformers (ViTs) promise quadratic reductions in attention FLOPs by dropping uninformative patches. Yet standard variable-length attention APIs -- including FlashAttention-2's varlen and PyTorch's NestedTensor SDPA -- fail to translate these savings into proportional wall-clock gains at the short post-pruning sequence lengths typical of ViTs (≤197 tokens). We identify a dispatch-overhead bottleneck: at these lengths, host-side kernel dispatch consumes 50\,μs regardless of workload, exceeding the actual GPU compute time at moderate-to-high pruning rates. We present a lightweight bidirectional Triton attention kernel whose dispatch floor is 24\,μs -- roughly 2.17× lower than FlashAttention-2 varlen -- allowing pruning savings to become visible in wall-clock time. Integrated into a complete pack-attend-unpack pipeline and evaluated on an NVIDIA RTX 4000 Ada Generation GPU, our system achieves 1.88× end-to-end throughput over padded PyTorch SDPA at standard 224×224 inputs, scaling to 2.51× at 384×384. Against FlashAttention-2 varlen -- the strongest baseline -- our kernel delivers 9-12\% higher throughput at serving batch sizes (BS=1-4), and 2.17× lower kernel latency at 80\% token pruning. Numerical correctness is verified with max absolute logit difference <0.004 and bit-exact top-1 predictions.
Turn this paper into a lesson
ArcXiv compiles a structured reading guide from this paper's metadata: plain-English importance, contributions, prerequisite concepts, which sections to read first, flashcards, and a quiz. Grounded in the abstract, never invented.