Baddbmm
Source:R/gen-namespace-docs.R
, R/gen-namespace-examples.R
, R/gen-namespace.R
torch_baddbmm.Rd
Baddbmm
Arguments
- self
(Tensor) the tensor to be added
- batch1
(Tensor) the first batch of matrices to be multiplied
- batch2
(Tensor) the second batch of matrices to be multiplied
- beta
(Number, optional) multiplier for
input
(\(\beta\))- alpha
(Number, optional) multiplier for \(\mbox{batch1} \mathbin{@} \mbox{batch2}\) (\(\alpha\))
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices in batch1
and batch2
.
input
is added to the final result.
batch1
and batch2
must be 3-D tensors each containing the same
number of matrices.
If batch1
is a \((b \times n \times m)\) tensor, batch2
is a
\((b \times m \times p)\) tensor, then input
must be
broadcastable with a
\((b \times n \times p)\) tensor and out
will be a
\((b \times n \times p)\) tensor. Both alpha
and beta
mean the
same as the scaling factors used in torch_addbmm
.
$$
\mbox{out}_i = \beta\ \mbox{input}_i + \alpha\ (\mbox{batch1}_i \mathbin{@} \mbox{batch2}_i)
$$
For inputs of type FloatTensor
or DoubleTensor
, arguments beta
and
alpha
must be real numbers, otherwise they should be integers.
Examples
if (torch_is_installed()) {
M = torch_randn(c(10, 3, 5))
batch1 = torch_randn(c(10, 3, 4))
batch2 = torch_randn(c(10, 4, 5))
torch_baddbmm(M, batch1, batch2)
}
#> torch_tensor
#> (1,.,.) =
#> 1.6657 -0.6742 0.4512 -3.2314 0.8295
#> 2.0840 3.4537 1.5246 0.4369 -1.3417
#> 0.8246 -0.4030 1.1787 1.0525 0.6282
#>
#> (2,.,.) =
#> 1.1660 1.2199 -1.4092 1.6581 -2.3506
#> -0.3510 -0.2972 1.0932 3.2922 -1.8609
#> 0.0902 -0.7679 0.4225 2.6301 -1.6637
#>
#> (3,.,.) =
#> 1.5026 -1.2179 -3.3466 1.0576 -1.6512
#> -2.7543 0.9854 -1.4523 1.0961 2.5613
#> 3.3064 -1.2371 -3.1061 1.9371 -0.7384
#>
#> (4,.,.) =
#> -3.8763 4.0324 0.5153 -3.7112 1.9464
#> 2.9007 -1.3636 0.9369 3.3250 -1.9141
#> -1.3285 3.9644 -1.3838 -0.5099 -0.7736
#>
#> (5,.,.) =
#> 0.6634 1.0391 -1.9443 1.2481 0.5775
#> 1.2183 3.6301 -0.8679 -1.9750 -0.5477
#> 1.9005 -0.3253 -2.8810 1.6825 0.8057
#>
#> (6,.,.) =
#> 0.1856 2.1649 1.2783 1.2708 -0.4000
#> -1.1717 -2.7041 0.5744 -0.5783 -1.6064
#> 1.2987 -1.4754 3.0119 -1.4063 0.6659
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]