Gradient Descent for Low-Rank Functions

Abstract

Several recent empirical studies demonstrate that important machine learning tasks, e.g., training deep neural networks, exhibit low-rank structure, where the loss function varies significantly in only a few directions of the input space. In this paper, we leverage such low-rank structure to reduce the high computational cost of canonical gradient-based methods such as gradient descent (GD). Our proposed Low-Rank Gradient Descent (LRGD) algorithm finds an ε-approximate stationary point of a p-dimensional function by first identifying r ≤ p significant directions, and then estimating the true p-dimensional gradient at every iteration by computing directional derivatives only along those r directions. We establish that the "directional oracle complexities" of LRGD for strongly convex and non-convex objective functions are O(r (1/ε) + rp) and O(r/ε2 + rp), respectively. When r p, these complexities are smaller than the known complexities of O(p (1/ε)) and O(p/ε2) of in the strongly convex and non-convex settings, respectively. Thus, LRGD significantly reduces the computational cost of gradient-based methods for sufficiently low-rank functions. In the course of our analysis, we also formally define and characterize the classes of exact and approximately low-rank functions.

0

Turn this paper into a lesson

ArcXiv compiles a structured reading guide from this paper's metadata: plain-English importance, contributions, prerequisite concepts, which sections to read first, flashcards, and a quiz. Grounded in the abstract, never invented.

Discussion (0)

Sign in to join the discussion.

Loading comments…