Module deformation_inversion_layer.fixed_point_invert_deformation

Fixed point deformation inversion function

Expand source code
"""Fixed point deformation inversion function"""

from functools import partial
from typing import Optional

from torch import Tensor
from torch import dtype as torch_dtype
from torch import enable_grad, zeros_like
from torch.autograd import grad
from torch.autograd.function import Function, FunctionCtx, once_differentiable

from .fixed_point_iteration import (
    AndersonSolver,
    MaxElementWiseAbsStopCriterion,
    RelativeL2ErrorStopCriterion,
)
from .interface import FixedPointSolver, Interpolator
from .interpolator import LinearInterpolator
from .interpolator.algorithm import generate_voxel_coordinate_grid


class DeformationInversionArguments:
    """Arguments for deformation fixed point inversion

    Args:
        interpolator: Interpolator with which to interpolate the input
            displacement field
        forward_solver: Fixed point solver for the forward pass
        backward_solver: Fixed point solver for the backward pass
        forward_dtype: Data type to use for the solver in the forward pass, by default
            the data type of the input is used
        backward_dtype: Data type to use for the solver in the backward pass, by default
            the data type of the input is used
    """

    def __init__(
        self,
        interpolator: Optional[Interpolator] = None,
        forward_solver: Optional[FixedPointSolver] = None,
        backward_solver: Optional[FixedPointSolver] = None,
        forward_dtype: Optional[torch_dtype] = None,
        backward_dtype: Optional[torch_dtype] = None,
    ) -> None:
        self.interpolator = LinearInterpolator() if interpolator is None else interpolator
        self.forward_solver = (
            AndersonSolver(
                stop_criterion=MaxElementWiseAbsStopCriterion(),
            )
            if forward_solver is None
            else forward_solver
        )
        self.backward_solver = (
            AndersonSolver(
                stop_criterion=RelativeL2ErrorStopCriterion(),
            )
            if backward_solver is None
            else backward_solver
        )
        self.forward_dtype = forward_dtype
        self.backward_dtype = backward_dtype


def fixed_point_invert_deformation(
    displacement_field: Tensor,
    arguments: Optional[DeformationInversionArguments] = None,
    initial_guess: Optional[Tensor] = None,
    coordinates: Optional[Tensor] = None,
) -> Tensor:
    """Fixed point invert displacement field

    Args:
        displacement_field: Displacement field describing the deformation to invert with shape
            (batch_size, n_dims, dim_1, ..., dim_{n_dims})
        arguments: Arguments for fixed point inversion
        initial_guess: Initial guess for inverted displacement field, if not given, negative
            of the displacement field is used
        coordinates: Voxel coordinates at which to compute the inverse with
            shape (batch_size, n_dims, *coordinates_shape), by default the inversion is done at
            the voxel coordinates of the input displacement field

    Returns:
        Inverted displacement field with shape (batch_size, n_dims, dim_1, ..., dim_{n_dims}) if
            coordinates is None, otherwise (batch_size, n_dims, *coordinates_shape)
    """
    return _FixedPointInvertDisplacementField.apply(
        displacement_field,
        DeformationInversionArguments() if arguments is None else arguments,
        initial_guess,
        coordinates,
    )


