Pruning is Optimal for Learning Sparse Features in High-Dimensions
Abstract
While it is commonly observed in practice that pruning networks to a certain level of sparsity can improve the quality of the features, a theoretical explanation of this phenomenon remains elusive. In this work, we investigate this by demonstrating that a broad class of statistical models can be optimally learned using pruned neural networks trained with gradient descent, in high-dimensions. We consider learning both single-index and multi-index models of the form y = σ*(V x) + ε, where σ* is a degree-p polynomial, and V ∈ Rd × r with r d, is the matrix containing relevant model directions. We assume that V satisfies a certain q-sparsity condition for matrices and show that pruning neural networks proportional to the sparsity level of V improves their sample complexity compared to unpruned networks. Furthermore, we establish Correlational Statistical Query (CSQ) lower bounds in this setting, which take the sparsity level of V into account. We show that if the sparsity level of V exceeds a certain threshold, training pruned networks with a gradient descent algorithm achieves the sample complexity suggested by the CSQ lower bound. In the same scenario, however, our results imply that basis-independent methods such as models trained via standard gradient descent initialized with rotationally invariant random weights can provably achieve only suboptimal sample complexity.
Turn this paper into a full lesson
ArcXiv compiles a staged curriculum from this paper: 8-12 lessons across beginner → advanced, synthesised section guides, visuals, flashcards, a quiz, exercises, and on-demand deep dives per section. Grounded in the abstract, never invented.