torch.lu_solve¶
-
torch.
lu_solve
(b, LU_data, LU_pivots, *, out=None) → Tensor¶ Returns the LU solve of the linear system using the partially pivoted LU factorization of A from
torch.lu()
.This function supports
float
,double
,cfloat
andcdouble
dtypes forinput
.- Parameters
b (Tensor) – the RHS tensor of size , where is zero or more batch dimensions.
LU_data (Tensor) – the pivoted LU factorization of A from
torch.lu()
of size , where is zero or more batch dimensions.LU_pivots (IntTensor) – the pivots of the LU factorization from
torch.lu()
of size , where is zero or more batch dimensions. The batch dimensions ofLU_pivots
must be equal to the batch dimensions ofLU_data
.
- Keyword Arguments
out (Tensor, optional) – the output tensor.
Example:
>>> A = torch.randn(2, 3, 3) >>> b = torch.randn(2, 3, 1) >>> A_LU = torch.lu(A) >>> x = torch.lu_solve(b, *A_LU) >>> torch.norm(torch.bmm(A, x) - b) tensor(1.00000e-07 * 2.8312)