class _FixedPointInvertDisplacementField(Function):  # pylint: disable=abstract-method
    @staticmethod
    def _forward_fixed_point_iteration_step(
        inverted_displacement_field: Tensor,
        displacement_field: Tensor,
        interpolator: Interpolator,
        coordinates: Tensor,
    ) -> Tensor:
        return -interpolator(
            volume=displacement_field,
            coordinates=coordinates + inverted_displacement_field,
        )

    @staticmethod
    def _forward_fixed_point_mapping(
        inverted_displacement_field: Tensor,
        out: Tensor,
        displacement_field: Tensor,
        interpolator: Interpolator,
        coordinates: Tensor,
    ) -> None:
        out[:] = _FixedPointInvertDisplacementField._forward_fixed_point_iteration_step(
            inverted_displacement_field=inverted_displacement_field,
            displacement_field=displacement_field,
            interpolator=interpolator,
            coordinates=coordinates,
        )

    @staticmethod
    def _backward_fixed_point_mapping(
        vjp_estimate: Tensor,
        out: Tensor,
        inverted_displacement_field: Tensor,
        forward_fixed_point_output: Tensor,
        grad_output: Tensor,
    ) -> None:
        out[:] = (
            grad(
                outputs=forward_fixed_point_output,
                inputs=inverted_displacement_field,
                grad_outputs=vjp_estimate,
                retain_graph=True,
            )[0]
            + grad_output
        )

    @staticmethod
    def forward(  # type: ignore # pylint: disable=arguments-differ, missing-function-docstring
        ctx: FunctionCtx,
        displacement_field: Tensor,
        arguments: DeformationInversionArguments,
        initial_guess: Optional[Tensor],
        coordinates: Optional[Tensor],
    ):
        dtype = (
            displacement_field.dtype if arguments.forward_dtype is None else arguments.forward_dtype
        )
        type_converted_displacement_field = displacement_field.to(dtype=dtype)
        if coordinates is None:
            type_converted_coordinates = generate_voxel_coordinate_grid(
                displacement_field.shape[2:], displacement_field.device, dtype=dtype
            )
        elif displacement_field.dtype != coordinates.dtype:
            raise ValueError(
                f'DType {coordinates.dtype} of input "coordinates" does not match '
                f'DType {displacement_field.dtype} of input "displacement_field"'
            )
        else:
            type_converted_coordinates = coordinates.to(dtype=dtype)
        if initial_guess is None and coordinates is None:
            type_converted_initial_guess = -type_converted_displacement_field
        elif initial_guess is None:
            type_converted_initial_guess = zeros_like(type_converted_coordinates)
        elif displacement_field.dtype != initial_guess.dtype:
            raise ValueError(
                f'DType {initial_guess.dtype} of input "initial_guess" does not match '
                f'DType {displacement_field.dtype} of input "displacement_field"'
            )
        else:
            type_converted_initial_guess = initial_guess.to(dtype=dtype)
        inverted_displacement_field = arguments.forward_solver.solve(
            partial(
                _FixedPointInvertDisplacementField._forward_fixed_point_mapping,
                displacement_field=type_converted_displacement_field,
                coordinates=type_converted_coordinates,
                interpolator=arguments.interpolator,
            ),
            initial_value=type_converted_initial_guess,
        ).to(displacement_field.dtype)
        (
            displacement_field_grad_needed,
            _,
            _,
            coordinates_grad_needed,
        ) = ctx.needs_input_grad  # type: ignore
        if displacement_field_grad_needed or coordinates_grad_needed:
            tensors_to_save = [displacement_field, inverted_displacement_field]
            if coordinates is not None:
                tensors_to_save.append(coordinates)
            ctx.save_for_backward(*tensors_to_save)
            ctx.arguments = arguments  # type: ignore
            ctx.dtype = dtype  # type: ignore
            ctx.has_coordinates = coordinates is not None  # type: ignore
        return inverted_displacement_field

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output: Tensor):  # type: ignore # pylint: disable=arguments-differ, missing-function-docstring
        (
            displacement_field_grad_needed,
            _,
            _,
            coordinates_grad_needed,
        ) = ctx.needs_input_grad
        if displacement_field_grad_needed or coordinates_grad_needed:
            displacement_field: Tensor = ctx.saved_tensors[0]
            inverted_displacement_field: Tensor = ctx.saved_tensors[1]
            coordinates: Tensor | None = ctx.saved_tensors[2] if ctx.has_coordinates else None
            arguments: DeformationInversionArguments = ctx.arguments
            del ctx
            dtype = (
                displacement_field.dtype
                if arguments.backward_dtype is None
                else arguments.backward_dtype
            )
            original_dtype = displacement_field.dtype
            displacement_field = displacement_field.to(dtype).detach()
            if coordinates is None:
                coordinates = generate_voxel_coordinate_grid(
                    displacement_field.shape[2:], displacement_field.device, dtype=dtype
                )
            else:
                coordinates = coordinates.to(dtype=dtype)
            inverted_displacement_field = inverted_displacement_field.to(dtype).detach()
            grad_output = grad_output.to(dtype)
            if arguments.backward_solver is None:
                raise RuntimeError("Backward solver not specified!")
            with enable_grad():
                displacement_field.requires_grad_(displacement_field_grad_needed)
                coordinates.requires_grad_(coordinates_grad_needed)
                inverted_displacement_field.requires_grad_(True)
                forward_fixed_point_output = _FixedPointInvertDisplacementField._forward_fixed_point_iteration_step(  # pylint: disable=line-too-long
                    inverted_displacement_field=inverted_displacement_field,
                    displacement_field=displacement_field,
                    interpolator=arguments.interpolator,
                    coordinates=coordinates,
                )
                fixed_point_solved_gradient = arguments.backward_solver.solve(
                    partial(
                        _FixedPointInvertDisplacementField._backward_fixed_point_mapping,
                        inverted_displacement_field=inverted_displacement_field,
                        forward_fixed_point_output=forward_fixed_point_output,
                        grad_output=grad_output,
                    ),
                    initial_value=zeros_like(inverted_displacement_field),
                )
                displacement_field.requires_grad_(displacement_field_grad_needed)
                coordinates.requires_grad_(coordinates_grad_needed)
                inverted_displacement_field.requires_grad_(False)
                differentiated_inputs = []
                if displacement_field_grad_needed:
                    differentiated_inputs.append(displacement_field)
                if coordinates_grad_needed:
                    differentiated_inputs.append(coordinates)
                output_grad = grad(
                    outputs=forward_fixed_point_output,
                    inputs=differentiated_inputs,
                    grad_outputs=fixed_point_solved_gradient,
                    retain_graph=False,
                )
                displacement_field_grad = (
                    output_grad[0].to(dtype=original_dtype)
                    if displacement_field_grad_needed
                    else None
                )
                coordinates_grad = (
                    output_grad[1 if displacement_field_grad_needed else 0].to(dtype=original_dtype)
                    if coordinates_grad_needed
                    else None
                )
                return displacement_field_grad, None, None, coordinates_grad
        return None, None, None, None

