Source code for nndt.space2.space_preloader

from anytree import PostOrderIter, PreOrderIter
from tqdm import tqdm

from nndt.space2.abstracts import DICT_NODETYPE_PRIORITY, AbstractBBoxNode
from nndt.space2.filesource import FileSource
from nndt.space2.group import Group
from nndt.space2.implicit_representation import IR1SDF, ImpRepr
from nndt.space2.method_set import (
    ColorMethodSetNode,
    MeshObjMethodSetNode,
    SamplingMethodSetNode,
    SDTMethodSetNode,
)
from nndt.space2.method_set_train_task import TrainTaskSetNode
from nndt.space2.object3D import Object3D
from nndt.space2.space import Space
from nndt.space2.transformation import AbstractTransformation
from nndt.space2.utils import pad_bbox, update_bbox


def _update_bbox_bottom_to_up(node):
    for child in node.children:
        if isinstance(child, AbstractBBoxNode):
            node.bbox = update_bbox(node.bbox, child.bbox)


[docs]class DefaultPreloader: def __init__( self, mode="identity", scale=50, keep_in_memory=True, ps_padding=(0.0, 0.0, 0.0), ns_padding=(0.0, 0.0, 0.0), ): self.mode = mode self.scale = scale self.keep_in_memory = keep_in_memory self.ps_padding = ps_padding self.ns_padding = ns_padding
[docs] def preload(self, space: Space, verbose=True): # Stage 1. Initialization of FileSources for node in PostOrderIter(space): if isinstance(node, FileSource): self._init_FileSource(node) # Stage 2. Initialization of Object3D if verbose: iter_tmp = tqdm(PostOrderIter(space)) else: iter_tmp = PostOrderIter(space) for node in iter_tmp: if isinstance(node, Object3D): self._init_Object3D(node) # Unload data from memory if it is necessary if not self.keep_in_memory: for node2 in PreOrderIter(space): if isinstance(node2, FileSource): node2._loader.unload_data() # Stage 3. Initialization of Group for node in PostOrderIter(space): if isinstance(node, Group): self._init_Group(node) self._init_Space(space) self._keep_alphabetical_order_of_nodes(space)
def _keep_alphabetical_order_of_nodes(self, space: Space): for node in PreOrderIter(space): node._NodeMixin__children_or_empty.sort( key=lambda d: (100 - DICT_NODETYPE_PRIORITY[d._nodetype], d.name), reverse=False, ) def _add_sampling_node(self, node: AbstractBBoxNode): SamplingMethodSetNode(parent=node) def _init_Space(self, node: Space): _update_bbox_bottom_to_up(node) self._add_sampling_node(node) node.init() def _init_Group(self, node: Group): _update_bbox_bottom_to_up(node) self._add_sampling_node(node) def _process_ir1_source(self, node: Object3D): ir1_array_list = [ source for source in node.children if isinstance(source, FileSource) and source.loader_type == "implicit_ir1" ] imp_repr = 0 if len(ir1_array_list) > 0: params = ir1_array_list[0]._loader.params bbox = ir1_array_list[0]._loader.bbox functions = ir1_array_list[0]._loader.functions ir1_sdf = IR1SDF(functions, params, bbox) imp_repr = ImpRepr("ir1", ir1_sdf, parent=node) return imp_repr def _process_sdt_source(self, node: Object3D): sdt_array_list = [ source for source in node.children if isinstance(source, FileSource) and source.loader_type in ("sdt", "implicit_ir1") ] transform = None if len(sdt_array_list) > 0: from nndt.space2.transformation import ( IdentityTransform, ShiftAndScaleTransform, ToNormalCubeTransform, ) ps_bbox = sdt_array_list[0].bbox if self.mode == "identity": transform = IdentityTransform(ps_bbox=ps_bbox, parent=node) elif self.mode == "shift_and_scale": ps_center = ( (ps_bbox[0][0] + ps_bbox[1][0]) / 2.0, (ps_bbox[0][1] + ps_bbox[1][1]) / 2.0, (ps_bbox[0][2] + ps_bbox[1][2]) / 2.0, ) transform = ShiftAndScaleTransform( ps_bbox=ps_bbox, ps_center=ps_center, ns_center=(0.0, 0.0, 0.0), scale_ps2ns=self.scale, parent=node, ) elif self.mode == "to_cube": transform = ToNormalCubeTransform(ps_bbox=ps_bbox, parent=node) else: raise NotImplementedError( f"{self.mode} is not supported for initialization" ) node.bbox = update_bbox(node.bbox, transform.bbox) node.bbox = pad_bbox(node.bbox, self.ns_padding) return transform def _process_sdt_source2(self, node: Object3D, transform): sdt_array_list = [ source for source in node.children if (isinstance(source, FileSource) and source.loader_type == "sdt") or (isinstance(source, ImpRepr)) ] if len(sdt_array_list) > 0 and transform is not None: sdt = sdt_array_list[0] SDTMethodSetNode(node, sdt, transform, parent=node) TrainTaskSetNode(node, sdt, transform, parent=node) def _process_mesh_obj_source( self, node: Object3D, transform: AbstractTransformation ): mesh_obj_array_list = [ source for source in node.children if isinstance(source, FileSource) and source.loader_type == "mesh_obj" ] if len(mesh_obj_array_list) and transform is not None: mesh = mesh_obj_array_list[0] MeshObjMethodSetNode(node, mesh, transform, parent=node) if mesh._loader.rgba is not None: ColorMethodSetNode(node, mesh, transform, parent=node) def _init_Object3D(self, node: Object3D): ir1 = self._process_ir1_source(node) transform = self._process_sdt_source(node) self._process_sdt_source2(node, transform) self._process_mesh_obj_source(node, transform) self._add_sampling_node(node) def _init_FileSource(self, node: FileSource): from nndt.space2.loader import DICT_LOADERTYPE_CLASS if node.loader_type not in DICT_LOADERTYPE_CLASS: raise NotImplementedError(f"{node.loader_type} is unknown loader") node._loader = DICT_LOADERTYPE_CLASS[node.loader_type](filepath=node.filepath) node._loader.load_data() node.bbox = node._loader.calc_bbox() node.bbox = pad_bbox(node.bbox, self.ps_padding) self._add_sampling_node(node) if not self.keep_in_memory: node._loader.unload_data()