Source code for torch.nn.utils.parametrizations
from enum import Enum, auto
import torch
from torch import Tensor
from ..utils import parametrize
from ..modules import Module
from .. import functional as F
from typing import Optional
__all__ = ['orthogonal', 'spectral_norm', 'weight_norm']
def _is_orthogonal(Q, eps=None):
n, k = Q.size(-2), Q.size(-1)
Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
# A reasonable eps, but not too large
eps = 10. * n * torch.finfo(Q.dtype).eps
return torch.allclose(Q.mH @ Q, Id, atol=eps)
def _make_orthogonal(A):
""" Assume that A is a tall matrix.
Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative
"""
X, tau = torch.geqrf(A)
Q = torch.linalg.householder_product(X, tau)
# The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs
Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
return Q
class _OrthMaps(Enum):
matrix_exp = auto()
cayley = auto()
householder = auto()
class _Orthogonal(Module):
base: Tensor
def __init__(self,
weight,
orthogonal_map: _OrthMaps,
*,
use_trivialization=True) -> None:
super().__init__()
# Note [Householder complex]
# For complex tensors, it is not possible to compute the tensor `tau` necessary for
# linalg.householder_product from the reflectors.
# To see this, note that the reflectors have a shape like:
# 0 0 0
# * 0 0
# * * 0
# which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters
# to parametrize the unitary matrices. Saving tau on its own does not work either, because
# not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise
# them as independent tensors we would not maintain the constraint
# An equivalent reasoning holds for rectangular matrices
if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
raise ValueError("The householder parametrization does not support complex tensors.")
self.shape = weight.shape
self.orthogonal_map = orthogonal_map
if use_trivialization:
self.register_buffer("base", None)
def forward(self, X: torch.Tensor) -> torch.Tensor:
n, k = X.size(-2), X.size(-1)
transposed = n < k
if transposed:
X = X.mT
n, k = k, n
# Here n > k and X is a tall matrix
if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley:
# We just need n x k - k(k-1)/2 parameters
X = X.tril()
if n != k:
# Embed into a square matrix
X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
A = X - X.mH
# A is skew-symmetric (or skew-hermitian)
if self.orthogonal_map == _OrthMaps.matrix_exp:
Q = torch.matrix_exp(A)
elif self.orthogonal_map == _OrthMaps.cayley:
# Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
Id = torch.eye(n, dtype=A.dtype, device=A.device)
Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5))
# Q is now orthogonal (or unitary) of size (..., n, n)
if n != k:
Q = Q[..., :k]
# Q is now the size of the X (albeit perhaps transposed)
else:
# X is real here, as we do not support householder with complex numbers
A = X.tril(diagonal=-1)
tau = 2. / (1. + (A * A).sum(dim=-2))
Q = torch.linalg.householder_product(A, tau)
# The diagonal of X is 1's and -1's
# We do not want to differentiate through this or update the diagonal of X hence the casting
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
if hasattr(self, "base"):
Q = self.base @ Q
if transposed:
Q = Q.mT
return Q
@torch.autograd.no_grad()
def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
if Q.shape != self.shape:
raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. "
f"Got a tensor of shape {Q.shape}.")
Q_init = Q
n, k = Q.size(-2), Q.size(-1)
transpose = n < k
if transpose:
Q = Q.mT
n, k = k, n
# We always make sure to always copy Q in every path
if not hasattr(self, "base"):
# Note [right_inverse expm cayley]
# If we do not have use_trivialization=True, we just implement the inverse of the forward
# map for the Householder. To see why, think that for the Cayley map,
# we would need to find the matrix X \in R^{n x k} such that:
# Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
# A = Y - Y.mH
# cayley(A)[:, :k]
# gives the original tensor. It is not clear how to do this.
# Perhaps via some algebraic manipulation involving the QR like that of
# Corollary 2.2 in Edelman, Arias and Smith?
if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp:
raise NotImplementedError("It is not possible to assign to the matrix exponential "
"or the Cayley parametrizations when use_trivialization=False.")
# If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.
# Here Q is always real because we do not support householder and complex matrices.
# See note [Householder complex]
A, tau = torch.geqrf(Q)
# We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could
# decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition
# The diagonal of Q is the diagonal of R from the qr decomposition
A.diagonal(dim1=-2, dim2=-1).sign_()
# Equality with zero is ok because LAPACK returns exactly zero when it does not want
# to use a particular reflection
A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1
return A.mT if transpose else A
else:
if n == k:
# We check whether Q is orthogonal
if not _is_orthogonal(Q):
Q = _make_orthogonal(Q)
else: # Is orthogonal
Q = Q.clone()
else:
# Complete Q into a full n x n orthogonal matrix
N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device)
Q = torch.cat([Q, N], dim=-1)
Q = _make_orthogonal(Q)
self.base = Q
# It is necessary to return the -Id, as we use the diagonal for the
# Householder parametrization. Using -Id makes:
# householder(torch.zeros(m,n)) == torch.eye(m,n)
# Poor man's version of eye_like
neg_Id = torch.zeros_like(Q_init)
neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.)
return neg_Id
[docs]def orthogonal(module: Module,
name: str = 'weight',
orthogonal_map: Optional[str] = None,
*,
use_trivialization: bool = True) -> Module:
r"""Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices.
Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized
matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as
.. math::
\begin{align*}
Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\
QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n}
\end{align*}
where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex
and the transpose when :math:`Q` is real-valued, and
:math:`\mathrm{I}_n` is the `n`-dimensional identity matrix.
In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n`
and orthonormal rows otherwise.
If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`.
The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor:
- ``"matrix_exp"``/``"cayley"``:
the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_
:math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric
:math:`A` to give an orthogonal matrix.
- ``"householder"``: computes a product of Householder reflectors
(:func:`~torch.linalg.householder_product`).
``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than
``"householder"``, but they are slower to compute for very thin or very wide matrices.
If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework",
where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under
``module.parametrizations.weight[0].base``. This helps the
convergence of the parametrized layer at the expense of some extra memory use.
See `Trivializations for Gradient-Based Optimization on Manifolds`_ .
Initial value of :math:`Q`:
If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value
of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case)
and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`).
Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``.
Otherwise, the initial value is the result of the composition of all the registered
parametrizations applied to the original tensor.
.. note::
This function is implemented using the parametrization functionality
in :func:`~torch.nn.utils.parametrize.register_parametrization`.
.. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
.. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501
Args:
module (nn.Module): module on which to register the parametrization.
name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``.
orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``.
Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise.
use_trivialization (bool, optional): whether to use the dynamic trivialization framework.
Default: ``True``.
Returns:
The original module with an orthogonal parametrization registered to the specified
weight
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> orth_linear = orthogonal(nn.Linear(20, 40))
>>> orth_linear
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _Orthogonal()
)
)
)
>>> # xdoctest: +IGNORE_WANT
>>> Q = orth_linear.weight
>>> torch.dist(Q.T @ Q, torch.eye(20))
tensor(4.9332e-07)
"""
weight = getattr(module, name, None)
if not isinstance(weight, Tensor):
raise ValueError(
f"Module '{module}' has no parameter or buffer with name '{name}'"
)
# We could implement this for 1-dim tensors as the maps on the sphere
# but I believe it'd bite more people than it'd help
if weight.ndim < 2:
raise ValueError("Expected a matrix or batch of matrices. "
f"Got a tensor of {weight.ndim} dimensions.")
if orthogonal_map is None:
orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder"
orth_enum = getattr(_OrthMaps, orthogonal_map, None)
if orth_enum is None:
raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". '
f'Got: {orthogonal_map}')
orth = _Orthogonal(weight,
orth_enum,
use_trivialization=use_trivialization)
parametrize.register_parametrization(module, name, orth, unsafe=True)
return module
class _WeightNorm(Module):
def __init__(
self,
dim: Optional[int] = 0,
) -> None:
super().__init__()
if dim is None:
dim = -1
self.dim = dim
def forward(self, weight_g, weight_v):
return torch._weight_norm(weight_v, weight_g, self.dim)
def right_inverse(self, weight):
weight_g = torch.norm_except_dim(weight, 2, self.dim)
weight_v = weight
return weight_g, weight_v
def weight_norm(module: Module, name: str = 'weight', dim: int = 0):
r"""Applies weight normalization to a parameter in the given module.
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by :attr:`name` with two parameters: one specifying the magnitude
and one specifying the direction.
By default, with ``dim=0``, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
``dim=None``.
See https://arxiv.org/abs/1602.07868
Args:
module (Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
Returns:
The original module with the weight norm hook
Example::
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _WeightNorm()
)
)
)
>>> m.parametrizations.weight.original0.size()
torch.Size([40, 1])
>>> m.parametrizations.weight.original1.size()
torch.Size([40, 20])
"""
_weight_norm = _WeightNorm(dim)
parametrize.register_parametrization(module, name, _weight_norm, unsafe=True)
def _weight_norm_compat_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
g_key = f"{prefix}{name}_g"
v_key = f"{prefix}{name}_v"
if g_key in state_dict and v_key in state_dict:
original0 = state_dict.pop(g_key)
original1 = state_dict.pop(v_key)
state_dict[f"{prefix}parametrizations.{name}.original0"] = original0
state_dict[f"{prefix}parametrizations.{name}.original1"] = original1
module._register_load_state_dict_pre_hook(_weight_norm_compat_hook)
return module
class _SpectralNorm(Module):
def __init__(
self,
weight: torch.Tensor,
n_power_iterations: int = 1,
dim: int = 0,
eps: float = 1e-12
) -> None:
super().__init__()
ndim = weight.ndim
if dim >= ndim or dim < -ndim:
raise IndexError("Dimension out of range (expected to be in range of "
f"[-{ndim}, {ndim - 1}] but got {dim})")
if n_power_iterations <= 0:
raise ValueError('Expected n_power_iterations to be positive, but '
f'got n_power_iterations={n_power_iterations}')
self.dim = dim if dim >= 0 else dim + ndim
self.eps = eps
if ndim > 1:
# For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward)
self.n_power_iterations = n_power_iterations
weight_mat = self._reshape_weight_to_matrix(weight)
h, w = weight_mat.size()
u = weight_mat.new_empty(h).normal_(0, 1)
v = weight_mat.new_empty(w).normal_(0, 1)
self.register_buffer('_u', F.normalize(u, dim=0, eps=self.eps))
self.register_buffer('_v', F.normalize(v, dim=0, eps=self.eps))
# Start with u, v initialized to some reasonable values by performing a number
# of iterations of the power method
self._power_method(weight_mat, 15)
def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
# Precondition
assert weight.ndim > 1
if self.dim != 0:
# permute dim to front
weight = weight.permute(self.dim, *(d for d in range(weight.dim()) if d != self.dim))
return weight.flatten(1)
@torch.autograd.no_grad()
def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None:
# See original note at torch/nn/utils/spectral_norm.py
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
# updated in power iteration **in-place**. This is very important
# because in `DataParallel` forward, the vectors (being buffers) are
# broadcast from the parallelized module to each module replica,
# which is a new module object created on the fly. And each replica
# runs its own spectral norm power iteration. So simply assigning
# the updated vectors to the module this function runs on will cause
# the update to be lost forever. And the next time the parallelized
# module is replicated, the same randomly initialized vectors are
# broadcast and used!
#
# Therefore, to make the change propagate back, we rely on two
# important behaviors (also enforced via tests):
# 1. `DataParallel` doesn't clone storage if the broadcast tensor
# is already on correct device; and it makes sure that the
# parallelized module is already on `device[0]`.
# 2. If the out tensor in `out=` kwarg has correct shape, it will
# just fill in the values.
# Therefore, since the same power iteration is performed on all
# devices, simply updating the tensors in-place will make sure that
# the module replica on `device[0]` will update the _u vector on the
# parallelized module (by shared storage).
#
# However, after we update `u` and `v` in-place, we need to **clone**
# them before using them to normalize the weight. This is to support
# backproping through two forward passes, e.g., the common pattern in
# GAN training: loss = D(real) - D(fake). Otherwise, engine will
# complain that variables needed to do backward for the first forward
# (i.e., the `u` and `v` vectors) are changed in the second forward.
# Precondition
assert weight_mat.ndim > 1
for _ in range(n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vectors.
# This power iteration produces approximations of `u` and `v`.
self._u = F.normalize(torch.mv(weight_mat, self._v), # type: ignore[has-type]
dim=0, eps=self.eps, out=self._u) # type: ignore[has-type]
self._v = F.normalize(torch.mv(weight_mat.t(), self._u),
dim=0, eps=self.eps, out=self._v) # type: ignore[has-type]
def forward(self, weight: torch.Tensor) -> torch.Tensor:
if weight.ndim == 1:
# Faster and more exact path, no need to approximate anything
return F.normalize(weight, dim=0, eps=self.eps)
else:
weight_mat = self._reshape_weight_to_matrix(weight)
if self.training:
self._power_method(weight_mat, self.n_power_iterations)
# See above on why we need to clone
u = self._u.clone(memory_format=torch.contiguous_format)
v = self._v.clone(memory_format=torch.contiguous_format)
# The proper way of computing this should be through F.bilinear, but
# it seems to have some efficiency issues:
# https://github.com/pytorch/pytorch/issues/58093
sigma = torch.dot(u, torch.mv(weight_mat, v))
return weight / sigma
def right_inverse(self, value: torch.Tensor) -> torch.Tensor:
# we may want to assert here that the passed value already
# satisfies constraints
return value
[docs]def spectral_norm(module: Module,
name: str = 'weight',
n_power_iterations: int = 1,
eps: float = 1e-12,
dim: Optional[int] = None) -> Module:
r"""Applies spectral normalization to a parameter in the given module.
.. math::
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
When applied on a vector, it simplifies to
.. math::
\mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}
Spectral normalization stabilizes the training of discriminators (critics)
in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant
of the model. :math:`\sigma` is approximated performing one iteration of the
`power method`_ every time the weight is accessed. If the dimension of the
weight tensor is greater than 2, it is reshaped to 2D in power iteration
method to get spectral norm.
See `Spectral Normalization for Generative Adversarial Networks`_ .
.. _`power method`: https://en.wikipedia.org/wiki/Power_iteration
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
.. note::
This function is implemented using the parametrization functionality
in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a
reimplementation of :func:`torch.nn.utils.spectral_norm`.
.. note::
When this constraint is registered, the singular vectors associated to the largest
singular value are estimated rather than sampled at random. These are then updated
performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor
is accessed with the module on `training` mode.
.. note::
If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`,
is in training mode on removal, it will perform another power iteration.
If you'd like to avoid this iteration, set the module to eval mode
before its removal.
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter. Default: ``"weight"``.
n_power_iterations (int, optional): number of power iterations to
calculate spectral norm. Default: ``1``.
eps (float, optional): epsilon for numerical stability in
calculating norms. Default: ``1e-12``.
dim (int, optional): dimension corresponding to number of outputs.
Default: ``0``, except for modules that are instances of
ConvTranspose{1,2,3}d, when it is ``1``
Returns:
The original module with a new parametrization registered to the specified
weight
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> snm = spectral_norm(nn.Linear(20, 40))
>>> snm
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _SpectralNorm()
)
)
)
>>> torch.linalg.matrix_norm(snm.weight, 2)
tensor(1.0081, grad_fn=<AmaxBackward0>)
"""
weight = getattr(module, name, None)
if not isinstance(weight, Tensor):
raise ValueError(
f"Module '{module}' has no parameter or buffer with name '{name}'"
)
if dim is None:
if isinstance(module, (torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d)):
dim = 1
else:
dim = 0
parametrize.register_parametrization(module, name, _SpectralNorm(weight, n_power_iterations, dim, eps))
return module