torch.linalg.householder_product¶
- torch.linalg.householder_product(A, tau, *, out=None) Tensor ¶
Computes the first n columns of a product of Householder matrices.
Let be or , and let be a matrix with columns for with . Denote by the vector resulting from zeroing out the first compontents of and setting to 1 the -th. For a vector with , this function computes the first columns of the matrix
where is the m-dimensional identity matrix and is the conjugate transpose when is complex, and the transpose when is real-valued. The output matrix is the same size as the input matrix
A
.See Representation of Orthogonal or Unitary Matrices for further details.
Supports inputs of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if the inputs are batches of matrices then the output has the same batch dimensions.
See also
torch.geqrf()
can be used together with this function to form the Q from theqr()
decomposition.torch.ormqr()
is a related function that computes the matrix multiplication of a product of Householder matrices with another matrix. However, that function is not supported by autograd.Warning
Gradient computations are only well-defined if . If this condition is not met, no error will be thrown, but the gradient produced may contain NaN.
- Parameters:
- Keyword Arguments:
out (Tensor, optional) – output tensor. Ignored if None. Default: None.
- Raises:
RuntimeError – if
A
doesn’t satisfy the requirement m >= n, ortau
doesn’t satisfy the requirement n >= k.
Examples:
>>> A = torch.randn(2, 2) >>> h, tau = torch.geqrf(A) >>> Q = torch.linalg.householder_product(h, tau) >>> torch.dist(Q, torch.linalg.qr(A).Q) tensor(0.) >>> h = torch.randn(3, 2, 2, dtype=torch.complex128) >>> tau = torch.randn(3, 1, dtype=torch.complex128) >>> Q = torch.linalg.householder_product(h, tau) >>> Q tensor([[[ 1.8034+0.4184j, 0.2588-1.0174j], [-0.6853+0.7953j, 2.0790+0.5620j]], [[ 1.4581+1.6989j, -1.5360+0.1193j], [ 1.3877-0.6691j, 1.3512+1.3024j]], [[ 1.4766+0.5783j, 0.0361+0.6587j], [ 0.6396+0.1612j, 1.3693+0.4481j]]], dtype=torch.complex128)