Shortcuts

NNJ Documentation

NNJ = torch.nn + Jacobian.

Both forward and backward passes for both vector and metrics in tangent spaces are implemented for every module. This library is a fast and memory efficient extension of PyTorch.

Github: https://github.com/IlMioFrizzantinoAmabile/nnj

Authors: Marco Miani and Frederik Warburg

Installation

Dependence: Please install Pytorch first.

The easiest way is to install from PyPI:

$ pip install nnj

Or install from source:

$ git clone https://github.com/IlMioFrizzantinoAmabile/nnj
$ cd nnj
$ pip install -e .

Want to learn more about nnj?

Check out our Introduction to nnj to learn more about the theory behind nnj. The document seek to provide a simple introduction to nnj and to the functions it provides.

Usage example

Declare your neural network as you would normally do with PyTorch, just with an extra j.

import nnj

# Define you sequential model
model = nnj.Sequential(
   nnj.Linear(),
   nnj.Tanh(),
   nnj.Linear(),
)

# Standard forward pass
val = model(x)

Compute gradient (with respect to weight) of the l2 loss as backward pass of the residual vector, and perform a gradient step.

# The residual is the derivative of the loss with respect to the nn output
residual = val - target
# Backpropagate the residual vector
gradient = model.vjp(
   x,             # input
   val,           # output
   residual,      # residual vector
   wrt="weight"
)
# Average over batch size
gradient = torch.mean(gradient, dim=0)

# Do a gradient step
param = model.get_weight()
param -= lr * gradient
model.set_weight(param)

Compute the Generalized-Gauss Newton (which is an approximation of the hessian) as a backward pass of the Euclidean metric.

jacobianTranspose_jacobian = model._jTmjp(
   x,                  # input
   val,                # output
   None,               # None means identity (i.e. Euclidean metric)
   wrt="weights",      # computes the jacobian wrt weights or inputs
   to_diag=True,       # computes the diagonal elements only
   diag_backprop=True, # approximates the diagonal elements, which speeds up the computations
)

Why not just use Jax?

Check out our comparison, e.g. start reading our Ok, but why to learn more about it. (WORK IN PROGRESS)