Learning quadratic neural networks in high dimensions: SGD dynamics and scaling laws
Abstract
We study the optimization and sample complexity of gradient-based training of a two-layer neural network with quadratic activation function in the high-dimensional regime, where the data is generated as f*(x) Σj=1rλj σ( θj, x), x N(0,Id), σ is the 2nd Hermite polynomial, and θj j=1r ⊂ Rd are orthonormal signal directions. We consider the extensive-width regime r dβ for β ∈ [0, 1), and assume a power-law decay on the (non-negative) second-layer coefficients λj j-α for α ≥ 0. We present a sharp analysis of the SGD dynamics in the feature learning regime, for both the population limit and the finite-sample (online) discretization, and derive scaling laws for the prediction risk that highlight the power-law dependencies on the optimization time, sample size, and model width. Our analysis combines a precise characterization of the associated matrix Riccati differential equation with novel matrix monotonicity arguments to establish convergence guarantees for the infinite-dimensional effective dynamics.
Turn this paper into a full lesson
ArcXiv compiles a staged curriculum from this paper: 8-12 lessons across beginner → advanced, synthesised section guides, visuals, flashcards, a quiz, exercises, and on-demand deep dives per section. Grounded in the abstract, never invented.