Source code for nndt.space2.implicit_representation

import jax
from colorama import Fore

from nndt.primitive_sdf import AbstractSDF
from nndt.space2.abstracts import AbstractBBoxNode, IterAccessMixin, node_method
from nndt.trainable_task import SimpleSDF


[docs]class ImpRepr(AbstractBBoxNode, IterAccessMixin): def __init__(self, name: str, abstract_sdf: AbstractSDF, parent=None): super(ImpRepr, self).__init__( name, parent=parent, bbox=abstract_sdf.bbox, _print_color=Fore.MAGENTA, _nodetype="IR", ) self.abstract_sdf = abstract_sdf # TODO This is a place for improvement # I name this variable _loader, for compatibility with SDTMethodSetNode self._loader = abstract_sdf
[docs] @node_method("purefun_sdf()") def purefun_sdf(self): return self.abstract_sdf.fun
[docs] @node_method("purefun_vec_sdf()") def purefun_vec_sdf(self): return self.abstract_sdf.vec_fun
[docs] @node_method("purefun_vec_sdf_dx()") def purefun_vec_sdf_dx(self): return self.abstract_sdf.vec_fun_dx
[docs] @node_method("purefun_vec_sdf_dy()") def purefun_vec_sdf_dy(self): return self.abstract_sdf.vec_fun_dy
[docs] @node_method("purefun_vec_sdf_dz()") def purefun_vec_sdf_dz(self): return self.abstract_sdf.vec_fun_dz
[docs]class IR1SDF(AbstractSDF): def __init__(self, func: SimpleSDF.FUNC, params, bbox): self.func = func self.params = params self._bbox = bbox key = jax.random.PRNGKey(42) self._fun = lambda x, y, z: self.func.sdf(params, key, x, y, z) self._vec_fun = lambda x, y, z: self.func.vec_sdf(params, key, x, y, z) self._vec_fun_x = lambda x, y, z: self.func.vec_sdf_dx(params, key, x, y, z) self._vec_fun_y = lambda x, y, z: self.func.vec_sdf_dy(params, key, x, y, z) self._vec_fun_z = lambda x, y, z: self.func.vec_sdf_dz(params, key, x, y, z) @property def bbox(self) -> ((float, float, float), (float, float, float)): return self._bbox def _get_fun(self): return self._fun