Source code for nndt.space2.method_set_train_task

import pickle
from typing import Union

import jax
import jax.numpy as jnp
import optax
from tqdm import tqdm

import nndt
from nndt import SimpleSDF
from nndt.space2.abstracts import AbstractBBoxNode, node_method
from nndt.space2.filesource import FileSource
from nndt.space2.implicit_representation import ImpRepr
from nndt.space2.method_set import MethodSetNode
from nndt.space2.transformation import AbstractTransformation


[docs]class TrainTaskSetNode(MethodSetNode): def __init__( self, object_3d: AbstractBBoxNode, sdt: Union[FileSource, ImpRepr], transform: AbstractTransformation, parent: AbstractBBoxNode = None, ): super(TrainTaskSetNode, self).__init__("train_task", parent=parent) self.object_3d = object_3d assert isinstance(sdt, ImpRepr) or sdt.loader_type == "sdt" self.sdt = sdt self.transform = transform
[docs] def load_batch(self, spacing): # TODO! This place is dangerous. Sampling initializes after this class. # This call does not specialize in a specific sampling node or SDT! # I hope I will find a proper way to write this in the future. xyz = self.parent.sampling.sampling_grid(spacing=spacing) xyz_flat = xyz.reshape((-1, 3)) sdf_flat = jnp.squeeze(self.parent.sdt.surface_xyz2sdt(xyz_flat)) xyz_flat = jnp.array(xyz_flat) data = SimpleSDF.DATA( X=xyz_flat[:, 0], Y=xyz_flat[:, 1], Z=xyz_flat[:, 2], SDF=sdf_flat ) return data
[docs] @node_method("train_task_sdt2sdf(filename, **kwargs)") def train_task_sdt2sdf( self, filename, spacing=(64, 64, 64), width=32, depth=8, learning_rate=0.006, epochs=10001, ): if not ( hasattr(self.parent, "sampling") and hasattr(self.parent.sampling, "sampling_grid") ): raise NotImplementedError( "This error is really bad. Initialization order was broken!" ) if not ( hasattr(self.parent, "sdt") and hasattr(self.parent.sdt, "surface_xyz2sdt") ): raise NotImplementedError( "This error is really bad. Initialization order was broken!" ) kwargs = { "mlp_layers": tuple([width] * depth + [1]), "batch_size": spacing[0] * spacing[1] * spacing[2], } from nndt.trainable_task import SimpleSDF task = SimpleSDF(**kwargs) rng = jax.random.PRNGKey(42) params, F = task.init_and_functions(rng) opt = optax.adam(learning_rate) opt_state = opt.init(params) D1 = self.load_batch(spacing) @jax.jit def train_step(params, rng, opt_state, D1): loss, grads = jax.value_and_grad(F.vec_main_loss)(params, rng, *tuple(D1)) updates, opt_state = opt.update(grads, opt_state) params = optax.apply_updates(params, updates) return loss, params, rng, opt_state min_loss = 99999999 loss_history = [] pbar = tqdm(range(epochs)) for epoch in pbar: loss, params, rng, opt_state = train_step(params, rng, opt_state, D1) loss_history.append(float(loss)) pbar.set_description(f"min_loss = {min_loss:.06f}") if loss < min_loss: with open(filename, "wb") as fl: pickle.dump( { "version": nndt.__version__, "repr": { (k, v) for k, v in self.transform.__dict__.items() if isinstance(v, (int, float, str)) }, "bbox": self.object_3d.bbox, "trainable_task": kwargs, "history_loss": loss_history, "params": params, }, fl, ) min_loss = loss