Module deformation_inversion_layer.interpolator.algorithm

Core interpolation algorithms

Expand source code
"""Core interpolation algorithms"""

from typing import List, Optional, Tuple

from torch import Tensor
from torch import device as torch_device
from torch import dtype as torch_dtype
from torch import linspace, meshgrid, stack, tensor
from torch.jit import script
from torch.nn.functional import grid_sample


@script
def _move_channels_last(tensor_to_modify: Tensor, num_channel_dims: int = 1) -> Tensor:
    if tensor_to_modify.ndim == num_channel_dims:
        return tensor_to_modify
    return tensor_to_modify.permute(
        [0]
        + list(range(num_channel_dims + 1, tensor_to_modify.ndim))
        + list(range(1, num_channel_dims + 1))
    )


@script
def _index_by_channel_dims(n_total_dims: int, channel_dim_index: int, n_channel_dims: int) -> int:
    if n_total_dims < n_channel_dims:
        raise RuntimeError("Number of channel dimensions do not match")
    if n_total_dims == n_channel_dims:
        return channel_dim_index
    return channel_dim_index + 1


@script
def _num_spatial_dims(n_total_dims: int, n_channel_dims: int) -> int:
    if n_total_dims < n_channel_dims:
        raise RuntimeError("Number of channel dimensions do not match")
    if n_total_dims <= n_channel_dims + 1:
        return 0
    return n_total_dims - n_channel_dims - 1


@script
def _convert_voxel_to_normalized_coordinates(
    coordinates: Tensor, volume_shape: Optional[List[int]] = None
) -> Tensor:
    channel_dim = _index_by_channel_dims(coordinates.ndim, channel_dim_index=0, n_channel_dims=1)
    n_spatial_dims = _num_spatial_dims(n_total_dims=coordinates.ndim, n_channel_dims=1)
    n_dims = coordinates.size(channel_dim)
    inferred_volume_shape = coordinates.shape[-n_dims:] if volume_shape is None else volume_shape
    add_spatial_dims_view = (-1,) + n_spatial_dims * (1,)
    volume_shape_tensor = tensor(
        inferred_volume_shape, dtype=coordinates.dtype, device=coordinates.device
    ).view(add_spatial_dims_view)
    coordinate_grid_start = tensor(-1.0, dtype=coordinates.dtype, device=coordinates.device).view(
        add_spatial_dims_view
    )
    coordinate_grid_end = tensor(1.0, dtype=coordinates.dtype, device=coordinates.device).view(
        add_spatial_dims_view
    )
    output = (
        coordinates / (volume_shape_tensor - 1) * (coordinate_grid_end - coordinate_grid_start)
        + coordinate_grid_start
    )
    return output


@script
def _broadcast_batch_size(tensor_1: Tensor, tensor_2: Tensor) -> Tuple[Tensor, Tensor]:
    batch_size = max(tensor_1.size(0), tensor_2.size(0))
    if tensor_1.size(0) == 1 and batch_size != 1:
        tensor_1 = tensor_1[0].expand((batch_size,) + tensor_1.shape[1:])
    elif tensor_2.size(0) == 1 and batch_size != 1:
        tensor_2 = tensor_2[0].expand((batch_size,) + tensor_2.shape[1:])
    elif tensor_1.size(0) != tensor_2.size(0) and batch_size != 1:
        raise ValueError("Can not broadcast batch size")
    return tensor_1, tensor_2


@script
def _match_grid_shape_to_dims(grid: Tensor) -> Tensor:
    batch_size = grid.size(0)
    n_dims = grid.size(1)
    grid_shape = grid.shape[2:]
    dim_matched_grid_shape = (
        (1,) * max(0, n_dims - grid.ndim + 1) + grid_shape[: n_dims - 1] + (-1,)
    )
    return grid.view(
        (
            batch_size,
            n_dims,
        )
        + dim_matched_grid_shape
    )


def interpolate(
    volume: Tensor, grid: Tensor, mode: str = "bilinear", padding_mode: str = "border"
) -> Tensor:
    """Interpolate in voxel coordinates

    Args:
        volume: Interpolated volume with shape
            (batch_size, [channel_1, ..., channel_n, ]dim_1, ..., dim_{n_dims})
        grid: Grid defining interpolation locations with shape (batch_size, n_dims, *target_shape)
        mode: Interpolation mode
        padding_mode: Padding mode defining extrapolation behaviour

    Returns:
        Volume interpolated at grid locations with shape
            (batch_size, channel_1, ..., channel_n, *target_shape)
    """
    if grid.ndim == 1:
        grid = grid[None]
    n_dims = grid.size(1)
    channel_shape = volume.shape[1:-n_dims]
    volume_shape = volume.shape[-n_dims:]
    target_shape = grid.shape[2:]
    dim_matched_grid = _match_grid_shape_to_dims(grid)
    normalized_grid = _convert_voxel_to_normalized_coordinates(dim_matched_grid, list(volume_shape))
    simplified_volume = volume.view((volume.size(0), -1) + volume_shape)
    permuted_volume = simplified_volume.permute(
        [0, 1] + list(range(simplified_volume.ndim - 1, 2 - 1, -1))
    )
    permuted_grid = _move_channels_last(normalized_grid, 1)
    permuted_volume, permuted_grid = _broadcast_batch_size(permuted_volume, permuted_grid)
    return grid_sample(
        input=permuted_volume,
        grid=permuted_grid,
        align_corners=True,
        mode=mode,
        padding_mode=padding_mode,
    ).view((-1,) + channel_shape + target_shape)


