Source code for nndt.trainable_task

from abc import abstractmethod
from collections import namedtuple
from typing import Callable, NamedTuple

import haiku as hk
import jax
import jax.numpy as jnp
from jax.random import KeyArray

from nndt.haiku_modules import DescConv, LipMLP


[docs]class AbstractTrainableTask:
[docs] @abstractmethod def init_data(self) -> namedtuple: pass
[docs] @abstractmethod def init_and_functions(self, rng_key: KeyArray) -> (namedtuple, namedtuple): pass
[docs]class SimpleSDF(AbstractTrainableTask): """ This is a trainable task for the problem of shape interpolation for one 3d object. This class employs the usual multi-layer perceptron. """
[docs] class FUNC(NamedTuple): sdf: Callable vec_sdf: Callable sdf_dx: Callable sdf_dy: Callable sdf_dz: Callable vec_sdf_dx: Callable vec_sdf_dy: Callable vec_sdf_dz: Callable vec_main_loss: Callable
[docs] class DATA(NamedTuple): X: jnp.ndarray # [N] Y: jnp.ndarray # [N] Z: jnp.ndarray # [N] SDF: jnp.ndarray # [N] def __add__(self, other): return SimpleSDF.DATA( X=jnp.concatenate([self.X, other.X], axis=0), Y=jnp.concatenate([self.Y, other.Y], axis=0), Z=jnp.concatenate([self.Z, other.Z], axis=0), SDF=jnp.concatenate([self.SDF, other.SDF], axis=0), )
def __init__( self, mlp_layers=(32, 32, 32, 32, 32, 32, 32, 32, 1), batch_size=64 * 64 * 64 ): self.mlp_layers = mlp_layers self.batch_size = batch_size self._init_data = SimpleSDF.DATA( X=jnp.zeros(self.batch_size), Y=jnp.zeros(self.batch_size), Z=jnp.zeros(self.batch_size), SDF=jnp.zeros(self.batch_size), )
[docs] def init_and_functions(self, rng_key: KeyArray) -> (namedtuple, namedtuple): def constructor(): def f_sdf(x, y, z): vec = jnp.hstack([x, y, z]) net = hk.nets.MLP(output_sizes=self.mlp_layers, activation=jnp.tanh) return jnp.squeeze(net(vec)) vec_f_sdf = hk.vmap(f_sdf, in_axes=(0, 0, 0), split_rng=False) def vec_main_loss(X, Y, Z, SDF): return jnp.mean((vec_f_sdf(X, Y, Z) - SDF) ** 2) grad_x = hk.grad(f_sdf, argnums=0) grad_y = hk.grad(f_sdf, argnums=1) grad_z = hk.grad(f_sdf, argnums=2) vec_grad_x = hk.vmap(grad_x, in_axes=(0, 0, 0), split_rng=False) vec_grad_y = hk.vmap(grad_y, in_axes=(0, 0, 0), split_rng=False) vec_grad_z = hk.vmap(grad_z, in_axes=(0, 0, 0), split_rng=False) def init(X, Y, Z, SDF): return vec_main_loss(X, Y, Z, SDF) return init, SimpleSDF.FUNC( sdf=f_sdf, vec_sdf=vec_f_sdf, sdf_dx=grad_x, sdf_dy=grad_y, sdf_dz=grad_z, vec_sdf_dx=vec_grad_x, vec_sdf_dy=vec_grad_y, vec_sdf_dz=vec_grad_z, vec_main_loss=vec_main_loss, ) init, function_tuple = hk.multi_transform(constructor) functions = SimpleSDF.FUNC(*function_tuple) init_params = init(rng_key, *tuple(self._init_data)) return init_params, functions
[docs]class ApproximateSDF(AbstractTrainableTask): """ This is a trainable task for the problem of shape interpolation between several 3d objects. This class employs the usual multi-layer perceptron. """ FUNC = namedtuple( "ApproximateSDF_DATA", [ "sdf", "vec_sdf", "sdf_dx", "sdf_dy", "sdf_dz", "sdf_dt", "vec_sdf_dx", "vec_sdf_dy", "vec_sdf_dz", "vec_sdf_dt", "vec_main_loss", ], )
[docs] class DATA(NamedTuple): X: jnp.ndarray # [N] Y: jnp.ndarray # [N] Z: jnp.ndarray # [N] T: jnp.ndarray # [N] P: jnp.ndarray # [N] SDF: jnp.ndarray # [N] def __add__(self, other): return ApproximateSDF.DATA( X=jnp.concatenate([self.X, other.X], axis=0), Y=jnp.concatenate([self.Y, other.Y], axis=0), Z=jnp.concatenate([self.Z, other.Z], axis=0), T=jnp.concatenate([self.T, other.T], axis=0), P=jnp.concatenate([self.P, other.P], axis=0), SDF=jnp.concatenate([self.SDF, other.SDF], axis=0), )
def __init__( self, mlp_layers=(64, 64, 64, 64, 64, 64, 64, 64, 1), batch_size=262144, model_number=2, ): self.mlp_layers = mlp_layers self.batch_size = batch_size self.model_number = model_number self._init_data = ApproximateSDF.DATA( X=jnp.zeros(self.batch_size), Y=jnp.zeros(self.batch_size), Z=jnp.zeros(self.batch_size), T=jnp.zeros(self.batch_size), P=jnp.zeros((self.batch_size, self.model_number)), SDF=jnp.zeros(self.batch_size), )
[docs] def init_data(self) -> namedtuple: return self._init_data
[docs] def init_and_functions(self, rng_key: KeyArray) -> (namedtuple, namedtuple): def constructor(): def f_sdf(x, y, z, t, p): vec = jnp.hstack([x, y, z, t, p]) net = hk.nets.MLP(output_sizes=self.mlp_layers, activation=jnp.tanh) return jnp.squeeze(net(vec)) vec_f_sdf = hk.vmap(f_sdf, in_axes=(0, 0, 0, 0, 0), split_rng=False) def vec_main_loss(X, Y, Z, T, P, SDF): return jnp.mean((vec_f_sdf(X, Y, Z, T, P) - SDF) ** 2) grad_x = hk.grad(f_sdf, argnums=0) grad_y = hk.grad(f_sdf, argnums=1) grad_z = hk.grad(f_sdf, argnums=2) grad_t = hk.grad(f_sdf, argnums=3) vec_grad_x = hk.vmap(grad_x, in_axes=(0, 0, 0, 0, 0), split_rng=False) vec_grad_y = hk.vmap(grad_y, in_axes=(0, 0, 0, 0, 0), split_rng=False) vec_grad_z = hk.vmap(grad_z, in_axes=(0, 0, 0, 0, 0), split_rng=False) vec_grad_t = hk.vmap(grad_t, in_axes=(0, 0, 0, 0, 0), split_rng=False) def init(X, Y, Z, T, P, SDF): return vec_main_loss(X, Y, Z, T, P, SDF) return init, ApproximateSDF.FUNC( sdf=f_sdf, vec_sdf=vec_f_sdf, sdf_dx=grad_x, sdf_dy=grad_y, sdf_dz=grad_z, sdf_dt=grad_t, vec_sdf_dx=vec_grad_x, vec_sdf_dy=vec_grad_y, vec_sdf_dz=vec_grad_z, vec_sdf_dt=vec_grad_t, vec_main_loss=vec_main_loss, ) init, function_tuple = hk.multi_transform(constructor) functions = ApproximateSDF.FUNC(*function_tuple) init_params = init(rng_key, *tuple(self._init_data)) return init_params, functions
[docs]class ApproximateSDFLipMLP(AbstractTrainableTask): """ This is a trainable task for the problem of shape interpolation between several 3d objects. This class employs a multi-layer perceptron with Lipschitz regularization (LipMLP). """ FUNC = namedtuple( "ApproximateSDFLipMLP_DATA", [ "sdf", "vec_sdf", "sdf_dx", "sdf_dy", "sdf_dz", "sdf_dt", "vec_sdf_dx", "vec_sdf_dy", "vec_sdf_dz", "vec_sdf_dt", "vec_main_loss", ], )
[docs] class DATA(NamedTuple): X: jnp.ndarray # [N] Y: jnp.ndarray # [N] Z: jnp.ndarray # [N] T: jnp.ndarray # [N] P: jnp.ndarray # [N] SDF: jnp.ndarray # [N] def __add__(self, other): return ApproximateSDFLipMLP.DATA( X=jnp.concatenate([self.X, other.X], axis=0), Y=jnp.concatenate([self.Y, other.Y], axis=0), Z=jnp.concatenate([self.Z, other.Z], axis=0), T=jnp.concatenate([self.T, other.T], axis=0), P=jnp.concatenate([self.P, other.P], axis=0), SDF=jnp.concatenate([self.SDF, other.SDF], axis=0), )
def __init__( self, mlp_layers=(64, 64, 64, 64, 64, 64, 64, 64, 1), batch_size=262144, model_number=2, lip_alpha=0.000001, negative_beta=0.0, ): self.mlp_layers = mlp_layers self.batch_size = batch_size self.model_number = model_number self.negative_beta = negative_beta self._init_data = ApproximateSDFLipMLP.DATA( X=jnp.zeros(self.batch_size), Y=jnp.zeros(self.batch_size), Z=jnp.zeros(self.batch_size), T=jnp.zeros(self.batch_size), P=jnp.zeros((self.batch_size, self.model_number)), SDF=jnp.zeros(self.batch_size), ) self.lip_alpha = lip_alpha
[docs] def init_data(self) -> namedtuple: return self._init_data
[docs] def init_and_functions(self, rng_key: KeyArray) -> (namedtuple, namedtuple): def constructor(): net = LipMLP( output_sizes=self.mlp_layers, ) def f_sdf(x, y, z, t, p): vec = jnp.hstack([x, y, z, t, p]) return jnp.squeeze(net(vec)) vec_f_sdf = hk.vmap(f_sdf, in_axes=(0, 0, 0, 0, 0), split_rng=False) def vec_main_loss(X, Y, Z, T, P, SDF): predict = vec_f_sdf(X, Y, Z, T, P) return ( jnp.mean((predict - SDF) ** 2) + self.negative_beta * jnp.mean((jax.nn.softplus(-predict) - jax.nn.softplus(-SDF)) ** 2) + self.lip_alpha * net.get_lipschitz_loss() ) grad_x = hk.grad(f_sdf, argnums=0) grad_y = hk.grad(f_sdf, argnums=1) grad_z = hk.grad(f_sdf, argnums=2) grad_t = hk.grad(f_sdf, argnums=3) vec_grad_x = hk.vmap(grad_x, in_axes=(0, 0, 0, 0, 0), split_rng=False) vec_grad_y = hk.vmap(grad_y, in_axes=(0, 0, 0, 0, 0), split_rng=False) vec_grad_z = hk.vmap(grad_z, in_axes=(0, 0, 0, 0, 0), split_rng=False) vec_grad_t = hk.vmap(grad_t, in_axes=(0, 0, 0, 0, 0), split_rng=False) def init(X, Y, Z, T, P, SDF): return vec_main_loss(X, Y, Z, T, P, SDF) return init, ApproximateSDFLipMLP.FUNC( sdf=f_sdf, vec_sdf=vec_f_sdf, sdf_dx=grad_x, sdf_dy=grad_y, sdf_dz=grad_z, sdf_dt=grad_t, vec_sdf_dx=vec_grad_x, vec_sdf_dy=vec_grad_y, vec_sdf_dz=vec_grad_z, vec_sdf_dt=vec_grad_t, vec_main_loss=vec_main_loss, ) init, function_tuple = hk.multi_transform(constructor) functions = ApproximateSDFLipMLP.FUNC(*function_tuple) init_params = init(rng_key, *tuple(self._init_data)) return init_params, functions
[docs]class ApproximateSDFLipMLP2(AbstractTrainableTask): """ This is a trainable task for the problem of shape interpolation between several 3d objects. This class employs a multi-layer perceptron with Lipschitz regularization (LipMLP). """ FUNC = namedtuple( "ApproximateSDFLipMLP2_DATA", [ "sdf", "vec_sdf", "sdf_dx", "sdf_dy", "sdf_dz", "sdf_dt", "vec_sdf_dx", "vec_sdf_dy", "vec_sdf_dz", "vec_sdf_dt", "vec_main_loss", ], )
[docs] class DATA(NamedTuple): X: jnp.ndarray # [N] Y: jnp.ndarray # [N] Z: jnp.ndarray # [N] T: jnp.ndarray # [N] P: jnp.ndarray # [N] SDF: jnp.ndarray # [N] WEIGHT: jnp.ndarray # [N] def __add__(self, other): return ApproximateSDFLipMLP2.DATA( X=jnp.concatenate([self.X, other.X], axis=0), Y=jnp.concatenate([self.Y, other.Y], axis=0), Z=jnp.concatenate([self.Z, other.Z], axis=0), T=jnp.concatenate([self.T, other.T], axis=0), P=jnp.concatenate([self.P, other.P], axis=0), SDF=jnp.concatenate([self.SDF, other.SDF], axis=0), WEIGHT=jnp.concatenate([self.WEIGHT, other.WEIGHT], axis=0), )
def __init__( self, mlp_layers=(64, 64, 64, 64, 64, 64, 64, 64, 1), batch_size=262144, model_number=2, lip_alpha=0.000001, ): self.mlp_layers = mlp_layers self.batch_size = batch_size self.model_number = model_number self._init_data = ApproximateSDFLipMLP2.DATA( X=jnp.zeros(self.batch_size), Y=jnp.zeros(self.batch_size), Z=jnp.zeros(self.batch_size), T=jnp.zeros(self.batch_size), P=jnp.zeros((self.batch_size, self.model_number)), SDF=jnp.zeros(self.batch_size), WEIGHT=jnp.ones(self.batch_size), ) self.lip_alpha = lip_alpha
[docs] def init_data(self) -> namedtuple: return self._init_data
[docs] def init_and_functions(self, rng_key: KeyArray) -> (namedtuple, namedtuple): def constructor(): net = LipMLP(output_sizes=self.mlp_layers, activation=jnp.tanh) def f_sdf(x, y, z, t, p): vec = jnp.hstack([x, y, z, t, p]) return jnp.squeeze(net(vec)) vec_f_sdf = hk.vmap(f_sdf, in_axes=(0, 0, 0, 0, 0), split_rng=False) def vec_main_loss(X, Y, Z, T, P, SDF, WEIGHT): predict = vec_f_sdf(X, Y, Z, T, P) return ( jnp.sum(WEIGHT * (predict - SDF) ** 2) / jnp.sum(WEIGHT) + self.lip_alpha * net.get_lipschitz_loss() ) grad_x = hk.grad(f_sdf, argnums=0) grad_y = hk.grad(f_sdf, argnums=1) grad_z = hk.grad(f_sdf, argnums=2) grad_t = hk.grad(f_sdf, argnums=3) vec_grad_x = hk.vmap(grad_x, in_axes=(0, 0, 0, 0, 0), split_rng=False) vec_grad_y = hk.vmap(grad_y, in_axes=(0, 0, 0, 0, 0), split_rng=False) vec_grad_z = hk.vmap(grad_z, in_axes=(0, 0, 0, 0, 0), split_rng=False) vec_grad_t = hk.vmap(grad_t, in_axes=(0, 0, 0, 0, 0), split_rng=False) def init(X, Y, Z, T, P, SDF, WEIGHT): return vec_main_loss(X, Y, Z, T, P, SDF, WEIGHT) return init, ApproximateSDFLipMLP2.FUNC( sdf=f_sdf, vec_sdf=vec_f_sdf, sdf_dx=grad_x, sdf_dy=grad_y, sdf_dz=grad_z, sdf_dt=grad_t, vec_sdf_dx=vec_grad_x, vec_sdf_dy=vec_grad_y, vec_sdf_dz=vec_grad_z, vec_sdf_dt=vec_grad_t, vec_main_loss=vec_main_loss, ) init, function_tuple = hk.multi_transform(constructor) functions = ApproximateSDFLipMLP2.FUNC(*function_tuple) init_params = init(rng_key, *tuple(self._init_data)) return init_params, functions
[docs]class SurfaceSegmentation(AbstractTrainableTask): """ This is a trainable task for the problem of supervised surface segmentation. This class employs the fully-convolutional neural network. """ FUNC = namedtuple("FUNC", ["nn", "main_loss", "metric_accuracy"]) DATA = namedtuple("DATA", ["SDF_CUBE", "CLASS"]) def __init__( self, spacing=(16, 16, 16), conv_kernel=32, conv_depth=4, num_classes=3, batch_size=128, ): self.spacing = spacing self.conv_kernel = conv_kernel self.conv_depth = conv_depth self.num_classes = num_classes self.batch_size = batch_size self._init_data = SurfaceSegmentation.DATA( SDF_CUBE=jnp.zeros( (self.batch_size, spacing[0], spacing[1], spacing[2], 1) ), CLASS=jnp.zeros(self.batch_size), )
[docs] def init_data(self) -> namedtuple: return self._init_data
[docs] def init_and_functions(self, rng_key: KeyArray) -> (namedtuple, namedtuple): def constructor(): descendants_convolution = DescConv( n_layers=self.conv_depth, kernels_in_first_layer=self.conv_kernel, kernel_shape=(2, 2, 2), stride=(2, 2, 2), activation=jax.nn.relu, ) def nn(input_): # NDHWC x = input_ x = descendants_convolution(x) x = hk.Conv3D( output_channels=self.conv_kernel * 4, kernel_shape=(1, 1, 1), padding="VALID", name="mlp_0", )(x) x = jax.nn.relu(x) x = hk.Conv3D( output_channels=self.conv_kernel * 4, kernel_shape=(1, 1, 1), padding="VALID", name="mlp_1", )(x) x = jax.nn.relu(x) x = hk.Conv3D( output_channels=self.num_classes, kernel_shape=(1, 1, 1), padding="VALID", name="mlp_output_0", )(x) x = jax.nn.softmax(x) return jnp.squeeze(x) def softmax_cross_entropy(logits, labels): one_hot = jax.nn.one_hot(labels, self.num_classes) return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1) def metric_accuracy(sdf_tiles, labels): logits = nn(sdf_tiles) return jnp.mean(jnp.argmax(logits, -1) == labels) def main_loss(sdf_tiles, labels): logits = nn(sdf_tiles) return jnp.mean(softmax_cross_entropy(logits, labels)) def init(sdf_tiles, labels): return main_loss(sdf_tiles, labels) return init, SurfaceSegmentation.FUNC( nn=nn, main_loss=main_loss, metric_accuracy=metric_accuracy ) init, function_tuple = hk.multi_transform(constructor) functions = SurfaceSegmentation.FUNC(*function_tuple) init_params = init(rng_key, *tuple(self._init_data)) return init_params, functions
[docs]class Eikonal3D(AbstractTrainableTask): """ This is a solver for the Eikonal equation. """
[docs] class FUNC(NamedTuple): nn: Callable[[float, float, float], float] # [N] nn_dx: Callable[[float, float, float], float] # [N] nn_dy: Callable[[float, float, float], float] # [N] nn_dz: Callable[[float, float, float], float] # [N] main_loss: Callable[[float, float, float], float] # [N]
[docs] class DATA(NamedTuple): # F(X,Y,Z) = U X: jnp.ndarray # [N] Y: jnp.ndarray # [N] Z: jnp.ndarray # [N]
def __init__( self, fun_sdf_domain: Callable[[float, float, float], float], fun_sdf_start: Callable[[float, float, float], float], mlp_layers=( 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 1, ), lambda_grad=0.1, lambda_non_negativity=0.1, batch_size=100, ): self.fun_sdf_domain = fun_sdf_domain self.fun_sdf_start = fun_sdf_start self.mlp_layers = mlp_layers self.batch_size = batch_size self._init_data = Eikonal3D.DATA( X=jnp.zeros(self.batch_size), Y=jnp.zeros(self.batch_size), Z=jnp.zeros(self.batch_size), ) self.LAMBDA_GRAD = lambda_grad self.LAMBDA_NON_NEGATIVITY = lambda_non_negativity
[docs] def init_data(self) -> namedtuple: return self._init_data
[docs] def init_and_functions(self, rng_key: KeyArray) -> (namedtuple, namedtuple): mlp_layers = self.mlp_layers fun_sdf_domain = self.fun_sdf_domain fun_sdf_start = self.fun_sdf_start LAMBDA_NON_NEGATIVITY = self.LAMBDA_NON_NEGATIVITY LAMBDA_GRAD = self.LAMBDA_GRAD def constructor(): def NN(x, y, z): vec = jnp.hstack([x, y, z]) net = hk.nets.MLP(output_sizes=mlp_layers, activation=jnp.tanh) return jnp.squeeze(net(vec)) vec_NN = hk.vmap(NN, in_axes=(0, 0, 0), split_rng=False) grad_x = hk.grad(NN, argnums=0) grad_y = hk.grad(NN, argnums=1) grad_z = hk.grad(NN, argnums=2) vec_grad_x = hk.vmap(grad_x, in_axes=(0, 0, 0), split_rng=False) vec_grad_y = hk.vmap(grad_y, in_axes=(0, 0, 0), split_rng=False) vec_grad_z = hk.vmap(grad_z, in_axes=(0, 0, 0), split_rng=False) def vec_conductivity(X, Y, Z): return 1.0 * (fun_sdf_domain(X, Y, Z) < 0) def vec_eikonal(x, y, z): return ( vec_grad_x(x, y, z) ** 2 + vec_grad_y(x, y, z) ** 2 + vec_grad_z(x, y, z) ** 2 ) def vec_region0(X, Y, Z, SDF): return 1.0 * (fun_sdf_start(X, Y, Z) < 0) def vec_main_loss(X, Y, Z): Omega0 = fun_sdf_start(X, Y, Z) < 0 Omega1 = 1.0 - Omega0 pred_u = vec_NN(X, Y, Z) loss = ( jnp.mean(Omega0 * (0.0 - pred_u) ** 2) + LAMBDA_GRAD * jnp.mean( Omega1 * (vec_eikonal(X, Y, Z) - vec_conductivity(X, Y, Z)) ** 2 ) + LAMBDA_NON_NEGATIVITY * jnp.mean(jax.nn.relu(-pred_u) ** 2) ) return loss def init(X, Y, Z): return vec_main_loss(X, Y, Z) return init, Eikonal3D.FUNC( nn=vec_NN, nn_dx=vec_grad_x, nn_dy=vec_grad_y, nn_dz=vec_grad_z, main_loss=vec_main_loss, ) init, function_tuple = hk.multi_transform(constructor) functions = Eikonal3D.FUNC(*function_tuple) init_params = init(rng_key, *tuple(self._init_data)) return init_params, functions