from abc import abstractmethod
from typing import Callable
import jax
import jax.numpy as jnp
[docs]def fun2vec_and_grad(prim):
vec_prim = jax.vmap(prim)
prim_x = jax.grad(prim, argnums=0)
prim_y = jax.grad(prim, argnums=1)
prim_z = jax.grad(prim, argnums=2)
vec_prim_x = jax.vmap(prim_x)
vec_prim_y = jax.vmap(prim_y)
vec_prim_z = jax.vmap(prim_z)
return vec_prim, vec_prim_x, vec_prim_y, vec_prim_z
[docs]class AbstractSDF:
def __init__(self):
self._fun = self._get_fun()
tpl = fun2vec_and_grad(self._fun)
self._vec_fun = tpl[0]
self._vec_fun_x = tpl[1]
self._vec_fun_y = tpl[2]
self._vec_fun_z = tpl[3]
@abstractmethod
def _get_fun(self):
pass
@property
@abstractmethod
def bbox(self) -> ((float, float, float), (float, float, float)):
"""
Return the minimal bounding box around the implicitly defined object.
:return: (X_min, Y_min, Z_min) , (X_max, Y_max, Z_max)
"""
return (0.0, 0.0, 0.0), (0.0, 0.0, 0.0)
@property
def fun(self) -> Callable:
"""
Get the SDF function in scalar form
:return: `f(x,y,z) = distance`
"""
return self._fun
@property
def vec_fun(self) -> Callable:
"""
Get the SDF function in vector form. Vectorization is performed along the zero axis.
:return: `f(vec_x, vec_y, vec_z) = vec_distance`
"""
return self._vec_fun
@property
def vec_fun_dx(self) -> Callable:
"""
Get the gradient of the SDF function over the X-axis. Vectorization is performed along the zero axis.
:return: `df/dx(vec_x, vec_y, vec_z)`
"""
return self._vec_fun_x
@property
def vec_fun_dy(self) -> Callable:
"""
Get the gradient of the SDF function over the Y-axis. Vectorization is performed along the zero axis.
:return: `df/dy(vec_x, vec_y, vec_z)`
"""
return self._vec_fun_y
@property
def vec_fun_dz(self) -> Callable:
"""
Get the gradient of the SDF function over the Z-axis. Vectorization is performed along the zero axis.
:return: `df/dz(vec_x, vec_y, vec_z)`
"""
return self._vec_fun_z
[docs] def request(self, ps_xyz: jnp.ndarray) -> jnp.ndarray:
"""
Get SDF values for the requested location on the physical space.
:return: distance values
"""
assert ps_xyz.shape[-1] == 3
ret_shape = list(ps_xyz.shape)
ret_shape[-1] = 1
ret_shape = tuple(ret_shape)
x = ps_xyz[..., 0].flatten()
y = ps_xyz[..., 1].flatten()
z = ps_xyz[..., 2].flatten()
dist = self._vec_fun(x, y, z)
dist = dist.reshape(ret_shape)
return dist
[docs]class SphereSDF(AbstractSDF):
"""
This is a sphere geometrical primitive.
"""
[docs] def __init__(self, center=(0.0, 0.0, 0.0), radius=1.0):
"""
This is a sphere geometrical primitive.
:param center: center of the sphere
:param radius: radius of the sphere
"""
assert radius > 0.0
self.center = center
self.radius = radius
super(SphereSDF, self).__init__()
@property
def bbox(self) -> ((float, float, float), (float, float, float)):
min_ = (
(self.center[0] - self.radius),
(self.center[1] - self.radius),
(self.center[2] - self.radius),
)
max_ = (
(self.center[0] + self.radius),
(self.center[1] + self.radius),
(self.center[2] + self.radius),
)
return min_, max_
def _get_fun(self):
center = self.center
radius = self.radius
def prim(x: float, y: float, z: float):
sdf = (
(x - center[0]) ** 2
+ (y - center[1]) ** 2
+ (z - center[2]) ** 2
- radius**2
)
return sdf
return prim