Finite-Time Analysis of Gradient Descent for Shallow Transformers
Abstract
Understanding why Transformers perform so well remains challenging due to their non-convex optimization landscape. In this work, we analyze a shallow Transformer with m independent heads trained by projected gradient descent in the kernel regime. Our analysis reveals two main findings: (i) the width required for nonasymptotic guarantees scales only logarithmically with the sample size n, and (ii) the optimization error is independent of the sequence length T. This contrasts sharply with recurrent architectures, where the optimization error can grow exponentially with T. The trade-off is memory: to keep the full context, the Transformer's memory requirement grows with the sequence length. We validate our theoretical results numerically in a teacher-student setting and compare Transformers with recurrent architectures on an autoregressive task.
Turn this paper into a full lesson
ArcXiv compiles a staged curriculum from this paper: 8-12 lessons across beginner → advanced, synthesised section guides, visuals, flashcards, a quiz, exercises, and on-demand deep dives per section. Grounded in the abstract, never invented.