Source code for nndt.vizualize

import os
import pickle
import time
from typing import *

import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as np
import numpy as onp
from mpl_toolkits.axes_grid1 import AxesGrid

from nndt.space2.utils import fix_file_extension


[docs]class IteratorWithTimeMeasurements: """Iterator that records and prints the epoch number and time spent from the start of iterations""" def __init__(self, basic_viz, epochs): self.basic_viz = basic_viz self.time_start = time.time() self.time_previous = self.time_start self.epochs = epochs self.counter = 0 def __iter__(self): return self def __next__(self): self.basic_viz.record({"_epoch": self.counter}) time_full = time.time() - self.time_start self.basic_viz.record({"_time": time_full}) if self.basic_viz.is_print_on_epoch(self.counter): str_ = f"[E:{self.basic_viz._records['_epoch'][-1]},T:{self.basic_viz._records['_time'][-1]:.01f}] " for k, v in self.basic_viz._records.items(): if not k.startswith("_"): str_ = str_ + f"{k}={v[-1]}, " str_ = str_ + "\n" print(str_) if self.counter > self.epochs: raise StopIteration() self.counter += 1 return self.counter - 1 def __len__(self): return self.epochs
[docs]def save_sdt_as_obj( array: Union[jnp.ndarray, onp.ndarray], path: str, level: float = 0.0 ): """ Run marching cubes over SDT and save results to file Parameters ---------- filename : string File name array : ndarray Signed distance tensor (SDT) level : float Isosurface level (defaults to 0.). """ assert array.ndim == 3 array_ = onp.array(array) from nndt.space2.utils import array_to_vert_and_faces, save_verts_and_faces_to_obj verts, faces = array_to_vert_and_faces(array_, level=level) save_verts_and_faces_to_obj(fix_file_extension(path, ".obj"), verts, faces)
[docs]def save_3D_slices( array: Union[onp.ndarray, jnp.ndarray], path: str = None, slice_num: int = 5, include_boundary=True, figsize=None, levels=(0.0,), level_colors=("white",), **kwargs, ): """ Generates a panel of images with slices of the 3D array. This is a helper function for studying 3D tensors. :param path: path to the image for write :param array: studied array :param slice_num: number of slices over array axis :param include_boundary: If True, the image will include boundaries of the array with indexes 0 and len(array)-1 :param figsize: the size of the image. If None, the size will be calculated according to the number of panels. :param levels: Isoline values. This param is ignored if RGB/RGBA image is passed. :param level_colors: Isoline colors. This param is ignored if RGB/RGBA image is passed. :param kwargs: parameter set that is passed to the `.imshow()` method :return: none """ panel_size = slice_num if include_boundary else slice_num - 2 assert panel_size > 0 assert slice_num > 0 and panel_size > 0 assert array.ndim == 3 or (array.ndim == 4 and (array.shape[-1] in (1, 3, 4))) assert len(levels) == len(level_colors) is_color = array.ndim == 4 and (array.shape[-1] in (3, 4)) if figsize is None: figsize = (3 * panel_size, 3 * 3) fig = plt.figure(figsize=figsize) if is_color: grid = AxesGrid( fig, 111, nrows_ncols=(3, panel_size), axes_pad=0.05, label_mode="L", ) else: grid = AxesGrid( fig, 111, nrows_ncols=(3, panel_size), axes_pad=0.05, cbar_mode="single", cbar_location="right", cbar_pad=0.1, label_mode="L", ) slices_x = np.linspace(0, array.shape[0] - 1, slice_num).astype(int) slices_y = np.linspace(0, array.shape[1] - 1, slice_num).astype(int) slices_z = np.linspace(0, array.shape[2] - 1, slice_num).astype(int) if not include_boundary: slices_x = slices_x[1:-1] slices_y = slices_y[1:-1] slices_z = slices_z[1:-1] array_ = array[..., np.newaxis] if (array.ndim == 3) else array for ind_ax, ax in zip(slices_x, grid[:panel_size]): im = ax.imshow( array_[ind_ax, :, :, 0:3].squeeze(), vmin=float(jnp.nanmin(array_)), vmax=float(jnp.nanmax(array_)), **kwargs, ) if not is_color and len(levels) > 0 and len(level_colors) > 0: _cs2 = ax.contour( array_[ind_ax, :, :, 0:3].squeeze(), levels=levels, origin="lower", colors=level_colors, ) for ind_ax, ax in zip(slices_y, grid[panel_size : panel_size * 2]): im = ax.imshow( array_[:, ind_ax, :, 0:3].squeeze(), vmin=float(jnp.nanmin(array_)), vmax=float(jnp.nanmax(array_)), **kwargs, ) if not is_color and len(levels) > 0 and len(level_colors) > 0: _cs2 = ax.contour( array_[:, ind_ax, :, 0:3].squeeze(), levels=levels, origin="lower", colors=level_colors, ) for ind_ax, ax in zip(slices_z, grid[2 * panel_size :]): im = ax.imshow( array_[:, :, ind_ax, 0:3].squeeze(), vmin=float(jnp.nanmin(array_)), vmax=float(jnp.nanmax(array_)), **kwargs, ) if not is_color and len(levels) > 0 and len(level_colors) > 0: _cs2 = ax.contour( array_[:, :, ind_ax, 0:3].squeeze(), levels=levels, origin="lower", colors=level_colors, ) if not is_color: cbar = ax.cax.colorbar(im) cbar = grid.cbar_axes[0].colorbar(im) if not is_color and len(levels) > 0 and len(level_colors) > 0: cbar.add_lines(_cs2) if path is not None: fig.savefig(path) else: plt.show()
[docs]class BasicVizualization: """ Simple MLOps class for storing the train history and visualization of intermediate results """
[docs] def __init__( self, folder: str, experiment_name: Optional[str] = None, print_on_each_epoch=20 ): """ Simple MLOps class for storing the train history and visualization of intermediate results :param folder: folder for store results :param experiment_name: name for an experiments :param print_on_each_epoch: this parameter helps to control intermediate result output """ self.folder = folder self.experiment_name = ( experiment_name if (experiment_name is not None) else folder ) os.makedirs(self.folder, exist_ok=True) self.print_on_each_epoch = print_on_each_epoch self._records = {"_epoch": [], "_time": []}
[docs] def iter(self, epoch_num): """Return iterators for the main train cycle Parameters ---------- epoch_num : int number of epoch Returns ------- instance of IteratorWithTimeMeasurements """ return IteratorWithTimeMeasurements(self, epoch_num)
[docs] def record(self, dict): for k, v in dict.items(): if k in self._records: self._records[k].append(v) else: self._records[k] = [] self._records[k].append(v)
[docs] def is_print_on_epoch(self, epoch): """Check if this is the right epoch to print results Parameters ---------- epoch : int epoch number Returns ------- bool Should we print on this step? """ return (epoch % self.print_on_each_epoch) == 0
[docs] def draw_loss(self, name, history): """Save the training history in .jpg Parameters ---------- name : string File name history (_type_): List of loss values over epochs """ plt.close(1) plt.figure(1) plt.semilogy(history) plt.title(f"{self.experiment_name}_{name}") plt.grid() plt.savefig(os.path.join(self.folder, f"{name}.jpg"))
[docs] def save_state(self, name, state): """Save the neural network state into the file Parameters ---------- name : string File name state : (_type_) The state to save """ pickle.dump(state, open(os.path.join(self.folder, f"{name}.pkl"), "wb"))
[docs] def save_txt(self, name, summary): """Save string data to .txt file Parameters ---------- name : string File name summary : string The text to save """ with open(os.path.join(self.folder, f"{name}.txt"), "w") as fl: fl.write(summary)
[docs] def sdt_to_obj( self, filename: str, array: Union[jnp.ndarray, onp.ndarray], level: float = 0.0 ): """Run marching cubes over SDT and save results to file Parameters ---------- filename : string File name array : ndarray Signed distance tensor (SDT) level : float Isosurface level (defaults to 0.). """ save_sdt_as_obj(array, os.path.join(self.folder, f"{filename}.obj"), level)
[docs] def save_mesh(self, name, save_method, dict_): """Save mesh to .vtp file with data Parameters ---------- name : string filename save_method : SaveMesh SaveMesh instance from NNDT space (v0.0.1 or v0.0.2) dict_ : dict name_value """ save_method(os.path.join(self.folder, f"{name}.vtp"), dict_)
[docs] def save_3D_array(self, name, array, section_img=True): """Save the 3D array to a file and section of this array as images Parameters ---------- name : string File name array : array 3D array to save section_img : bool If true, this saves three plane section of 3D array (defaults to True) """ assert array.ndim == 3 jnp.save(os.path.join(self.folder, f"{name}.npy"), array) if section_img: plt.close(1) plt.figure(1) plt.title(f"{self.experiment_name}_{name}_0") plt.imshow(array[array.shape[0] // 2, :, :]) plt.colorbar() plt.savefig(os.path.join(self.folder, f"{name}_0.jpg")) plt.close(1) plt.figure(1) plt.title(f"{self.experiment_name}_{name}_1") plt.imshow(array[:, array.shape[1] // 2, :]) plt.colorbar() plt.savefig(os.path.join(self.folder, f"{name}_1.jpg")) plt.close(1) plt.figure(1) plt.title(f"{self.experiment_name}_{name}_2") plt.imshow(array[:, :, array.shape[2] // 2]) plt.colorbar() plt.savefig(os.path.join(self.folder, f"{name}_2.jpg"))