import pickle
import warnings
from typing import Optional
import jax
import jax.numpy as jnp
import vtk
from packaging import version
from pykdtree.kdtree import KDTree
from vtkmodules.util.numpy_support import vtk_to_numpy
import nndt
from nndt.space2.abstracts import AbstractLoader
from nndt.trainable_task import SimpleSDF
[docs]class EmptyLoader(AbstractLoader):
"""
Dummy loader, that does nothing.
Args:
filepath (str): path to the file
"""
def __init__(self, filepath: str):
self.filepath = filepath
self.is_load = False
[docs] def load_data(self):
self.is_load = True
[docs] def unload_data(self):
self.is_load = False
[docs] def is_load(self) -> bool:
return self.is_load
[docs]class TXTLoader(AbstractLoader):
"""
Load txt file
Args:
filepath (str): path to the file
"""
def __init__(self, filepath: str):
self.filepath = filepath
self.is_load = False
self._text = None
@property
def text(self):
"""Return text from the file.
Returns:
str: File content.
"""
if not self.is_load:
self.load_data()
return self._text
[docs] def load_data(self):
with open(self.filepath, "r") as fl:
self._text = fl.read()
self.is_load = True
[docs] def unload_data(self):
self._text = None
self.is_load = False
[docs] def is_load(self) -> bool:
return self.is_load
def _load_colors_from_obj(filepath):
red = []
green = []
blue = []
alpha = []
with open(filepath, "r") as fl:
for line in fl:
if "v" in line:
tokens = line.split(" ")
if ("v" == tokens[0]) and (len(tokens) >= 7):
red.append(float(tokens[4].replace(",", ".")))
green.append(float(tokens[5].replace(",", ".")))
blue.append(float(tokens[6].replace(",", ".")))
alpha.append(1.0)
red = jnp.array(red)
green = jnp.array(green)
blue = jnp.array(blue)
alpha = jnp.array(alpha)
return jnp.column_stack([red, green, blue, alpha])
def _load_colors_from_ply(filepath):
red = []
green = []
blue = []
alpha = []
is_read_mode = False
with open(filepath, "r") as fl:
for line in fl:
if "end_header" in line:
is_read_mode = True
if is_read_mode:
tokens = line.split(" ")
if len(tokens) >= 10:
red.append(float(tokens[6].replace(",", ".")))
green.append(float(tokens[7].replace(",", ".")))
blue.append(float(tokens[8].replace(",", ".")))
alpha.append(float(tokens[9].replace(",", ".")))
red = jnp.array(red) / 255
green = jnp.array(green) / 255
blue = jnp.array(blue) / 255
alpha = jnp.array(alpha) / 255
return jnp.concatenate([red, green, blue, alpha], axis=1)
[docs]class MeshObjLoader(AbstractLoader):
"""
Load .obj file with mesh.
Args:
filepath (str): path to the file
"""
def __init__(self, filepath: str):
self.filepath = filepath
self.is_load = False
self._mesh = None
self._points = None
self._kdtree = None
self._rgba = None
[docs] def calc_bbox(self) -> ((float, float, float), (float, float, float)):
"""Return the boundary box size of a 3D object.
Returns:
(tuple), (tuple): boundary box: (Xmin, Xmax, Ymin), (Ymax, Zmin, Zmax)
"""
Xmin, Xmax, Ymin, Ymax, Zmin, Zmax = self.mesh.GetBounds()
return (Xmin, Ymin, Zmin), (Xmax, Ymax, Zmax)
@property
def mesh(self) -> vtk.vtkPolyData:
"""Return the mesh of a 3D object.
Returns:
vtk.vtkPolyData: Mesh.
"""
if not self.is_load:
self.load_data()
return self._mesh
@property
def points(self) -> jnp.ndarray:
"""Return points of a 3D object.
Returns:
jnp.ndarray: Points of the object.
"""
if not self.is_load:
self.load_data()
return self._points
@property
def kdtree(self) -> KDTree:
"""Return KDTree of a 3D object.
Returns:
KDTree: KDTree.
"""
if not self.is_load:
self.load_data()
return self._kdtree
@property
def rgba(self) -> Optional[jnp.ndarray]:
"""Return colors of all mesh vertex
Returns:
Optional[jnp.ndarray]: RGBA color
"""
if not self.is_load:
self.load_data()
return self._rgba
[docs] def load_data(self):
"""Load data from the file"""
reader = vtk.vtkOBJReader()
reader.SetFileName(self.filepath)
reader.Update()
self._mesh = reader.GetOutput()
onp_points = vtk_to_numpy(self._mesh.GetPoints().GetData())
self._points = jnp.array(onp_points)
self._kdtree = KDTree(onp_points)
try:
self._rgba = _load_colors_from_obj(self.filepath)
except Exception:
warnings.warn("Colors cannot be loaded from mesh.")
self._rgba = None
self.is_load = True
[docs] def unload_data(self):
"""Clear data"""
self._mesh = None
self.is_load = False
[docs] def is_load(self) -> bool:
"""Return is file loaded.
Returns:
bool: File load status.
"""
return self.is_load
[docs]class SDTLoader(AbstractLoader):
"""
Load signed distance tensor file.
Args:
filepath (str): path to the file
"""
def __init__(self, filepath: str):
self.filepath = filepath
self.is_load = False
self._sdt = None
self._sdt_threshold_level = 0.0
[docs] def calc_bbox(self) -> ((float, float, float), (float, float, float)):
"""Return the boundary box size of the object.
Returns:
(tuple), (tuple): boundary box: (Xmin, Xmax, Ymin), (Ymax, Zmin, Zmax)
"""
mask_arr = self.sdt <= self._sdt_threshold_level
Xmin = float(jnp.argmax(jnp.any(mask_arr, axis=(1, 2))))
Ymin = float(jnp.argmax(jnp.any(mask_arr, axis=(0, 2))))
Zmin = float(jnp.argmax(jnp.any(mask_arr, axis=(0, 1))))
Xmax = float(
self.sdt.shape[0] - jnp.argmax(jnp.any(mask_arr, axis=(1, 2))[::-1])
)
Ymax = float(
self.sdt.shape[1] - jnp.argmax(jnp.any(mask_arr, axis=(0, 2))[::-1])
)
Zmax = float(
self.sdt.shape[2] - jnp.argmax(jnp.any(mask_arr, axis=(0, 1))[::-1])
)
return (Xmin, Ymin, Zmin), (Xmax, Ymax, Zmax)
@property
def sdt(self):
"""Return full STD data.
Returns:
jnp.ndarray: STD
"""
if not self.is_load:
self.load_data()
return self._sdt
[docs] def load_data(self):
"""Load data from the file"""
self._sdt = jnp.load(self.filepath)
self.is_load = True
[docs] def request(self, ps_xyz: jnp.ndarray) -> jnp.ndarray:
"""
Calculate the distance from points to the surface
:param ps_xyz: points in the normalized space
:return: distances in SDT form
"""
assert ps_xyz.ndim >= 1
assert ps_xyz.shape[-1] == 3
if ps_xyz.ndim == 1:
p_array_ = ps_xyz[jnp.newaxis, :]
else:
p_array_ = ps_xyz
p_array_ = p_array_.reshape((-1, 3))
req_x = p_array_[:, 0]
req_y = p_array_[:, 1]
req_z = p_array_[:, 2]
x = jnp.rint(jnp.clip(req_x, 0, self.sdt.shape[0] - 1)).astype(int)
y = jnp.rint(jnp.clip(req_y, 0, self.sdt.shape[1] - 1)).astype(int)
z = jnp.rint(jnp.clip(req_z, 0, self.sdt.shape[2] - 1)).astype(int)
adv_x = req_x - x
adv_y = req_y - y
adv_z = req_z - z
result = self.sdt[x, y, z]
result = result + jnp.sqrt(adv_x**2 + adv_y**2 + adv_z**2)
ret_shape = list(ps_xyz.shape)
ret_shape[-1] = 1
result = result.reshape(ret_shape)
return result
[docs] def unload_data(self):
self._sdt = None
self.is_load = False
[docs] def is_load(self) -> bool:
return self.is_load
[docs]class IR1Loader(AbstractLoader):
def __init__(self, filepath: str):
self.filepath = filepath
self.is_load = False
self.json_ = None
self.functions_ = None
self.params_ = None
self.bbox_ = ((0.0, 0.0, 0.0), (0.0, 0.0, 0.0))
@property
def json(self):
if not self.is_load:
self.load_data()
return self.json_
@property
def functions(self):
if not self.is_load:
self.load_data()
return self.functions_
@property
def params(self):
if not self.is_load:
self.load_data()
return self.params_
@property
def bbox(self):
if not self.is_load:
self.load_data()
return self.bbox_
[docs] def load_data(self):
with open(self.filepath, "rb") as input_file:
self.json_ = pickle.load(input_file)
version_ = self.json_["version"]
trainable_task_ = self.json_["trainable_task"]
repr_ = self.json_["repr"]
history_loss_ = self.json_["history_loss"]
params_ = self.json_["params"]
bbox_ = self.json_["bbox"]
if version.parse(nndt.__version__) < version.parse(version_):
warnings.warn(
"Loaded neural network was created on earlier version of NNDT!"
)
task = SimpleSDF(**trainable_task_)
rng = jax.random.PRNGKey(42)
_, self.F = task.init_and_functions(rng)
self.functions_ = self.F
self.params_ = params_
self.bbox_ = bbox_
self.is_load = True
[docs] def calc_bbox(self) -> ((float, float, float), (float, float, float)):
return self.bbox_
[docs] def unload_data(self):
self.json_ = None
self.functions_ = None
self.params_ = None
self.bbox_ = None
[docs] def is_load(self) -> bool:
"""Return is file loaded.
Returns:
bool: File load status.
"""
return self.is_load
DICT_LOADERTYPE_CLASS = {
"txt": TXTLoader,
"sdt": SDTLoader,
"mesh_obj": MeshObjLoader,
"implicit_ir1": IR1Loader,
"undefined": EmptyLoader,
}
DICT_CLASS_LOADERTYPE = {(v, k) for k, v in DICT_LOADERTYPE_CLASS.items()}