Computes the solution X
to the system torch_tensordot(A, X) = B
.
Source: R/linalg.R
linalg_tensorsolve.Rd
If m
is the product of the first B
\ .ndim
dimensions of A
and
n
is the product of the rest of the dimensions, this function expects m
and n
to be equal.
The returned tensor x
satisfies
tensordot(A, x, dims=x$ndim) == B
.
Arguments
- A
(Tensor): tensor to solve for.
- B
(Tensor): the solution
- dims
(Tupleint, optional): dimensions of
A
to be moved. IfNULL
, no dimensions are moved. Default:NULL
.
Details
If dims
is specified, A
will be reshaped as
A = movedim(A, dims, seq(len(dims) - A$ndim + 1, 0))
Supports inputs of float, double, cfloat and cdouble dtypes.
See also
linalg_tensorinv()
computes the multiplicative inverse oftorch_tensordot()
.
Other linalg:
linalg_cholesky_ex()
,
linalg_cholesky()
,
linalg_det()
,
linalg_eigh()
,
linalg_eigvalsh()
,
linalg_eigvals()
,
linalg_eig()
,
linalg_householder_product()
,
linalg_inv_ex()
,
linalg_inv()
,
linalg_lstsq()
,
linalg_matrix_norm()
,
linalg_matrix_power()
,
linalg_matrix_rank()
,
linalg_multi_dot()
,
linalg_norm()
,
linalg_pinv()
,
linalg_qr()
,
linalg_slogdet()
,
linalg_solve_triangular()
,
linalg_solve()
,
linalg_svdvals()
,
linalg_svd()
,
linalg_tensorinv()
,
linalg_vector_norm()
Examples
if (torch_is_installed()) {
A <- torch_eye(2 * 3 * 4)$reshape(c(2 * 3, 4, 2, 3, 4))
B <- torch_randn(2 * 3, 4)
X <- linalg_tensorsolve(A, B)
X$shape
torch_allclose(torch_tensordot(A, X, dims = X$ndim), B)
A <- torch_randn(6, 4, 4, 3, 2)
B <- torch_randn(4, 3, 2)
X <- linalg_tensorsolve(A, B, dims = c(1, 3))
A <- A$permute(c(2, 4, 5, 1, 3))
torch_allclose(torch_tensordot(A, X, dims = X$ndim), B, atol = 1e-6)
}
#> [1] TRUE