Module deformation_inversion_layer.interface
Interface definitions
Expand source code
"""Interface definitions"""
from abc import abstractmethod
from typing import Protocol, Sequence
from torch import Tensor
class Interpolator(Protocol):
"""Interpolates values on regular grid in voxel coordinates"""
@abstractmethod
def __call__(self, volume: Tensor, coordinates: Tensor) -> Tensor:
"""Interpolate
Args:
volume: Volume to be interpolated with shape
(batch_size, *channel_dims, dim_1, ..., dim_{n_dims}). Dimension
order is the same as the coordinate order of the coordinates
coordinates: Interpolation coordinates with shape
(batch_size, n_dims, *target_shape)
Returns:
Interpolated volume with shape (batch_size, *channel_dims, *target_shape)
"""
class FixedPointFunction(Protocol):
"""Protocol for fixed point functions"""
def __call__(
self,
iteration_input: Tensor,
output_buffer: Tensor,
) -> None:
"""Call a fixed point function
Args:
iteration_input: Input to the fixed point function
output_buffer: Output should be stored in-place to this Tensor (same
shape as the input)
"""
class FixedPointSolver(Protocol):
"""Protocol for fixed point solvers"""
@abstractmethod
def solve(
self,
fixed_point_function: FixedPointFunction,
initial_value: Tensor,
) -> Tensor:
"""Solve a fixed point problem
Args:
fixed_point_function: Function to be iterated until convergence
initial_value: Initial iteration value
Returns:
Solution of the fixed point iteration
"""
class FixedPointStopCriterion(Protocol):
"""Protocol for fixed point iteration stopping criterions"""
@abstractmethod
def should_stop(
self,
current_iteration: Tensor,
previous_iterations: Sequence[Tensor],
n_earlier_iterations: int,
) -> bool:
"""Return whether iterating should be continued at beginning of an iteration
Args:
current_iteration: Current output of the fixed point iteration. For
n_earlier_iterations == 0 this equals the initial guess.
previous_iterations: Previous outputs of the fixed point iteration,
starting from the most recent one. length of this list may
depend on the fixed point iteration solver and is always 0 for
the n_earlier_iterations == 0.
n_earlier_iterations: Number of calls made to the fixed point function
before the next iteration.
Returns:
Whether the iteration should be stopped
"""
Classes
class FixedPointFunction (*args, **kwargs)
-
Protocol for fixed point functions
Expand source code
class FixedPointFunction(Protocol): """Protocol for fixed point functions""" def __call__( self, iteration_input: Tensor, output_buffer: Tensor, ) -> None: """Call a fixed point function Args: iteration_input: Input to the fixed point function output_buffer: Output should be stored in-place to this Tensor (same shape as the input) """
Ancestors
- typing.Protocol
- typing.Generic
class FixedPointSolver (*args, **kwargs)
-
Protocol for fixed point solvers
Expand source code
class FixedPointSolver(Protocol): """Protocol for fixed point solvers""" @abstractmethod def solve( self, fixed_point_function: FixedPointFunction, initial_value: Tensor, ) -> Tensor: """Solve a fixed point problem Args: fixed_point_function: Function to be iterated until convergence initial_value: Initial iteration value Returns: Solution of the fixed point iteration """
Ancestors
- typing.Protocol
- typing.Generic
Subclasses
Methods
def solve(self, fixed_point_function: FixedPointFunction, initial_value: torch.Tensor) ‑> torch.Tensor
-
Solve a fixed point problem
Args
fixed_point_function
- Function to be iterated until convergence
initial_value
- Initial iteration value
Returns
Solution of the fixed point iteration
Expand source code
@abstractmethod def solve( self, fixed_point_function: FixedPointFunction, initial_value: Tensor, ) -> Tensor: """Solve a fixed point problem Args: fixed_point_function: Function to be iterated until convergence initial_value: Initial iteration value Returns: Solution of the fixed point iteration """
class FixedPointStopCriterion (*args, **kwargs)
-
Protocol for fixed point iteration stopping criterions
Expand source code
class FixedPointStopCriterion(Protocol): """Protocol for fixed point iteration stopping criterions""" @abstractmethod def should_stop( self, current_iteration: Tensor, previous_iterations: Sequence[Tensor], n_earlier_iterations: int, ) -> bool: """Return whether iterating should be continued at beginning of an iteration Args: current_iteration: Current output of the fixed point iteration. For n_earlier_iterations == 0 this equals the initial guess. previous_iterations: Previous outputs of the fixed point iteration, starting from the most recent one. length of this list may depend on the fixed point iteration solver and is always 0 for the n_earlier_iterations == 0. n_earlier_iterations: Number of calls made to the fixed point function before the next iteration. Returns: Whether the iteration should be stopped """
Ancestors
- typing.Protocol
- typing.Generic
Subclasses
Methods
def should_stop(self, current_iteration: torch.Tensor, previous_iterations: Sequence[torch.Tensor], n_earlier_iterations: int) ‑> bool
-
Return whether iterating should be continued at beginning of an iteration
Args
current_iteration
- Current output of the fixed point iteration. For n_earlier_iterations == 0 this equals the initial guess.
previous_iterations
- Previous outputs of the fixed point iteration, starting from the most recent one. length of this list may depend on the fixed point iteration solver and is always 0 for the n_earlier_iterations == 0.
n_earlier_iterations
- Number of calls made to the fixed point function before the next iteration.
Returns
Whether the iteration should be stopped
Expand source code
@abstractmethod def should_stop( self, current_iteration: Tensor, previous_iterations: Sequence[Tensor], n_earlier_iterations: int, ) -> bool: """Return whether iterating should be continued at beginning of an iteration Args: current_iteration: Current output of the fixed point iteration. For n_earlier_iterations == 0 this equals the initial guess. previous_iterations: Previous outputs of the fixed point iteration, starting from the most recent one. length of this list may depend on the fixed point iteration solver and is always 0 for the n_earlier_iterations == 0. n_earlier_iterations: Number of calls made to the fixed point function before the next iteration. Returns: Whether the iteration should be stopped """
class Interpolator (*args, **kwargs)
-
Interpolates values on regular grid in voxel coordinates
Expand source code
class Interpolator(Protocol): """Interpolates values on regular grid in voxel coordinates""" @abstractmethod def __call__(self, volume: Tensor, coordinates: Tensor) -> Tensor: """Interpolate Args: volume: Volume to be interpolated with shape (batch_size, *channel_dims, dim_1, ..., dim_{n_dims}). Dimension order is the same as the coordinate order of the coordinates coordinates: Interpolation coordinates with shape (batch_size, n_dims, *target_shape) Returns: Interpolated volume with shape (batch_size, *channel_dims, *target_shape) """
Ancestors
- typing.Protocol
- typing.Generic
Subclasses