Source code for castalign.base

import os
from .compat import CURRENT_FILE_FORMAT_VERSION, apply_legacy_class_remappings, get_legacy_eval_namespace
import concurrent.futures
import numpy as np
import scipy
from .ndarray_shifted import ndarray_shifted
from .utils import blit, image_is_label, invert_function_numerical, rotation_matrix
import threadpoolctl
try:
    import cupy as cp
    from cupyx.scipy import ndimage as cupyx_ndimage
    try:
        GPU_AVAILABLE = cp.cuda.runtime.getDeviceCount() > 0
    except Exception:
        GPU_AVAILABLE = False
except Exception:
    cp = None
    cupyx_ndimage = None
    GPU_AVAILABLE = False

# TODO:
# - implement posttransforms, allowing the unfitted transform to be on the left hand side


def _gpu_chunksize(y_len, x_len, default_chunksize=4_000_000):
    if not GPU_AVAILABLE:
        return default_chunksize
    try:
        free_mem, _ = cp.cuda.runtime.memGetInfo()
        # Rough budget for coords, mapped coords, and sampled output.
        bytes_per_voxel = 28
        budget = max(int(free_mem * 0.50), 1)
        voxel_budget = max(budget // bytes_per_voxel, 1)
        plane = max(y_len * x_len, 1)
        z_per_chunk = max(voxel_budget // plane, 1)
        return max(z_per_chunk * plane, 1)
    except Exception:
        return default_chunksize


[docs] class Transform: """Base class for all transforms. This is the core transform interface used across CASTalign. A transform maps 3D volume coordinates from one space to another, and can also be applied to full volumetric images through resampling. Conceptually, transforms in CASTalign are usually either: * point-based (see ``PointTransform``), or * parameterized transforms that do not use correspondence points. ``DEFAULT_PARAMETERS`` serves two purposes: * it provides default values, and * it defines the complete set of constructor keyword parameters accepted by that transform class. Parameter values are stored in ``self.params`` and can be edited in the GUI. These are for user-settable controls (for example shifts, scales, toggles), not values that are fit from selected points. Notes ----- To implement a new subclass: * Implement ``_transform(points)`` for forward mapping. * Implement ``invert()`` for reverse mapping. * Implement ``_fit()`` if parameters must be derived from point data. ``_fit()`` is called during initialization when present. * If there is no analytic inverse, subclass ``PointTransformNoAnalyticInverse``. Any transform must be exactly reconstructible from ``__repr__`` output. """ NAME = "" SHORTCUT_KEY = "" SORT_WEIGHT = 0 # How to sort when listing all transforms DEFAULT_PARAMETERS = {} GUI_DRAG_PARAMETERS = [None, None, None] def __init__(self, **kwargs): # Initialise parameters to either pass values or defaults self.params = {} for k in self.DEFAULT_PARAMETERS.keys(): self.params[k] = kwargs[k] if k in kwargs else self.DEFAULT_PARAMETERS[k] # Check for invalid arguments for k in kwargs.keys(): assert k in self.DEFAULT_PARAMETERS.keys(), f"Keyword argument {k} is not valid for the transform {type(self)}, instead try {list(self.DEFAULT_PARAMETERS.keys())}" # If this transform needs to be fit, then fit it. if hasattr(self, "_fit"): self._fit() def __repr__(self): ret = self.__class__.__name__ ret += "(" parts = [] for k,v in self.params.items(): parts.append(f"{k}={v}") ret += ", ".join(parts) ret += ")" return ret def __eq__(self, other): return repr(self) == repr(other) def __add__(self, other): return compose_transforms(self, other) def __call__(self, data, *args, **kwargs): """Apply this transform to points or an image. This convenience method lets you use a transform like a function. If ``data`` looks like point coordinates, it routes to ``transform``; otherwise it treats ``data`` as image data and routes to ``transform_image``. Parameters ---------- data : array-like Either: * 3D point coordinates in ``(z, y, x)`` form (``(3,)`` or ``(N, 3)``), or * image-like array data. *args, **kwargs Forwarded to ``transform`` or ``transform_image`` depending on input type. Returns ------- numpy.ndarray or ndarray_shifted Transformed points or transformed image data. """ adata = np.asarray(data) if (adata.ndim == 2 and adata.shape[1] == 3) or (adata.ndim == 1 and adata.shape[0] == 3): return self.transform(adata, *args, **kwargs) return self.transform_image(adata, *args, **kwargs)
[docs] def save(self, filename): """Save this transform to disk as a reconstructible text string. Parameters ---------- filename : str or path-like Output file path. Notes ----- The saved content is ``repr(self)`` and is intended to be read back with :meth:`load`. """ with open(filename, "w") as f: f.write(repr(self))
[docs] @staticmethod def load(filename, version=None): """Load a transform from a file written by :meth:`save`. Parameters ---------- filename : str or path-like Input file path containing a transform repr string. version : int or None, optional File format version for this transform text. If ``None``, uses the current format version. Set to ``1`` for older files that may use legacy class names. Returns ------- Transform Reconstructed transform object. """ with open(filename, "r") as f: text = f.read() if version is None: version = CURRENT_FILE_FORMAT_VERSION eval_namespace = dict(globals()) if version < CURRENT_FILE_FORMAT_VERSION: print( f"Loading legacy transform format version {version}. " f"It will be saved as version {CURRENT_FILE_FORMAT_VERSION} when you save it." ) text = apply_legacy_class_remappings(text) eval_namespace.update(get_legacy_eval_namespace(eval_namespace)) return eval(text, eval_namespace, None)
[docs] def transform(self, points): """Transform 3D point coordinates from source space to target space. Use this when you want to move landmarks, annotations, or other coordinate sets into the transformed image space. Parameters ---------- points : array-like Point coordinates in ``(z, y, x)`` order. Accepts either a single point ``(3,)`` or many points ``(N, 3)``. Returns ------- numpy.ndarray Transformed coordinates, with the same single-point vs multi-point structure as the input. """ in_cupy_format = GPU_AVAILABLE and isinstance(points, cp.ndarray) use_cupy_internally = GPU_AVAILABLE and self._has_gpu_transform() points = cp.asarray(points) if use_cupy_internally else (cp.asnumpy(points) if in_cupy_format else np.asarray(points)) is_1d = False if points.ndim == 1: points = points[None] is_1d = True if 0 in points.shape: # If any dimensions don't exist, no points to transform return cp.asarray(points) if in_cupy_format else (cp.asnumpy(points) if use_cupy_internally else np.asarray(points)) assert points.shape[1] == 3, "Input points must be in volume space" if use_cupy_internally: out = self._transform_gpu(points) else: out = self._transform(points) out = cp.asarray(out) if in_cupy_format else (cp.asnumpy(out) if use_cupy_internally else out) return out[0] if is_1d else out
def _transform(self, points): """Map points from source space to target space. Parameters ---------- points : numpy.ndarray Array with shape ``(N, 3)``. Returns ------- numpy.ndarray Transformed points with shape ``(N, 3)``. """ raise NotImplementedError("Please subclass and replace") def _transform_gpu(self, points): """GPU variant of :meth:`_transform` using CuPy arrays. Subclasses can override this to enable GPU-accelerated point mapping. Parameters ---------- points : cupy.ndarray Array with shape ``(N, 3)`` on GPU. Returns ------- cupy.ndarray Transformed points with shape ``(N, 3)``. """ raise NotImplementedError("GPU path not implemented for this transform") def _has_gpu_transform(self): return GPU_AVAILABLE and (self.__class__._transform_gpu is not Transform._transform_gpu)
[docs] def inverse_transform(self, points): """Map points from target space back to source space. Parameters ---------- points : numpy.ndarray Array with shape ``(N, 3)`` or a single point ``(3,)``. Returns ------- numpy.ndarray Inverse-transformed points with the same leading shape as input. Notes ----- Override this for a faster implementation; default is ``self.invert().transform(points)``. """ return self.invert().transform(points)
[docs] def invert(self): """Return the inverse transform. Use this when you need to map coordinates or images in the opposite direction (target space back to source space). Returns ------- Transform A transform representing the inverse mapping. Notes ----- Subclasses must implement this. """ raise NotImplementedError("Please subclass and replace")
[docs] def origin_and_maxpos(self, img, output_size=None, force_size=True): """Compute output-space bounds for a transformed image. This is used internally to inspect or control the output coordinate box before calling ``transform_image``. This is especially useful when you need consistent output extents across multiple transformed volumes. Parameters ---------- img : numpy.ndarray or ndarray_shifted Input 3D image. output_size : None, sequence, or sequence of 2-tuples, optional Output bounds. * ``None``: tight bounds from transformed corners. * ``(z, y, x)``: explicit upper bounds from origin 0. * ``((zmin, zmax), (ymin, ymax), (xmin, xmax))``: explicit bounds. ``None`` values inside explicit bounds are treated as open bounds. force_size : bool, optional If ``True``, use explicit bounds exactly. If ``False``, treat them as limits and allow smaller computed bounds. Returns ------- tuple of numpy.ndarray ``(origin, maxpos)`` each with shape ``(3,)``. Examples -------- :: >>> # Use automatic tight bounds. >>> origin, maxpos = t.origin_and_maxpos(img) >>> # Force a specific output size from origin 0. >>> origin, maxpos = t.origin_and_maxpos(img, output_size=(80, 256, 256)) >>> # Force explicit coordinate bounds. >>> origin, maxpos = t.origin_and_maxpos( ... img, ... output_size=((10, 90), (20, 220), (30, 230)), ... force_size=True, ... ) """ input_bounds = img.shape origin_offset = img.origin if isinstance(img, ndarray_shifted) else [0,0,0] corners_pretransform = [[a, b, c] for a in [0, input_bounds[0]] for b in [0, input_bounds[1]] for c in [0, input_bounds[2]]] corners_pretransform = (corners_pretransform) + np.asarray(origin_offset) if isinstance(self, PointTransform): # For nonlinear transforms a box is not enough. This is also not enough but it is better than nothing. origin = np.min(np.concatenate([self.transform(corners_pretransform), self.points_end]), axis=0).astype("float32") maxpos = np.max(np.concatenate([self.transform(corners_pretransform), self.points_end]), axis=0).astype("float32") else: origin = np.min(self.transform(corners_pretransform), axis=0).astype("float32") maxpos = np.max(self.transform(corners_pretransform), axis=0).astype("float32") if output_size is None: # the default pass elif isinstance(output_size, (list,tuple,np.ndarray)): # Manually specifying coordinates maxpos_ = np.asarray([r[1] if isinstance(r, (list,tuple,np.ndarray)) else r if r is not None else np.inf for r in output_size], dtype="float32") origin_ = np.asarray([r[0] if isinstance(r, (list,tuple,np.ndarray)) else 0 if r is not None else -np.inf for r in output_size], dtype="float32") if force_size: origin = origin_ maxpos = maxpos_ else: origin = np.max([origin_, origin], axis=0) maxpos = np.min([maxpos_, maxpos], axis=0) else: raise ValueError(f"Invalid value of `output_size` passed: {output_size}") return origin,maxpos
[docs] def transform_image(self, img, output_size=None, labels=None, force_size=True): """Transform an image ``output_size`` controls the output coordinate box. If it is ``None``, the transformed image uses a tight bounding box based on transformed corners (and point targets for point-based transforms). If it is explicit, it can be either ``(z, y, x)`` max bounds from origin 0, or full per-axis bounds ``((zmin, zmax), (ymin, ymax), (xmin, xmax))``. With ``force_size=True`` these bounds are used exactly; with ``force_size=False`` they are treated as limits and the method may return a smaller box when possible. Parameters ---------- img : numpy.ndarray or ndarray_shifted Input image. 2D images are promoted to 3D single-slice volumes. output_size : None, sequence, or sequence of 2-tuples, optional Output bounds, same formats as ``origin_and_maxpos``. labels : bool or None, optional Label mode. * ``True``: nearest-neighbor interpolation. * ``False``: linear interpolation. * ``None``: auto-detect with ``image_is_label``. force_size : bool, optional If ``False``, output size may be smaller than ``output_size`` if the output would include empty space Returns ------- ndarray_shifted or numpy.ndarray Transformed image in output coordinates. Notes ----- Generic implementation for non-rigid transforms. Subclasses can override with faster special cases. Examples -------- :: >>> # Transform with automatic tight output bounds. >>> out = t.transform_image(img) >>> # Transform a label volume with nearest-neighbor interpolation. >>> out = t.transform_image(labels_img, labels=True) >>> # Transform with explicit output bounds. >>> out = t.transform_image( ... img, ... output_size=((10, 90), (20, 220), (30, 230)), ... ) """ # First, if we have an ndarray_shifted object, shift it first with another transform. if isinstance(img, ndarray_shifted) and np.any(img.origin != np.asarray([0,0,0])): shift = TranslateParametric(z=img.origin[0], y=img.origin[1], x=img.origin[2]) return (shift + self).transform_image(np.asarray(img), output_size=output_size, labels=labels, force_size=force_size) # Housekeeping if labels is None: labels = image_is_label(img) if img.ndim == 2: img = img[None] origin, maxpos = self.origin_and_maxpos(img, output_size=output_size, force_size=force_size) shape = (maxpos - origin).astype(int) img_offset = ndarray_shifted(img).origin img = np.ascontiguousarray(img, dtype=np.float32 if not labels else img.dtype) origin = origin.astype(np.float32) # shape = np.round(np.ceil(maxpos - origin)/downsample_output).astype(int) # Maybe this is better? # shape = np.asarray(img.shape).astype(int) # Pulled from old code if img.shape[0] == 1: # This is a hack to get around thickness=1 images disappearing in the map_coordinates function img = np.concatenate([img, img]) if img.shape[1] == 1: img = img*np.ones((1,2,1), dtype=img.dtype) if img.shape[2] == 1: img = img*np.ones((1,1,2), dtype=img.dtype) # For memory efficiency, we split coords into chunks zcoords = np.arange(0, shape[0], dtype="float32") ycoords = np.arange(0,shape[1], dtype="float32") xcoords = np.arange(0,shape[2], dtype="float32") def chunker(zcoords, ycoords, xcoords, chunksize=10_000_000): """Yield z-plane chunks to limit peak memory usage. Parameters ---------- zcoords, ycoords, xcoords : numpy.ndarray Output coordinate vectors. chunksize : int, optional Approximate voxel budget per chunk. Yields ------ tuple ``(coords, chunk_shape, inds)`` for one chunk. """ zsize = len(ycoords)*len(xcoords) n_z_per_chunk = np.maximum(chunksize // zsize, 1) single_plane = np.asarray([np.repeat(ycoords, len(xcoords)), np.tile(xcoords, len(ycoords))], dtype="float32") zcoords_float32 = np.asarray(zcoords, dtype="float32") for i in range(0, 100000): zfrom, zto = n_z_per_chunk*i,n_z_per_chunk*(i+1) if zfrom >= len(zcoords): return if zto > len(zcoords): zto = len(zcoords) inds = (slice(zfrom,zto), slice(None), slice(None)) chunk_shape = (zto-zfrom, len(ycoords), len(xcoords)) coords = np.concatenate([np.repeat(zcoords_float32[zfrom:zto], single_plane.shape[1])[None,:], np.tile(single_plane, zto-zfrom)], axis=0).T yield coords, chunk_shape, inds # First, we construct a list of coordinates of all the pixels in the # image, and transform them to find out which point is mapped to which # other point. Then, we inverse transform them to construct a matrix of # mappings. We turn this matrix of mappings into a matrix of pointers # from the destination image to the source image, and then use the # map_coordinates function to perform this mapping. output = np.zeros((len(zcoords),len(ycoords),len(xcoords)), dtype=(img.dtype if labels else "float32")) use_gpu_image = GPU_AVAILABLE and self._has_gpu_transform() def _process_chunk(args): grid, chunk_shape, inds = args grid = grid + origin mapped = self.inverse_transform(grid) if isinstance(img, ndarray_shifted): mapped += img.origin disp = mapped.reshape(*chunk_shape, 3).transpose(3, 0, 1, 2) block = scipy.ndimage.map_coordinates(img, disp, prefilter=False, order=(0 if labels else 1)) return inds, block if use_gpu_image: try: inverse_transform_gpu = self.invert()._transform_gpu chunksize = _gpu_chunksize(len(ycoords), len(xcoords)) img_gpu = cp.asarray(img) origin_gpu = cp.asarray(origin, dtype=cp.float32) plane = max(len(ycoords) * len(xcoords), 1) n_z_per_chunk = max(chunksize // plane, 1) y_gpu = cp.asarray(ycoords, dtype=cp.float32)[None, :, None] x_gpu = cp.asarray(xcoords, dtype=cp.float32)[None, None, :] for zfrom in range(0, len(zcoords), n_z_per_chunk): zto = min(zfrom + n_z_per_chunk, len(zcoords)) inds = (slice(zfrom, zto), slice(None), slice(None)) chunk_shape = (zto - zfrom, len(ycoords), len(xcoords)) z_gpu = cp.arange(zfrom, zto, dtype=cp.float32)[:, None, None] grid_gpu = cp.stack( [ cp.broadcast_to(z_gpu, chunk_shape), cp.broadcast_to(y_gpu, chunk_shape), cp.broadcast_to(x_gpu, chunk_shape), ], axis=-1, ).reshape(-1, 3) mapped_gpu = inverse_transform_gpu(grid_gpu + origin_gpu) disp_gpu = mapped_gpu.reshape(*chunk_shape, 3).transpose(3, 0, 1, 2) block_gpu = cupyx_ndimage.map_coordinates(img_gpu, disp_gpu, prefilter=False, order=(0 if labels else 1)) output[inds] = cp.asnumpy(block_gpu) return ndarray_shifted(output, origin=origin, only_if_necessary=True) except Exception: pass with threadpoolctl.threadpool_limits(limits=1): # Parallelise here, so disable parallelisation on threads with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as pool: for inds, block in pool.map(_process_chunk, chunker(zcoords, ycoords, xcoords, chunksize=4_000_000)): output[inds] = block return ndarray_shifted(output, origin=origin, only_if_necessary=True) # Added -origin from origin due to TranslateParametric + Rescale on a ndarray_shifted but not sure if this is the right spot
[docs] @staticmethod def pretransform(*args, **kwargs): """Return the fixed pre-transform applied before this transform. Returns ------- Transform Pre-transform to apply first. The default is ``Identity()``. Notes ----- Most transform classes should keep this default. Override it when a class represents a composed transform where only part of the chain is meant to be fit from data (for example, when an earlier component is fixed and only the final component is fit). """ return Identity()
[docs] class PointTransform(Transform): """Base class for transforms fit from point correspondences. This class stores matched 3D points and is the base for transforms that are learned from those matches. Notes ----- Subclasses can rely on: * ``self.points_start``: source coordinates. * ``self.points_end``: target coordinates. Subclasses are expected to implement the standard transform behavior: * forward mapping through ``_transform``, * inverse mapping through ``invert``, * and optional fitting logic in ``_fit``. """ def __init__(self, points_start=None, points_end=None, **kwargs): """Initialize a point-based transform from matched coordinates. Parameters ---------- points_start : array-like, optional Source-space points, shaped ``(N, 3)`` in ``(z, y, x)`` order. points_end : array-like, optional Target-space points matched one-to-one with ``points_start``, shaped ``(N, 3)``. **kwargs Additional transform parameters accepted by this class (defined by ``DEFAULT_PARAMETERS`` on the subclass). Raises ------ AssertionError If ``points_start`` and ``points_end`` do not have the same shape. """ # Save and process the points for the transform points_start = np.asarray(points_start) points_end = np.asarray(points_end) assert points_start.shape == points_end.shape, "Points start and end must be the same size" self.points_start = points_start self.points_end = points_end super().__init__(**kwargs) @classmethod def from_transform(cls, transform, *args, **kwargs): """Construct a point transform from another transform's points. Parameters ---------- transform : Transform Existing transform providing ``points_start`` and ``points_end``. *args, **kwargs Additional constructor args for ``cls``. Returns ------- PointTransform New instance initialized with copied correspondences. """ return cls(points_start=transform.points_start, points_end=transform.points_end, *args, **kwargs) def __repr__(self): ret = self.__class__.__name__ ret += "(" ret += f"points_start={self.points_start.tolist()}, points_end={self.points_end.tolist()}" for k,v in self.params.items(): ret += f", {k}={v}" ret += ")" return ret
[docs] class AffineTransform: """Mixin implementing affine transform behavior. This class provides shared affine logic for transform classes that represent mappings of the form: ``output = input @ matrix - shift``. Notes ----- Subclasses should use ``_fit`` to set: * ``self.matrix`` with shape ``(3, 3)`` * ``self.shift`` with shape ``(3,)`` This mixin is designed for multiple inheritance with ``Transform`` or ``PointTransform``-derived classes. """ def _transform(self, points): return points @ self.matrix - self.shift def _transform_gpu(self, points): _matrix_gpu = cp.asarray(self.matrix, dtype=cp.float32) _shift_gpu = cp.asarray(self.shift, dtype=cp.float32) return points @ _matrix_gpu - _shift_gpu def transform_image(self, image, output_size=None, labels=None, force_size=True): # Optimisation for the case where no image transform needs to be # performed. if np.all(self.matrix == np.eye(3)) and ((not isinstance(image, ndarray_shifted)) or np.all(image.origin==[0,0,0])): if output_size is None: return ndarray_shifted(image, origin=-self.shift, only_if_necessary=True) # else: # newimg = np.zeros_like(image) # blit(image, newimg, self.shift) # TODO test, not sure if this works # return newimg return super().transform_image(image, output_size=output_size, labels=labels, force_size=force_size) def invert(self): """Invert the transform. Note: This will return incorrect results for some non-affine transforms. Currently it just swaps the order of the points. """ return self.__class__(points_start=self.points_end, points_end=self.points_start, **self.params)
# TODO improve optimisation with the jacobian
[docs] class PointTransformNoAnalyticInverse(PointTransform): """Base class for point transforms without an analytic inverse. This class uses numerical inversion when needed, for transforms that are still bijective but do not have a closed-form inverse. Notes ----- Subclasses should: * include ``{"invert": False}`` in ``DEFAULT_PARAMETERS`` * implement ``_transform(points, points_start, points_end)`` By default, this class treats ``_transform`` as the inverse-direction map, which is typically the faster direction for image resampling workflows. Passing ``invert=False`` flips that behavior. """ def __init__(self, *args, **kwargs): self._inv_transform_cache = {} super().__init__(*args, **kwargs) # Use the inverse transform by default def transform(self, points): points = np.asarray(points) is_1d = False if points.ndim == 1: is_1d = True points = points[None] assert points.shape[1] == 3, "Input points must be in volume space" if self.params["invert"]: print("Fast inverse transform of", points.shape) res = self._transform(points, points_start=self.points_start, points_end=self.points_end) else: print("Slow transform of", points.shape) if points.shape[0] > 1000: raise ValueError("Too many points to invert, this will take forever.") res = np.asarray([self._inv_transform_cache[tuple(p)] if tuple(p) in self._inv_transform_cache.keys() else invert_function_numerical(lambda x,self=self : self._transform(x, self.points_end, self.points_start), p) for p in points]) return res[0] if is_1d else res def invert(self): return self.__class__(points_start=self.points_end, points_end=self.points_start, invert=(not self.params["invert"]))
########## Transforms ##########
[docs] class Identity(AffineTransform,Transform): """Identity transform that leaves coordinates unchanged.""" NAME = "No transform" def _fit(self): self.matrix = np.eye(3) self.shift = np.zeros(3) def _transform(self, points): return points def invert(self): return self.__class__() def transform_image(self, image, output_size=None, labels=None, force_size=True): """Apply the identity transform This can obviously be a faster implementation than the default. Parameters ---------- image : numpy.ndarray or ndarray_shifted Input image data. output_size : optional Forwarded when generic path is required. labels : bool or None, optional Forwarded when generic path is required. force_size : bool, optional Forwarded when generic path is required. Returns ------- numpy.ndarray or ndarray_shifted Original image when possible, otherwise generic resampled output. """ # TODO This doesn't work for different output_size values if output_size is not None: return super().transform_image(image, output_size=output_size, labels=labels, force_size=force_size) return image
[docs] class Translate(AffineTransform,PointTransform): """Translation-only transform fit from corresponding points.""" NAME = "Translate" SORT_WEIGHT = -100 def _fit(self): self.matrix = np.eye(3) self.shift = np.mean(self.points_start - self.points_end, axis=0)
[docs] class TranslateParametric(AffineTransform,Transform): """Translation-only transform with fixed z/y/x offsets. Parameters ---------- z, y, x : float, optional Translation offsets in voxel coordinates. """ NAME = "Translate" DEFAULT_PARAMETERS = {"z": 0.0, "y": 0.0, "x": 0.0} GUI_DRAG_PARAMETERS = ["z", "y", "x"] def _fit(self): self.matrix = np.eye(3) self.shift = np.asarray([-self.params["z"], -self.params["y"], -self.params["x"]]) def invert(self): return self.__class__(x=-self.params["x"], y=-self.params["y"], z=-self.params["z"])
[docs] class Rigid(AffineTransform,PointTransform): """Rigid transform fit from points using rotation and translation.""" NAME = "Rigid" SHORTCUT_KEY = "R" SORT_WEIGHT = -99 def _fit(self): demeaned_start = self.points_start - np.mean(self.points_start, axis=0) demeaned_end = self.points_end - np.mean(self.points_end, axis=0) U,S,V = np.linalg.svd(demeaned_start.T @ demeaned_end) self.matrix = U@V self.shift = np.mean(self.points_start @ self.matrix - self.points_end, axis=0)
[docs] class RigidParametric(AffineTransform,Transform): """Rigid transform defined by fixed translation and rotation parameters. Parameters ---------- z, y, x : float, optional Translation offsets in voxel coordinates. zrotate, yrotate, xrotate : float, optional Clockwise rotation angles (degrees) about each axis. invert : bool, optional If ``True``, uses the inverse rotation direction. """ NAME = "Rigid" SHORTCUT_KEY = "r" SORT_WEIGHT = -99 DEFAULT_PARAMETERS = {"z": 0.0, "y": 0.0, "x": 0.0, "zrotate": 0.0, "yrotate": 0.0, "xrotate": 0.0, "invert": False} GUI_DRAG_PARAMETERS = ["z", "y", "x"] def _fit(self): self.matrix = rotation_matrix(self.params["zrotate"], self.params["yrotate"], self.params["xrotate"]) if self.params['invert']: self.matrix = self.matrix.T self.shift = np.asarray([-self.params["z"], -self.params["y"], -self.params["x"]]) def invert(self): newzyx = [self.params["z"], self.params["y"], self.params["x"]] @ self.matrix.T return self.__class__(zrotate=self.params["zrotate"], yrotate=self.params["yrotate"], xrotate=self.params["xrotate"], z=-newzyx[0], y=-newzyx[1], x=-newzyx[2], invert=(not self.params['invert']))
[docs] class Affine(AffineTransform,PointTransform): """Full affine transform fit from corresponding 3D points.""" NAME = "Affine" SHORTCUT_KEY = "A" SORT_WEIGHT = -97 DEFAULT_PARAMETERS = {"invert": False} def _fit(self): if self.params['invert']: _start = self.points_end _end = self.points_start else: _start = self.points_start _end = self.points_end start = np.hstack([np.ones((_start.shape[0],1)), _start]) reg_coefs = np.linalg.inv(start.T @ start) @ start.T @ _end self.matrix = reg_coefs[1:] if self.params['invert']: self.matrix = np.linalg.inv(self.matrix) self.shift = np.mean(self.points_start @ self.matrix - self.points_end, axis=0) def invert(self): return self.__class__(points_start=self.points_end, points_end=self.points_start, invert=(not self.params["invert"]))
[docs] class AffineParametric(AffineTransform,Transform): """Affine transform defined by fixed translation/rotation/scale/shear. Parameters ---------- z, y, x : float, optional Translation offsets in voxel coordinates. zrotate, yrotate, xrotate : float, optional Clockwise rotation angles (degrees) about each axis. zscale, yscale, xscale : float, optional Axis-wise scale factors. yzshear, xzshear, xyshear : float, optional Shear coefficients. invert : bool, optional If ``True``, inverts the combined affine matrix. """ NAME = "Affine" SHORTCUT_KEY = "a" SORT_WEIGHT = -94 DEFAULT_PARAMETERS = {"z": 0.0, "y": 0.0, "x": 0.0, "zrotate": 0.0, "yrotate": 0.0, "xrotate": 0.0, "zscale": 1.0, "yscale": 1.0, "xscale": 1.0, "yzshear": 0.0, "xzshear": 0.0, "xyshear": 0.0, "invert": False} GUI_DRAG_PARAMETERS = ["z", "y", "x"] def _fit(self): self.matrix = rotation_matrix(self.params["zrotate"], self.params["yrotate"], self.params["xrotate"]) @ np.asarray([[self.params["zscale"], 0, 0], [0, self.params["yscale"], 0], [0, 0, self.params["xscale"]]]) @ np.asarray([[1, 0, 0], [self.params["yzshear"], 1, 0], [self.params["xzshear"], self.params["xyshear"], 1]]) if self.params["invert"]: self.matrix = np.linalg.inv(self.matrix) self.shift = np.asarray([-self.params["z"], -self.params["y"], -self.params["x"]]) def invert(self): newzyx = [self.params["z"], self.params["y"], self.params["x"]] @ np.linalg.inv(self.matrix) return self.__class__(zrotate=self.params["zrotate"], yrotate=self.params["yrotate"], xrotate=self.params["xrotate"], zscale=self.params["zscale"], yscale=self.params["yscale"], xscale=self.params["xscale"], yzshear=self.params["yzshear"], xzshear=self.params["xzshear"], xyshear=self.params["xyshear"], z=-newzyx[0], y=-newzyx[1], x=-newzyx[2], invert=(not self.params["invert"]))
[docs] class MatrixParametric(AffineTransform,Transform): """Affine transform specified directly by matrix coefficients. Parameters ---------- a11, a12, a13, a21, a22, a23, a31, a32, a33 : float, optional Entries of the 3x3 affine matrix in row-major order. z, y, x : float, optional Translation offsets in voxel coordinates. """ NAME = "Transformation matrix" DEFAULT_PARAMETERS = {"a11": 1, "a12": 0, "a13": 0, "a21": 0, "a22": 1, "a23": 0, "a31": 0, "a32": 0, "a33": 1, "x": 0, "y": 0, "z": 0} GUI_DRAG_PARAMETERS = ["z", "y", "x"] SHORTCUT_KEY = "m" SORT_WEIGHT = -90 def _fit(self): p = lambda num : self.params[f"a{num}"] self.matrix = np.asarray([[p(11), p(12), p(13)], [p(21), p(22), p(23)], [p(31), p(32), p(33)]]) self.shift = np.asarray([-self.params["z"], -self.params["y"], -self.params["x"]]) def invert(self): p = lambda num : self.params[f"a{num}"] newzyx = [self.params["z"], self.params["y"], self.params["x"]] @ self.matrix.T return self.__class__(a11=p(11), a12=p(21), a13=p(31), a21=p(12), a22=p(22), a23=p(32), a31=p(13), a32=p(23), a33=p(33), z=-newzyx[0], y=-newzyx[1], x=-newzyx[2])
[docs] class LaminarAffine(AffineTransform,PointTransform): """Laminar affine fit for section-like data. Splits fitting into high-variance laminar components plus low-variance normal-depth component to reduce skew in thin volumes. The laminar basis is estimated from the dominant fitted plane of the matched points. """ NAME = "Laminar affine" SHORTCUT_KEY = "P" SORT_WEIGHT = -96 DEFAULT_PARAMETERS = {"invert": False} def _fit(self): if self.params['invert']: _start = self.points_end _end = self.points_start else: _start = self.points_start _end = self.points_end # PCA demeaned_start = _start - np.mean(_start, axis=0) demeaned_end = _end - np.mean(_end, axis=0) U,S,V = np.linalg.svd(demeaned_start.T @ demeaned_end) assert np.all(np.sort(S)[::-1] == S), "SVD is not sorted correctly" # Regression for the high-variance dimensions proj_start = demeaned_start @ U[:,:2] proj_end = demeaned_end @ V.T[:,:2] _proj_start = np.hstack([np.ones((proj_start.shape[0],1)), proj_start]) reg_coefs = np.linalg.inv(_proj_start.T @ _proj_start) @ _proj_start.T @ proj_end # Regression for the low-variance dimension depth_start = demeaned_start @ U[:,[2]] depth_end = demeaned_end @ V.T[:,[2]] _depth_start = np.hstack([np.ones((depth_start.shape[0],1)), depth_start]) depth_reg_coefs = np.linalg.inv(_depth_start.T @ _depth_start) @ _depth_start.T @ depth_end # Combine the two self.matrix = U @ (scipy.linalg.block_diag(reg_coefs[1:], 0) + np.diag([0, 0, depth_reg_coefs[1,0]])) @ V if self.params['invert']: self.matrix = np.linalg.inv(self.matrix) self.shift = np.mean(self.points_start @ self.matrix - self.points_end, axis=0) def invert(self): return self.__class__(points_start=self.points_end, points_end=self.points_start, invert=(not self.params["invert"]))
[docs] class FlipParametric(AffineTransform,Transform): """Axis-flip transform controlled by boolean flip parameters. Parameters ---------- z, y, x : bool, optional Whether to flip along each axis. zthickness, ythickness, xthickness : float or int, optional Axis extents used to compute the post-flip shift. """ NAME = "Flip" DEFAULT_PARAMETERS = {"z": False, "y": False, "x": False, "zthickness": 0, "ythickness": 0, "xthickness": 0} def _fit(self): sign = lambda x : -1 if self.params[x] else 1 self.matrix = np.asarray([[sign("z"), 0, 0], [0, sign("y"), 0], [0, 0, sign("x")]]) self.shift = -np.asarray([self.params[c+"thickness"]*int(self.params[c]) for c in ["z", "y", "x"]]) def invert(self): return self
[docs] class RescaleParametric(AffineTransform,Transform): """Axis-wise rescaling transform with fixed scale parameters. Parameters ---------- z, y, x : float, optional Scale factors for each axis. """ NAME = "Rescale" DEFAULT_PARAMETERS = {"z": 1.0, "y": 1.0, "x": 1.0} SHORTCUT_KEY = "z" def _fit(self): self.matrix = np.diag([self.params["z"], self.params["y"], self.params["x"]]) self.shift = np.asarray([0, 0, 0]) def invert(self): return self.__class__(z=1/self.params["z"], y=1/self.params["y"], x=1/self.params["x"])
[docs] class Triangulation(PointTransform): """Nonlinear 3D deformation using piecewise-affine triangulation. This transform warps a 3D volume by building a Delaunay triangulation over control points, then applying a local affine map per tetrahedron. Notes ----- The implementation supports both directions of mapping. For one direction it can use ``find_simplex`` directly; for the other it manually checks tetrahedron containment, because SciPy does not provide the exact arbitrary-triangulation path needed for this workflow. """ NAME = "3D triangulation" SHORTCUT_KEY = "V" SORT_WEIGHT = 100 DEFAULT_PARAMETERS = {"invert": True} # Start with inverted because inverted is slower for points and faster for images def _fit(self): # To avoid out of bounds, we add a few pseudo points. We do this by # finding the convex hull, centering it, and scaling it, and then # shifting the scaled points back from the centering. To avoid sharp # angles for near-coplannar point clouds, we shift the points in all # dimensions by a small value first. We assign these points a simple # linear transformed version of the points they are derived from. SCALE_FACTOR = 1000 if self.params["invert"]: before = self.points_end after = self.points_start else: before = self.points_start after = self.points_end _rns = np.random.RandomState(0).randn(*before.shape)*.0001 # Break symmetry before = before + _rns t = scipy.spatial.Delaunay(before) # Triangulation assert np.all(t.points == before), "Coplannar points" hull_points_inds = np.unique(t.convex_hull.flatten()) hull_points_vecs = after[hull_points_inds] - before[hull_points_inds] hull_mean_shift = np.mean(before[hull_points_inds], axis=0) if self.params["invert"]: self.pseudopoints_end = SCALE_FACTOR*(before[hull_points_inds] - hull_mean_shift) + hull_mean_shift #self.pseudopoints_end += 5*SCALE_FACTOR*np.sign(self.pseudopoints_end-hull_mean_shift) self.pseudopoints_start = self.pseudopoints_end + hull_points_vecs else: self.pseudopoints_start = SCALE_FACTOR*(before[hull_points_inds] - hull_mean_shift) + hull_mean_shift #self.pseudopoints_start += 5*SCALE_FACTOR*np.sign(self.pseudopoints_start-hull_mean_shift) self.pseudopoints_end = self.pseudopoints_start + hull_points_vecs self.all_points_start = np.concatenate([self.points_start, self.pseudopoints_start]) self.all_points_end = np.concatenate([self.points_end, self.pseudopoints_end]) def _transform(self, points): points = np.asarray(points) start = self.all_points_start end = self.all_points_end tri_points = self.all_points_start if not self.params["invert"] else self.all_points_end _rns = np.random.RandomState(0).randn(*tri_points.shape)*.0001 # Break symmetry delaunay = scipy.spatial.Delaunay(tri_points+_rns) assert np.max(delaunay.points-tri_points)<.1, "Wrong order of Delaunay triangulation, are some points coplannar?" if self.params['invert']: newpoints = np.zeros_like(points)*np.nan for simp in delaunay.simplices: insimp = scipy.spatial.Delaunay(start[simp]).find_simplex(points)>=0 if np.sum(insimp) == 0: continue coefs_rhs = np.concatenate([start[simp], np.ones(len(simp))[:,None]], axis=1) coefs_lhs = end[simp] params = np.linalg.inv(coefs_rhs) @ coefs_lhs newpoints[insimp] = np.concatenate([points[insimp], np.ones(sum(insimp))[:,None]], axis=1) @ params assert not np.any(np.isnan(newpoints)), "Point was outside of simplex or invalid input points" return newpoints else: # For the non-inverted case, we can use the original triangulation and improve performance insimp = delaunay.find_simplex(points) assert np.all(insimp>=0), "Points outside domain, increase scale factor in code" newpoints = np.zeros_like(points)*np.nan for i,simp in enumerate(delaunay.simplices): if np.sum(insimp==i) == 0: continue coefs_rhs = np.concatenate([start[simp], np.ones(len(simp))[:,None]], axis=1) coefs_lhs = end[simp] params = np.linalg.inv(coefs_rhs) @ coefs_lhs newpoints[insimp==i] = np.concatenate([points[insimp==i], np.ones(sum(insimp==i))[:,None]], axis=1) @ params assert not np.any(np.isnan(newpoints)), "Not sure why this should ever happen?" return newpoints def invert(self): return self.__class__(invert=(not self.params["invert"]), points_start=self.points_end, points_end=self.points_start)
[docs] class LaminarTriangulation(PointTransform): """Nonlinear laminar triangulation in 3D. This is generally the recommended nonlinear transform for mostly flat section-like 3D data (broad in two dimensions, thinner in the third). If ``normal_z``, ``normal_y``, and ``normal_x`` are all zero, the normal is estimated automatically from the input points. Notes ----- The transform triangulates points after projection into the fitted laminar plane, then computes local 3D affine maps using those projected triangles plus a normal-direction anchor so depth is handled consistently. As with ``Triangulation``, one mapping direction can use ``find_simplex`` directly, while the other uses manual triangle containment checks due to SciPy API limitations for this specific inverse workflow. """ NAME = "Laminar triangulation" SHORTCUT_KEY = "N" SORT_WEIGHT = 99 DEFAULT_PARAMETERS = {"invert": True, "normal_z": 0.0, "normal_y": 0.0, "normal_x": 0.0} # Start with inverted because inverted is slower for points and faster for images def _fit(self): # To avoid out of bounds, we add a few pseudo points. We do this by # finding the convex hull, centering it, and scaling it, and then # shifting the scaled points back from the centering. To avoid sharp # angles for near-coplannar point clouds, we shift the points in all # dimensions by a small value first. We assign these points a simple # linear transformed version of the points they are derived from. SCALE_FACTOR = 100 if self.params["invert"]: before = self.points_end after = self.points_start else: before = self.points_start after = self.points_end # Automatically find the normal if it wasn't manually specified if self.params['normal_z'] == 0 and self.params['normal_y'] == 0 and self.params['normal_x'] == 0: demeaned_before = before - np.mean(before, axis=0) U,S,V = np.linalg.svd(demeaned_before.T @ demeaned_before) assert np.all(np.sort(S)[::-1] == S), "SVD is not sorted correctly" self.normal = U[:,2] else: self.normal = np.asarray([self.params["normal_z"], self.params["normal_y"], self.params["normal_x"]]) self.normal /= np.sqrt(np.sum(np.square(self.normal))) # Find two vectors to form the basis for the plane. Suffix "B" to indicate we are in this basis vec1 = np.asarray([1., 0, 0]) if np.asarray([1., 0, 0]) @ self.normal < .99 else np.asarray([0, 1., 0]) # Can be any vector in a different direction than the normal, these are a bit easier to interpret vec1 -= (vec1 @ self.normal) * self.normal vec1 /= np.sqrt(np.sum(np.square(vec1))) self.B = np.asarray([vec1, np.cross(self.normal, vec1)]).T beforeB = before @ self.B t = scipy.spatial.Delaunay(beforeB) # Triangulation assert np.all(t.points == beforeB), "Coplannar points" hull_points_inds = np.unique(t.convex_hull.flatten()) hull_points_vecs = after[hull_points_inds] - before[hull_points_inds] hull_mean_shift = np.mean(before[hull_points_inds], axis=0) if self.params["invert"]: self.pseudopoints_end = SCALE_FACTOR*(before[hull_points_inds] - hull_mean_shift) + hull_mean_shift #self.pseudopoints_end += 5*SCALE_FACTOR*np.sign(self.pseudopoints_end-hull_mean_shift) self.pseudopoints_start = self.pseudopoints_end + hull_points_vecs else: self.pseudopoints_start = SCALE_FACTOR*(before[hull_points_inds] - hull_mean_shift) + hull_mean_shift #self.pseudopoints_start += 5*SCALE_FACTOR*np.sign(self.pseudopoints_start-hull_mean_shift) self.pseudopoints_end = self.pseudopoints_start + hull_points_vecs self.all_points_start = np.concatenate([self.points_start, self.pseudopoints_start]) self.all_points_end = np.concatenate([self.points_end, self.pseudopoints_end]) def _transform(self, points): # Project all points into the basis defined in _fit. This is indicated # by "B" suffix. Then perform the triangulation in that 2-dimensional # space, and then compute the new points in 3D by holding the # interpolation constant in the direction normal to the space. points = np.asarray(points) pointsB = points @ self.B start = self.all_points_start startB = start @ self.B end = self.all_points_end tri_pointsB = (self.all_points_start if not self.params["invert"] else self.all_points_end) @ self.B delaunay = scipy.spatial.Delaunay(tri_pointsB) assert np.all(delaunay.points == tri_pointsB), "Coplannar points" if self.params['invert']: newpoints = np.zeros_like(points)*np.nan if points.shape[0] > 1000: print("Warning, using slow transform to transform many points, try inverting") for simp in delaunay.simplices: insimp = scipy.spatial.Delaunay(startB[simp]).find_simplex(pointsB)>=0 if np.sum(insimp) == 0: continue _start = np.concatenate([start[simp], start[[simp[0]]]+self.normal], axis=0) _end = np.concatenate([end[simp], end[[simp[0]]]+self.normal], axis=0) coefs_rhs = np.concatenate([_start, np.ones(len(simp)+1)[:,None]], axis=1) coefs_lhs = _end params = np.linalg.inv(coefs_rhs) @ coefs_lhs newpoints[insimp] = np.concatenate([points[insimp], np.ones(np.sum(insimp))[:,None]], axis=1) @ params assert not np.any(np.isnan(newpoints)), "Point was outside of simplex or invalid input points" return newpoints else: # For the non-inverted case, we can use the original triangulation and improve performance insimp = delaunay.find_simplex(pointsB) assert np.all(insimp>=0), "Points outside domain, increase scale factor in code" newpoints = np.zeros_like(points)*np.nan for i,simp in enumerate(delaunay.simplices): if np.sum(insimp==i) == 0: continue # Perform linear regression to get a map from the start to the # end. Add an extra point with the normal added so we get a # 3D-to-3D map. _start = np.concatenate([start[simp], start[[simp[0]]]+self.normal], axis=0) _end = np.concatenate([end[simp], end[[simp[0]]]+self.normal], axis=0) coefs_rhs = np.concatenate([_start, np.ones(len(simp)+1)[:,None]], axis=1) coefs_lhs = _end params = np.linalg.inv(coefs_rhs) @ coefs_lhs newpoints[insimp==i] = np.concatenate([points[insimp==i], np.ones(np.sum(insimp==i))[:,None]], axis=1) @ params assert not np.any(np.isnan(newpoints)), "Not sure why this should ever happen?" return newpoints def invert(self): return self.__class__(invert=(not self.params["invert"]), points_start=self.points_end, points_end=self.points_start, normal_z=self.params["normal_z"], normal_y=self.params["normal_y"], normal_x=self.params["normal_x"])
########## Composing transforms ##########
[docs] def compose_transforms(a, b): """Compose two transforms into one transform chain. This is the implementation behind ``a + b``. It handles composing transform instances, and also the mixed case where ``a`` is an already-fit transform instance and ``b`` is a transform class. Parameters ---------- a : Transform Left-hand transform. Must be an instantiated transform. b : Transform or type Right-hand transform, either as an instance or as a transform class. Returns ------- Transform or type Depending on inputs: * If both are instances, returns an instantiated composed transform. * If ``b`` is a class, returns a composed transform class. * Identity components are simplified away when possible. Notes ----- For affine + affine composition, the returned composed class uses affine shortcuts (combined matrix/shift). For non-affine composition, point mapping is composed directly. """ # Skip for the identity transform if isinstance(a, Identity): return b if isinstance(b, Identity): return a # Special cases for linear and for adding to a class (not yet fitted) if isinstance(a, Transform) and isinstance(b, Transform): if isinstance(b, PointTransform): return compose_transforms(a, b.__class__)(points_start=b.points_start, points_end=b.points_end, **b.params) else: return compose_transforms(a, b.__class__)(**b.params) # if isinstance(a, Transform) and isinstance(b, Transform): # return Composed(a, b) if isinstance(a, Transform) and not isinstance(b, Transform): inherit = PointTransform if issubclass(b, PointTransform) else Transform if isinstance(a, AffineTransform) and issubclass(b, AffineTransform): class ComposedPartialAffine(AffineTransform,inherit): DEFAULT_PARAMETERS = b.DEFAULT_PARAMETERS # b.params # Changed from b.DEFAULT_PARAMETERS GUI_DRAG_PARAMETERS = b.GUI_DRAG_PARAMETERS def __new__(cls, *args, **kwargs): # Strip identity if necessary if b is Identity: return a return super(ComposedPartialAffine, cls).__new__(cls) def __init__(self, points_start=None, points_end=None, *args, **kwargs): extra_args = {} if points_start is not None and points_end is not None: extra_args['points_start'] = points_start extra_args['points_end'] = points_end self.b_type = b self.b = b(*args, **kwargs, **extra_args) super().__init__(*args, **kwargs, **extra_args) if self.b_type is Identity: self = self.a def _fit(self): self.matrix = a.matrix @ self.b.matrix self.shift = a.shift @ self.b.matrix + self.b.shift def __repr__(self): return repr(a) + " + " + repr(self.b) @staticmethod def pretransform(*args, **kwargs): return a def invert(self): return self.b.invert() + a.invert() return ComposedPartialAffine else: class ComposedPartial(inherit): DEFAULT_PARAMETERS = b.DEFAULT_PARAMETERS # b.params # Changed from b.DEFAULT_PARAMETERS GUI_DRAG_PARAMETERS = b.GUI_DRAG_PARAMETERS def __new__(cls, *args, **kwargs): # Strip identity if necessary if b is Identity: return a return super(ComposedPartial, cls).__new__(cls) def __init__(self, points_start=None, points_end=None, *args, **kwargs): extra_args = {} if points_start is not None and points_end is not None: extra_args['points_start'] = points_start extra_args['points_end'] = points_end self.b_type = b self.b = b(*args, **kwargs, **extra_args) super().__init__(**extra_args, **self.b.params) def _transform(self, points): return self.b.transform(a.transform(points)) def inverse_transform(self, points): return a.inverse_transform(self.b.inverse_transform(points)) def invert(self): raise NotImplementedError def __repr__(self): return repr(a) + " + " + repr(self.b) def invert(self): return self.b.invert() + a.invert() @staticmethod def pretransform(*args, **kwargs): return a return ComposedPartial raise NotImplementedError("Invalid composition")
########## Deprecated ##########
[docs] class Flip(AffineTransform,Transform): # Deprecated """Deprecated parametric axis-flip transform.""" NAME = "Flip" DEFAULT_PARAMETERS = {"z": False, "y": False, "x": False, "zthickness": 0, "ythickness": 0, "xthickness": 0} def _fit(self): sign = lambda x : -1 if self.params[x] else 1 self.matrix = np.asarray([[sign("z"), 0, 0], [0, sign("y"), 0], [0, 0, sign("x")]]) self.shift = np.asarray([max(0, self.params[c+"thickness"]-1)*int(self.params[c]) for c in ["z", "y", "x"]]) def invert(self): return self
# This doesn't work very well
[docs] class DistanceWeightedAverageGaussian(PointTransformNoAnalyticInverse): # Deprecated """Deprecated nonlinear displacement-field transform with Gaussian weighting.""" DEFAULT_PARAMETERS = {"extent": 1, "invert": False} def _transform(self, points, points_start, points_end): points = np.asarray(points, dtype="float") baseline = np.zeros_like(points[:,0]) pos = np.zeros_like(points) for i in range(0, len(points_start)): mvn = scipy.stats.multivariate_normal(points_start[i], np.eye(3)*self.params["extent"]) baseline += mvn.pdf(points) for j in range(0, 3): pos[:,j] += mvn.pdf(points)*(points_end[i][j]-points_start[i][j]) epsilon = 1e-100 # For numerical stability pos += np.mean(points_end-points_start, axis=0, keepdims=True)*epsilon pos /= (baseline[:,None] + epsilon) return points + pos