How Transformers Learn In-Context Recall Tasks? Optimality, Training Dynamics and Generalization

Abstract

We study the approximation capabilities, convergence speeds and on-convergence behaviors of transformers trained on in-context recall tasks -- which requires to recognize the positional association between a pair of tokens from in-context examples. Existing theoretical results only focus on the in-context reasoning behavior of transformers after being trained for the one gradient descent step. It remains unclear what is the on-convergence behavior of transformers being trained by gradient descent and how fast the convergence rate is. In addition, the generalization of transformers in one-step in-context reasoning has not been formally investigated. This work addresses these gaps. We first show that a class of transformers with either linear, ReLU or softmax attentions, is provably Bayes-optimal for an in-context recall task. When being trained with gradient descent, we show via a finite-sample analysis that the expected loss converges at linear rate to the Bayes risks. Moreover, we show that the trained transformers exhibit out-of-distribution (OOD) generalization, i.e., generalizing to samples outside of the population distribution. Our theoretical findings are further supported by extensive empirical validations, showing that without proper parameterization, models with larger expressive power surprisingly fail to generalize OOD after being trained by gradient descent.

0

Turn this paper into a full lesson

ArcXiv compiles a staged curriculum from this paper: 8-12 lessons across beginner → advanced, synthesised section guides, visuals, flashcards, a quiz, exercises, and on-demand deep dives per section. Grounded in the abstract, never invented.

Discussion (0)

Sign in to join the discussion.

Loading comments…