def generate_voxel_coordinate_grid(
    shape: List[int], device: torch_device, dtype: Optional[torch_dtype] = None
) -> Tensor:
    """Generate voxel coordinate grid

    Args:
        shape: Shape of the grid
        device: Device of the grid
        dtype: Data type of the grid

    Returns:
        Voxel coordinate grid with shape (1, len(shape), dim_1, ..., dim_{len(shape)})
    """
    axes = [
        linspace(
            start=0,
            end=int(dim_size) - 1,
            steps=int(dim_size),
            device=device,
            dtype=dtype,
        )
        for dim_size in shape
    ]
    coordinates = stack(meshgrid(axes, indexing="ij"), dim=0)
    return coordinates[None]

Functions

def generate_voxel_coordinate_grid(shape: List[int], device: torch.device, dtype: Optional[torch.dtype] = None) ‑> torch.Tensor

Generate voxel coordinate grid

Args

shape
Shape of the grid
device
Device of the grid
dtype
Data type of the grid

Returns

Voxel coordinate grid with shape (1, len(shape), dim_1, …, dim_{len(shape)})

Expand source code
def generate_voxel_coordinate_grid(
    shape: List[int], device: torch_device, dtype: Optional[torch_dtype] = None
) -> Tensor:
    """Generate voxel coordinate grid

    Args:
        shape: Shape of the grid
        device: Device of the grid
        dtype: Data type of the grid

    Returns:
        Voxel coordinate grid with shape (1, len(shape), dim_1, ..., dim_{len(shape)})
    """
    axes = [
        linspace(
            start=0,
            end=int(dim_size) - 1,
            steps=int(dim_size),
            device=device,
            dtype=dtype,
        )
        for dim_size in shape
    ]
    coordinates = stack(meshgrid(axes, indexing="ij"), dim=0)
    return coordinates[None]
def interpolate(volume: torch.Tensor, grid: torch.Tensor, mode: str = 'bilinear', padding_mode: str = 'border') ‑> torch.Tensor

Interpolate in voxel coordinates

Args

volume
Interpolated volume with shape (batch_size, [channel_1, …, channel_n, ]dim_1, …, dim_{n_dims})
grid
Grid defining interpolation locations with shape (batch_size, n_dims, *target_shape)
mode
Interpolation mode
padding_mode
Padding mode defining extrapolation behaviour

Returns

Volume interpolated at grid locations with shape (batch_size, channel_1, …, channel_n, *target_shape)

Expand source code
def interpolate(
    volume: Tensor, grid: Tensor, mode: str = "bilinear", padding_mode: str = "border"
) -> Tensor:
    """Interpolate in voxel coordinates

    Args:
        volume: Interpolated volume with shape
            (batch_size, [channel_1, ..., channel_n, ]dim_1, ..., dim_{n_dims})
        grid: Grid defining interpolation locations with shape (batch_size, n_dims, *target_shape)
        mode: Interpolation mode
        padding_mode: Padding mode defining extrapolation behaviour

    Returns:
        Volume interpolated at grid locations with shape
            (batch_size, channel_1, ..., channel_n, *target_shape)
    """
    if grid.ndim == 1:
        grid = grid[None]
    n_dims = grid.size(1)
    channel_shape = volume.shape[1:-n_dims]
    volume_shape = volume.shape[-n_dims:]
    target_shape = grid.shape[2:]
    dim_matched_grid = _match_grid_shape_to_dims(grid)
    normalized_grid = _convert_voxel_to_normalized_coordinates(dim_matched_grid, list(volume_shape))
    simplified_volume = volume.view((volume.size(0), -1) + volume_shape)
    permuted_volume = simplified_volume.permute(
        [0, 1] + list(range(simplified_volume.ndim - 1, 2 - 1, -1))
    )
    permuted_grid = _move_channels_last(normalized_grid, 1)
    permuted_volume, permuted_grid = _broadcast_batch_size(permuted_volume, permuted_grid)
    return grid_sample(
        input=permuted_volume,
        grid=permuted_grid,
        align_corners=True,
        mode=mode,
        padding_mode=padding_mode,
    ).view((-1,) + channel_shape + target_shape)