Source code for nndt.haiku_modules

from typing import *

import haiku as hk
import jax
import jax.numpy as jnp


[docs]class DescConv(hk.Module): """ Fully convolutional network """ def __init__( self, n_layers=4, kernels_in_first_layer=32, kernel_shape=(2, 2, 2), stride=(2, 2, 2), activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu, name: Optional[str] = None, ): super().__init__(name=name) self.n_layers = n_layers self.kernels_in_first_layer = kernels_in_first_layer self.kernel_shape = kernel_shape self.stride = stride self.activation = activation layers = [] for index in range(n_layers): layers.append( hk.Conv3D( output_channels=2**index * kernels_in_first_layer, kernel_shape=kernel_shape, stride=stride, w_init=hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform"), b_init=jnp.zeros, padding="VALID", ) ) self.layers = tuple(layers) def __call__(self, inputs): out = inputs activation = self.activation for i, layer in enumerate(self.layers): out = layer(out) out = activation(out) return out
[docs]class LipLinear(hk.Module): """ Layer for implementation of the LipMLP from the article: Liu, Hsueh-Ti Derek, et al. "Learning Smooth Neural Functions via Lipschitz Regularization." arXiv preprint arXiv:2202.08345 (2022). """ def __init__( self, output_size, name=None, activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu, ): super().__init__(name=name) self.output_size = output_size self.activation = activation
[docs] def weight_normalization(self, W, softplus_c): absrowsum = jnp.sum(jnp.abs(W), axis=1) scale = jnp.minimum(1.0, softplus_c / absrowsum) return W * scale[:, None]
def __call__(self, x): input_size, output_size = x.shape[-1], self.output_size w_init = hk.initializers.VarianceScaling(2.0, "fan_in", "truncated_normal") W = hk.get_parameter( "W", shape=[input_size, output_size], dtype=x.dtype, init=w_init ) b = hk.get_parameter("b", shape=[output_size], dtype=x.dtype, init=jnp.zeros) def _c_init(shape: Sequence[int], dtype: Any) -> jnp.ndarray: return jnp.max(jnp.sum(jnp.abs(W.T), axis=1)) c = hk.get_parameter("c", shape=(), dtype=x.dtype, init=_c_init) W_ = self.weight_normalization(W.T, jax.nn.softplus(c)).T out = self.activation(jnp.dot(x, W_) + b) return out
[docs] def get_lipschitz_loss(self): c = hk.get_parameter("c", shape=(), init=jnp.zeros) return jax.nn.softplus(c)
[docs]class LipMLP(hk.Module): """ This is an implementation of the LipMLP from the article: Liu, Hsueh-Ti Derek, et al. "Learning Smooth Neural Functions via Lipschitz Regularization." arXiv preprint arXiv:2202.08345 (2022). """ def __init__( self, output_sizes: Iterable[int], name: Optional[str] = None, activation=jax.nn.tanh, activation_output=lambda x: x, ): super().__init__(name=name) self.output_sizes = output_sizes layers = [] output_sizes = tuple(output_sizes) for index, output_size in enumerate(output_sizes[:-1]): layers.append( LipLinear( output_size=output_size, name="lip_mlp_%d" % index, activation=activation, ) ) index, output_size = len(output_sizes) - 1, output_sizes[-1] layers.append( LipLinear( output_size=output_size, name="lip_mlp_%d" % index, activation=activation_output, ) ) self.layers = tuple(layers) def __call__(self, inputs): out = inputs for i, layer in enumerate(self.layers): out = layer(out) return out
[docs] def get_lipschitz_loss(self): out = 1.0 for i, layer in enumerate(self.layers): out = out * layer.get_lipschitz_loss() return out