Fast and Easy Infinite Neural Networks in Python
Project description
Neural Tangents
ICLR 2020 Video | Paper | Quickstart | Install guide | Reference docs | Release notes
Overview
Neural Tangents is a high-level neural network API for specifying complex, hierarchical, neural networks of both finite and infinite width. Neural Tangents allows researchers to define, train, and evaluate infinite networks as easily as finite ones.
Infinite (in width or channel count) neural networks are Gaussian Processes (GPs) with a kernel function determined by their architecture. See References for details and nuances of this correspondence. Also see this listing of papers written by the creators of Neural Tangents which study the infinite width limit of neural networks.
Neural Tangents allows you to construct a neural network model from common building blocks like convolutions, pooling, residual connections, nonlinearities, and more, and obtain not only the finite model, but also the kernel function of the respective GP.
The library is written in python using JAX and leveraging XLA to run out-of-the-box on CPU, GPU, or TPU. Kernel computation is highly optimized for speed and memory efficiency, and can be automatically distributed over multiple accelerators with near-perfect scaling.
Neural Tangents is a work in progress. We happily welcome contributions!
Contents
- Colab Notebooks
- Installation
- 5-Minute intro
- Package description
- Technical gotchas
- Training dynamics of wide but finite networks
- Performance
- Papers
- Citation
- References
Colab Notebooks
An easy way to get started with Neural Tangents is by playing around with the following interactive notebooks in Colaboratory. They demo the major features of Neural Tangents and show how it can be used in research.
- Neural Tangents Cookbook
- Weight Space Linearization
- Function Space Linearization
- Neural Network Phase Diagram
- Performance Benchmark : Simple benchmark for Myrtle kernels used in [16]. Also see Performance.
Installation
To use GPU, first follow JAX's GPU installation instructions. Otherwise, install JAX on CPU by running
pip install jax jaxlib --upgrade
Once JAX is installed install Neural Tangents by running
pip install neural-tangents
or, to use the bleeding-edge version from GitHub source,
git clone https://github.com/google/neural-tangents; cd neural-tangents
pip install -e .
You can now run the examples (using tensorflow_datasets
)
and tests by calling:
pip install tensorflow tensorflow-datasets more-itertools --upgrade
python examples/infinite_fcn.py
python examples/weight_space.py
python examples/function_space.py
set -e; for f in tests/*.py; do python $f; done
5-Minute intro
See this Colab for a detailed tutorial. Below is a very quick introduction.
Our library closely follows JAX's API for specifying neural networks, stax
. In stax
a network is defined by a pair of functions (init_fn, apply_fn)
initializing the trainable parameters and computing the outputs of the network respectively. Below is an example of defining a 3-layer network and computing it's outputs y
given inputs x
.
from jax import random
from jax.example_libraries import stax
init_fn, apply_fn = stax.serial(
stax.Dense(512), stax.Relu,
stax.Dense(512), stax.Relu,
stax.Dense(1)
)
key = random.PRNGKey(1)
x = random.normal(key, (10, 100))
_, params = init_fn(key, input_shape=x.shape)
y = apply_fn(params, x) # (10, 1) np.ndarray outputs of the neural network
Neural Tangents is designed to serve as a drop-in replacement for stax
, extending the (init_fn, apply_fn)
tuple to a triple (init_fn, apply_fn, kernel_fn)
, where kernel_fn
is the kernel function of the infinite network (GP) of the given architecture. Below is an example of computing the covariances of the GP between two batches of inputs x1
and x2
.
from jax import random
from neural_tangents import stax
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(512), stax.Relu(),
stax.Dense(512), stax.Relu(),
stax.Dense(1)
)
key1, key2 = random.split(random.PRNGKey(1))
x1 = random.normal(key1, (10, 100))
x2 = random.normal(key2, (20, 100))
kernel = kernel_fn(x1, x2, 'nngp')
Note that kernel_fn
can compute two covariance matrices corresponding to the Neural Network Gaussian Process (NNGP) and Neural Tangent (NT) kernels respectively. The NNGP kernel corresponds to the Bayesian infinite neural network [1-5]. The NTK corresponds to the (continuous) gradient descent trained infinite network [10]. In the above example, we compute the NNGP kernel but we could compute the NTK or both:
# Get kernel of a single type
nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) np.ndarray
ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) np.ndarray
# Get kernels as a namedtuple
both = kernel_fn(x1, x2, ('nngp', 'ntk'))
both.nngp == nngp # True
both.ntk == ntk # True
# Unpack the kernels namedtuple
nngp, ntk = kernel_fn(x1, x2, ('nngp', 'ntk'))
Additionally, if no third-argument is specified then the kernel_fn
will return a Kernel
namedtuple that contains additional metadata. This can be useful for composing applications of kernel_fn
as follows:
kernel = kernel_fn(x1, x2)
kernel = kernel_fn(kernel)
print(kernel.nngp)
Doing inference with infinite networks trained on MSE loss reduces to classical GP inference, for which we also provide convenient tools:
import neural_tangents as nt
x_train, x_test = x1, x2
y_train = random.uniform(key1, shape=(10, 1)) # training targets
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
y_train)
y_test_nngp = predict_fn(x_test=x_test, get='nngp')
# (20, 1) np.ndarray test predictions of an infinite Bayesian network
y_test_ntk = predict_fn(x_test=x_test, get='ntk')
# (20, 1) np.ndarray test predictions of an infinite continuous
# gradient descent trained network at convergence (t = inf)
# Get predictions as a namedtuple
both = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
both.nngp == y_test_nngp # True
both.ntk == y_test_ntk # True
# Unpack the predictions namedtuple
y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
Infinitely WideResnet
We can define a more compex, (infinitely) Wide Residual Network [14] using the same nt.stax
building blocks:
from neural_tangents import stax
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
Main = stax.serial(
stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),
stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))
Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
channels, (3, 3), strides, padding='SAME')
return stax.serial(stax.FanOut(2),
stax.parallel(Main, Shortcut),
stax.FanInSum())
def WideResnetGroup(n, channels, strides=(1, 1)):
blocks = []
blocks = [WideResnetBlock(channels, strides, channel_mismatch=True)]
for _ in range(n - 1):
blocks = [WideResnetBlock(channels, (1, 1))]
return stax.serial(*blocks)
def WideResnet(block_size, k, num_classes):
return stax.serial(
stax.Conv(16, (3, 3), padding='SAME'),
WideResnetGroup(block_size, int(16 * k)),
WideResnetGroup(block_size, int(32 * k), (2, 2)),
WideResnetGroup(block_size, int(64 * k), (2, 2)),
stax.AvgPool((8, 8)),
stax.Flatten(),
stax.Dense(num_classes, 1., 0.))
init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)
Package description
The neural_tangents
(nt
) package contains the following modules and functions:
-
stax
- primitives to construct neural networks likeConv
,Relu
,serial
,parallel
etc. -
predict
- predictions with infinite networks:-
predict.gradient_descent_mse
- inference with a single infinite width / linearized network trained on MSE loss with continuous gradient descent for an arbitrary finite or infinite (t=None
) time. Computed in closed form. -
predict.gradient_descent
- inference with a single infinite width / linearized network trained on arbitrary loss with continuous (momentum) gradient descent for an arbitrary finite time. Computed using an ODE solver. -
predict.gradient_descent_mse_ensemble
- inference with an infinite ensemble of infinite width networks, either fully Bayesian (get='nngp'
) or inference with MSE loss using continuous gradient descent (get='ntk'
). Finite-time Bayesian inference (e.g.t=1., get='nngp'
) is interpreted as gradient descent on the top layer only [11], since it converges to exact Gaussian process inference with NNGP (t=None, get='nngp'
). Computed in closed form. -
predict.gp_inference
- exact closed form Gaussian process inference using NNGP (get='nngp'
), NTK (get='ntk'
), or both (get=('nngp', 'ntk')
). Equivalent topredict.gradient_descent_mse_ensemble
witht=None
(infinite training time), but has a slightly different API (accepting precomputed kernel matrixk_train_train
instead ofkernel_fn
andx_train
).
-
-
monte_carlo_kernel_fn
- compute a Monte Carlo kernel estimate of any(init_fn, apply_fn)
, not necessarily specified viant.stax
, enabling the kernel computation of infinite networks without closed-form expressions. -
Tools to investigate training dynamics of wide but finite neural networks, like
linearize
,taylor_expand
,empirical_kernel_fn
and more. See Training dynamics of wide but finite networks for details.
Technical gotchas
nt.stax
vs jax.example_libraries.stax
We remark the following differences between our library and the JAX one.
- All
nt.stax
layers are instantiated with a function call, i.e.nt.stax.Relu()
vsjax.example_libraries.stax.Relu
. - All layers with trainable parameters use the NTK parameterization by default (see [10], Remark 1). However, Dense and Conv layers also support the standard parameterization via a
parameterization
keyword argument (see [15]). nt.stax
andjax.example_libraries.stax
may have different layers and options available (for examplent.stax
layers supportCIRCULAR
padding, haveLayerNorm
, but noBatchNorm
.).
CPU and TPU performance
For CNNs w/ pooling, our CPU and TPU performance is suboptimal due to low core utilization (10-20%, looks like an XLA:CPU issue), and excessive padding respectively. We will look into improving performance, but recommend NVIDIA GPUs in the meantime. See Performance.
Training dynamics of wide but finite networks
The kernel of an infinite network kernel_fn(x1, x2).ntk
combined with nt.predict.gradient_descent_mse
together allow to analytically track the outputs of an infinitely wide neural network trained on MSE loss througout training. Here we discuss the implications for wide but finite neural networks and present tools to study their evolution in weight space (trainable parameters of the network) and function space (outputs of the network).
Weight space
Continuous gradient descent in an infinite network has been shown in [11] to correspond to training a linear (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.
For this, we provide two convenient functions:
nt.linearize
, andnt.taylor_expand
,
which allow to linearize or get an arbitrary-order Taylor expansion of any function apply_fn(params, x)
around some initial parameters params_0
as apply_fn_lin = nt.linearize(apply_fn, params_0)
.
One can use apply_fn_lin(params, x)
exactly as you would any other function
(including as an input to JAX optimizers). This makes it easy to compare the
training trajectory of neural networks with that of its linearization.
Previous theory and experiments have examined the linearization of neural
networks from inputs to logits or pre-activations, rather than from inputs to
post-activations which are substantially more nonlinear.
Example:
import jax.numpy as np
import neural_tangents as nt
def apply_fn(params, x):
W, b = params
return np.dot(x, W) b
W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))
apply_fn_lin = nt.linearize(apply_fn, (W_0, b_0))
W = np.array([[1.5, 0.2], [0.1, 0.9]])
b = b_0 0.2
x = np.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]])
logits = apply_fn_lin((W, b), x) # (3, 2) np.ndarray
Function space:
Outputs of a linearized model evolve identically to those of an infinite one [11] but with a different kernel - specifically, the Neural Tangent Kernel [10] evaluated on the specific apply_fn
of the finite network given specific params_0
that the network is initialized with. For this we provide the nt.empirical_kernel_fn
function that accepts any apply_fn
and returns a kernel_fn(x1, x2, get, params)
that allows to compute the empirical NTK and/or NNGP (based on get
) kernels on specific params
.
Example:
import jax.random as random
import jax.numpy as np
import neural_tangents as nt
def apply_fn(params, x):
W, b = params
return np.dot(x, W) b
W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))
params = (W_0, b_0)
key1, key2 = random.split(random.PRNGKey(1), 2)
x_train = random.normal(key1, (3, 2))
x_test = random.normal(key2, (4, 2))
y_train = random.uniform(key1, shape=(3, 2))
kernel_fn = nt.empirical_kernel_fn(apply_fn)
ntk_train_train = kernel_fn(x_train, None, 'ntk', params)
ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
mse_predictor = nt.predict.gradient_descent_mse(ntk_train_train, y_train)
t = 5.
y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0, ntk_test_train)
# (3, 2) and (4, 2) np.ndarray train and test outputs after `t` units of time
# training with continuous gradient descent
What to Expect
The success or failure of the linear approximation is highly architecture dependent. However, some rules of thumb that we've observed are:
-
Convergence as the network size increases.
-
For fully-connected networks one generally observes very strong agreement by the time the layer-width is 512 (RMSE of about 0.05 at the end of training).
-
For convolutional networks one generally observes reasonable agreement agreement by the time the number of channels is 512.
-
-
Convergence at small learning rates.
With a new model it is therefore advisable to start with a very large model on a small dataset using a small learning rate.
Performance
In the table below we measure time to compute a single NTK
entry in a 21-layer CNN (3x3
filters, no strides, SAME
padding, ReLU
) on inputs of shape 3x32x32
. Precisely:
layers = []
for _ in range(21):
layers = [stax.Conv(1, (3, 3), (1, 1), 'SAME'), stax.Relu()]
CNN with pooling
Top layer is stax.GlobalAvgPool()
:
_, _, kernel_fn = stax.serial(*(layers [stax.GlobalAvgPool()]))
Platform | Precision | Milliseconds / NTK entry | Max batch size (NxN ) |
---|---|---|---|
CPU, >56 cores, >700 Gb RAM | 32 | 112.90 | >= 128 |
CPU, >56 cores, >700 Gb RAM | 64 | 258.55 | 95 (fastest - 72) |
TPU v2 | 32/16 | 3.2550 | 16 |
TPU v3 | 32/16 | 2.3022 | 24 |
NVIDIA P100 | 32 | 5.9433 | 26 |
NVIDIA P100 | 64 | 11.349 | 18 |
NVIDIA V100 | 32 | 2.7001 | 26 |
NVIDIA V100 | 64 | 6.2058 | 18 |
CNN without pooling
Top layer is stax.Flatten()
:
_, _, kernel_fn = stax.serial(*(layers [stax.Flatten()]))
Platform | Precision | Milliseconds / NTK entry | Max batch size (NxN ) |
---|---|---|---|
CPU, >56 cores, >700 Gb RAM | 32 | 0.12013 | 2048 <= N < 4096 (fastest - 512) |
CPU, >56 cores, >700 Gb RAM | 64 | 0.3414 | 2048 <= N < 4096 (fastest - 256) |
TPU v2 | 32/16 | 0.0015722 | 512 <= N < 1024 |
TPU v3 | 32/16 | 0.0010647 | 512 <= N < 1024 |
NVIDIA P100 | 32 | 0.015171 | 512 <= N < 1024 |
NVIDIA P100 | 64 | 0.019894 | 512 <= N < 1024 |
NVIDIA V100 | 32 | 0.0046510 | 512 <= N < 1024 |
NVIDIA V100 | 64 | 0.010822 | 512 <= N < 1024 |
Tested using version 0.2.1
. All GPU results are per single accelerator.
Note that runtime is proportional to the depth of your network.
If your performance differs significantly,
please file a bug!
Myrtle network
Colab notebook Performance Benchmark
demonstrates how one would construct and benchmark kernels. To demonstrate
flexibility, we took architecture from [16]
as an example. With NVIDIA V100
64-bit precision, nt
took 316/330/508 GPU-hours on full 60k CIFAR-10 dataset for Myrtle-5/7/10 kernels.
Papers
Neural Tangents has been used in the following papers (newest first):
- On the Equivalence between Neural Network and Support Vector Machine
- An Empirical Study of Neural Kernel Bandits
- Neural Networks as Kernel Learners: The Silent Alignment Effect
- Understanding Deep Learning via Analyzing Dynamics of Gradient Descent
- Neural Scene Representations for View Synthesis
- Neural Tangent Kernel Eigenvalues Accurately Predict Generalization
- Uniform Generalization Bounds for Overparameterized Neural Networks
- Data Summarization via Bilevel Optimization
- Neural Tangent Generalization Attacks
- Dataset Distillation with Infinitely Wide Convolutional Networks
- Neural Contextual Bandits without Regret
- Epistemic Neural Networks
- Uncertainty-aware Cardinality Estimation by Neural Network Gaussian Process
- Scale Mixtures of Neural Network Gaussian Processes
- Provably efficient machine learning for quantum many-body problems
- Wide Mean-Field Variational Bayesian Neural Networks Ignore the Data
- Spectral bias and task-model alignment explain generalization in kernel regression and infinitely wide neural networks
- Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Training and Effective Adaptation
- Wide Mean-Field Variational Bayesian Neural Networks Ignore the Data
- What can linearized neural networks actually say about generalization?
- Measuring the sensitivity of Gaussian processes to kernel choice
- A Neural Tangent Kernel Perspective of GANs
- On the Power of Shallow Learning
- Learning Curves for SGD on Structured Features
- Out-of-Distribution Generalization in Kernel Regression
- Rapid Feature Evolution Accelerates Learning in Neural Networks
- Scalable and Flexible Deep Bayesian Optimization with Auxiliary Information for Scientific Problems
- Random Features for the Neural Tangent Kernel
- Multi-Level Fine-Tuning: Closing Generalization Gaps in Approximation of Solution Maps under a Limited Budget for Training Data
- Explaining Neural Scaling Laws
- Correlated Weights in Infinite Limits of Deep Convolutional Neural Networks
- Dataset Meta-Learning from Kernel Ridge-Regression
- Deep learning versus kernel learning: an empirical study of loss landscape geometry and the time evolution of the Neural Tangent Kernel
- Stable ResNet
- Label-Aware Neural Tangent Kernel: Toward Better Generalization and Local Elasticity
- Semi-supervised Batch Active Learning via Bilevel Optimization
- Temperature check: theory and practice for training models with softmax-cross-entropy losses
- Experimental Design for Overparameterized Learning with Application to Single Shot Deep Active Learning
- How Neural Networks Extrapolate: From Feedforward to Graph Neural Networks
- Exploring the Uncertainty Properties of Neural Networks’ Implicit Priors in the Infinite-Width Limit
- Cold Posteriors and Aleatoric Uncertainty
- Asymptotics of Wide Convolutional Neural Networks
- Finite Versus Infinite Neural Networks: an Empirical Study
- Bayesian Deep Ensembles via the Neural Tangent Kernel
- The Surprising Simplicity of the Early-Time Learning Dynamics of Neural Networks
- When Do Neural Networks Outperform Kernel Methods?
- Statistical Mechanics of Generalization in Kernel Regression
- Exact posterior distributions of wide Bayesian neural networks
- Infinite attention: NNGP and NTK for deep attention networks
- Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains
- Finding trainable sparse networks through Neural Tangent Transfer
- Coresets via Bilevel Optimization for Continual Learning and Streaming
- On the Neural Tangent Kernel of Deep Networks with Orthogonal Initialization
- The large learning rate phase of deep learning: the catapult mechanism
- Spectrum Dependent Learning Curves in Kernel Regression and Wide Neural Networks
- Taylorized Training: Towards Better Approximation of Neural Network Training at Finite Width
- On the Infinite Width Limit of Neural Networks with a Standard Parameterization
- Disentangling Trainability and Generalization in Deep Learning
- Information in Infinite Ensembles of Infinitely-Wide Neural Networks
- Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel
- Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent
- Bayesian Deep Convolutional Networks with Many Channels are Gaussian Processes
Please let us know if you make use of the code in a publication, and we'll add it to the list!
Citation
If you use the code in a publication, please cite our ICLR 2020 paper:
@inproceedings{neuraltangents2020,
title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python},
author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://github.com/google/neural-tangents}
}
References
[1] Priors for Infinite Networks
[2] Exponential expressivity in deep neural networks through transient chaos
[3] Toward deeper understanding of neural networks: The power of initialization and a dual view on expressivity
[4] Deep Information Propagation
[5] Deep Neural Networks as Gaussian Processes
[6] Gaussian Process Behaviour in Wide Deep Neural Networks
[7] Dynamical Isometry and a Mean Field Theory of CNNs: How to Train 10,000-Layer Vanilla Convolutional Neural Networks.
[8] Bayesian Deep Convolutional Networks with Many Channels are Gaussian Processes
[9] Deep Convolutional Networks as shallow Gaussian Processes
[10] Neural Tangent Kernel: Convergence and Generalization in Neural Networks
[11] Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent
[12] Scaling Limits of Wide Neural Networks with Weight Sharing: Gaussian Process Behavior, Gradient Independence, and Neural Tangent Kernel Derivation
[13] Mean Field Residual Networks: On the Edge of Chaos
[14] Wide Residual Networks
[15] On the Infinite Width Limit of Neural Networks with a Standard Parameterization
[16] Neural Kernels Without Tangents
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for neural_tangents-0.3.9-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2b34752230116edad6379ae8b1adf4056206c9343906264852033108ae50662f |
|
MD5 | 3a5a46e758e234e150fcaa2941e84c0e |
|
BLAKE2b-256 | ac03026f6afeeae24d33a2ae5c21610b8ccb18ad7691a572c35b0733d6630bbb |