Parallel Sampling via Counting

Abstract

We show how to use parallelization to speed up sampling from an arbitrary distribution μ on a product space [q]n, given oracle access to counting queries: PX μ[XS=σS] for any S⊂eq [n] and σS ∈ [q]S. Our algorithm takes O(n2/3· polylog(n,q)) parallel time, to the best of our knowledge, the first sublinear in n runtime for arbitrary distributions. Our results have implications for sampling in autoregressive models. Our algorithm directly works with an equivalent oracle that answers conditional marginal queries PX μ[Xi=σi\;\; XS=σS], whose role is played by a trained neural network in autoregressive models. This suggests a roughly n1/3-factor speedup is possible for sampling in any-order autoregressive models. We complement our positive result by showing a lower bound of (n1/3) for the runtime of any parallel sampling algorithm making at most poly(n) queries to the counting oracle, even for q=2.

0

Turn this paper into a lesson

ArcXiv compiles a structured reading guide from this paper's metadata: plain-English importance, contributions, prerequisite concepts, which sections to read first, flashcards, and a quiz. Grounded in the abstract, never invented.

Discussion (0)

Sign in to join the discussion.

Loading comments…