Fast Attention Requires Bounded Entries
Abstract
In modern machine learning, inner product attention computation is a fundamental task for training large language models such as Transformer, GPT-1, BERT, GPT-2, GPT-3 and ChatGPT. Formally, in this problem, one is given as input three matrices Q, K, V ∈ [-B,B]n × d, and the goal is to construct the matrix Att(Q,K,V) := diag(A 1n)-1 A V ∈ Rn × d, where A = (QK/d) is the `attention matrix', and is applied entry-wise. Straightforward methods for this problem explicitly compute the n × n attention matrix A, and hence require time (n2) even when d = no(1) is small. In this paper, we investigate whether faster algorithms are possible by implicitly making use of the matrix A. We present two results, showing that there is a sharp transition at B = ( n). If d = O( n) and B = o( n), there is an n1+o(1) time algorithm to approximate Att(Q,K,V) up to 1/poly(n) additive error. If d = O( n) and B = ( n), assuming the Strong Exponential Time Hypothesis from fine-grained complexity theory, it is impossible to approximate Att(Q,K,V) up to 1/poly(n) additive error in truly subquadratic time n2 - (1). This gives a theoretical explanation for the phenomenon observed in practice that attention computation is much more efficient when the input matrices have smaller entries.
Turn this paper into a lesson
ArcXiv compiles a structured reading guide from this paper's metadata: plain-English importance, contributions, prerequisite concepts, which sections to read first, flashcards, and a quiz. Grounded in the abstract, never invented.