A Faster Generalized Two-Stage Approximate Top-K
Abstract
We consider the Top-K selection problem, which aims to identify the largest K elements in an array. Top-K selection arises in many machine learning algorithms and often becomes a bottleneck on accelerators, which are optimized for dense matrix multiplications. To address this problem, Chern et al. (2022) proposed a fast two-stage approximate Top-K algorithm that: (i) partitions the input array into equal-sized chunks and selects the top-1 element from each partition; and (ii) sorts the resulting smaller subset and returns the top K elements. In this paper, we generalize the first stage so that each partition selects the top K' elements (for 1 ≤ K' ≤ K). Our contributions include: (i) an expression for the expected recall of this generalized algorithm under random partitioning, and a demonstration that choosing K' > 1 with fewer partitions in the first stage more effectively reduces the input size to the second stage while maintaining the same expected recall as the original algorithm; (ii) a bound on the expected recall of the original algorithm as a function of the algorithm parameters that is provably tighter by a factor of 2 than the bound reported by Chern et al. (2022); and (iii) an implementation of our algorithm on Cloud TPUv5e that achieves approximately an order of magnitude speedup over the original algorithm without sacrificing recall.
Turn this paper into a full lesson
ArcXiv compiles a staged curriculum from this paper: 8-12 lessons across beginner → advanced, synthesised section guides, visuals, flashcards, a quiz, exercises, and on-demand deep dives per section. Grounded in the abstract, never invented.