Body

Conventional statistical inference methods are typically developed for models simple enough to admit tractable estimators through carefully designed iterative algorithms. In contrast, modern deep learning models are enormously complex, yet are trained by simple gradient-descent-type algorithms, often without any provable guarantee of algorithmic convergence to global/local optima.

 

Can we reconcile classical inference principles with these highly complicated, modern learning paradigms? In this talk, we will present a new inference framework addressing this question, by showing that valid statistical inference can be performed along the entire gradient descent trajectory, iteration by iteration, without requiring convexity of the loss landscape or convergence of the algorithm. 

 

To illustrate this concept, we begin with a single-index (one-layer neural network) regression model and demonstrate how gradient descent iterates can be "debiased", at each iteration, to yield valid confidence intervals for the underlying signal and consistent estimates of generalization errors. We then extend this paradigm to the much more challenging setting of learning with general multi-layer neural networks in their full complexity, where the loss landscape can be arbitrarily complex. Crucially, the proposed method remains valid without requiring either algorithmic convergence or oracle knowledge of the unknowns, and may therefore inform practical decisions such as early stopping and hyperparameter tuning.

 

The key technical ingredient underlying this new inference paradigm is a recent entrywise dynamical mean-field theory for a broad class of first-order algorithms developed by the speaker.