Functions

def fixed_point_invert_deformation(displacement_field: torch.Tensor, arguments: Optional[DeformationInversionArguments] = None, initial_guess: Optional[torch.Tensor] = None, coordinates: Optional[torch.Tensor] = None) ‑> torch.Tensor

Fixed point invert displacement field

Args

displacement_field
Displacement field describing the deformation to invert with shape (batch_size, n_dims, dim_1, …, dim_{n_dims})
arguments
Arguments for fixed point inversion
initial_guess
Initial guess for inverted displacement field, if not given, negative of the displacement field is used
coordinates
Voxel coordinates at which to compute the inverse with shape (batch_size, n_dims, *coordinates_shape), by default the inversion is done at the voxel coordinates of the input displacement field

Returns

Inverted displacement field with shape (batch_size, n_dims, dim_1, …, dim_{n_dims}) if coordinates is None, otherwise (batch_size, n_dims, *coordinates_shape)

Expand source code
def fixed_point_invert_deformation(
    displacement_field: Tensor,
    arguments: Optional[DeformationInversionArguments] = None,
    initial_guess: Optional[Tensor] = None,
    coordinates: Optional[Tensor] = None,
) -> Tensor:
    """Fixed point invert displacement field

    Args:
        displacement_field: Displacement field describing the deformation to invert with shape
            (batch_size, n_dims, dim_1, ..., dim_{n_dims})
        arguments: Arguments for fixed point inversion
        initial_guess: Initial guess for inverted displacement field, if not given, negative
            of the displacement field is used
        coordinates: Voxel coordinates at which to compute the inverse with
            shape (batch_size, n_dims, *coordinates_shape), by default the inversion is done at
            the voxel coordinates of the input displacement field

    Returns:
        Inverted displacement field with shape (batch_size, n_dims, dim_1, ..., dim_{n_dims}) if
            coordinates is None, otherwise (batch_size, n_dims, *coordinates_shape)
    """
    return _FixedPointInvertDisplacementField.apply(
        displacement_field,
        DeformationInversionArguments() if arguments is None else arguments,
        initial_guess,
        coordinates,
    )

Classes

class DeformationInversionArguments (interpolator: Optional[Interpolator] = None, forward_solver: Optional[FixedPointSolver] = None, backward_solver: Optional[FixedPointSolver] = None, forward_dtype: Optional[torch.dtype] = None, backward_dtype: Optional[torch.dtype] = None)

Arguments for deformation fixed point inversion

Args

interpolator
Interpolator with which to interpolate the input displacement field
forward_solver
Fixed point solver for the forward pass
backward_solver
Fixed point solver for the backward pass
forward_dtype
Data type to use for the solver in the forward pass, by default the data type of the input is used
backward_dtype
Data type to use for the solver in the backward pass, by default the data type of the input is used
Expand source code
class DeformationInversionArguments:
    """Arguments for deformation fixed point inversion

    Args:
        interpolator: Interpolator with which to interpolate the input
            displacement field
        forward_solver: Fixed point solver for the forward pass
        backward_solver: Fixed point solver for the backward pass
        forward_dtype: Data type to use for the solver in the forward pass, by default
            the data type of the input is used
        backward_dtype: Data type to use for the solver in the backward pass, by default
            the data type of the input is used
    """

    def __init__(
        self,
        interpolator: Optional[Interpolator] = None,
        forward_solver: Optional[FixedPointSolver] = None,
        backward_solver: Optional[FixedPointSolver] = None,
        forward_dtype: Optional[torch_dtype] = None,
        backward_dtype: Optional[torch_dtype] = None,
    ) -> None:
        self.interpolator = LinearInterpolator() if interpolator is None else interpolator
        self.forward_solver = (
            AndersonSolver(
                stop_criterion=MaxElementWiseAbsStopCriterion(),
            )
            if forward_solver is None
            else forward_solver
        )
        self.backward_solver = (
            AndersonSolver(
                stop_criterion=RelativeL2ErrorStopCriterion(),
            )
            if backward_solver is None
            else backward_solver
        )
        self.forward_dtype = forward_dtype
        self.backward_dtype = backward_dtype