torch.lu_unpack¶
-
torch.
lu_unpack
(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None)¶ Unpacks the data and pivots from a LU factorization of a tensor into tensors
L
andU
and a permutation tensorP
such thatLU_data, LU_pivots = (P @ L @ U).lu()
.Returns a tuple of tensors as
(the P tensor (permutation matrix), the L tensor, the U tensor)
.Note
P.dtype == LU_data.dtype
andP.dtype
is not an integer type so that matrix products withP
are possible without casting it to a floating type.- Parameters
LU_data (Tensor) – the packed LU factorization data
LU_pivots (Tensor) – the packed LU factorization pivots
unpack_data (bool) – flag indicating if the data should be unpacked. If
False
, then the returnedL
andU
areNone
. Default:True
unpack_pivots (bool) – flag indicating if the pivots should be unpacked into a permutation matrix
P
. IfFalse
, then the returnedP
isNone
. Default:True
out (tuple, optional) – a tuple of three tensors to use for the outputs
(P, L, U)
.
Examples:
>>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = A.lu() >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots) >>> >>> # can recover A from factorization >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U)) >>> # LU factorization of a rectangular matrix: >>> A = torch.randn(2, 3, 2) >>> A_LU, pivots = A.lu() >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots) >>> P tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], [[0., 0., 1.], [0., 1., 0.], [1., 0., 0.]]]) >>> A_L tensor([[[ 1.0000, 0.0000], [ 0.4763, 1.0000], [ 0.3683, 0.1135]], [[ 1.0000, 0.0000], [ 0.2957, 1.0000], [-0.9668, -0.3335]]]) >>> A_U tensor([[[ 2.1962, 1.0881], [ 0.0000, -0.8681]], [[-1.0947, 0.3736], [ 0.0000, 0.5718]]]) >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U)) >>> torch.norm(A_ - A) tensor(2.9802e-08)