Baddbmm
Source:R/gen-namespace-docs.R
, R/gen-namespace-examples.R
, R/gen-namespace.R
torch_baddbmm.Rd
Baddbmm
Arguments
- self
(Tensor) the tensor to be added
- batch1
(Tensor) the first batch of matrices to be multiplied
- batch2
(Tensor) the second batch of matrices to be multiplied
- beta
(Number, optional) multiplier for
input
(\(\beta\))- alpha
(Number, optional) multiplier for \(\mbox{batch1} \mathbin{@} \mbox{batch2}\) (\(\alpha\))
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices in batch1
and batch2
.
input
is added to the final result.
batch1
and batch2
must be 3-D tensors each containing the same
number of matrices.
If batch1
is a \((b \times n \times m)\) tensor, batch2
is a
\((b \times m \times p)\) tensor, then input
must be
broadcastable with a
\((b \times n \times p)\) tensor and out
will be a
\((b \times n \times p)\) tensor. Both alpha
and beta
mean the
same as the scaling factors used in torch_addbmm
.
$$
\mbox{out}_i = \beta\ \mbox{input}_i + \alpha\ (\mbox{batch1}_i \mathbin{@} \mbox{batch2}_i)
$$
For inputs of type FloatTensor
or DoubleTensor
, arguments beta
and
alpha
must be real numbers, otherwise they should be integers.
Examples
if (torch_is_installed()) {
M = torch_randn(c(10, 3, 5))
batch1 = torch_randn(c(10, 3, 4))
batch2 = torch_randn(c(10, 4, 5))
torch_baddbmm(M, batch1, batch2)
}
#> torch_tensor
#> (1,.,.) =
#> 2.7700 3.7597 -0.5840 -0.9228 1.1850
#> 0.2146 -1.0171 0.1117 0.0613 0.7060
#> -0.8335 -1.2902 0.8080 -0.2263 -0.8075
#>
#> (2,.,.) =
#> -1.3236 0.5524 2.3813 -1.8995 0.2905
#> 5.2242 2.2223 -2.6739 3.2640 1.0878
#> -2.2081 -2.6137 3.7052 -0.2712 0.5142
#>
#> (3,.,.) =
#> 0.0276 -1.2610 -0.1775 -0.9561 0.7384
#> 0.6315 -3.6193 -2.3835 -3.4108 -2.9823
#> -0.4199 2.7762 3.5225 4.3590 1.6647
#>
#> (4,.,.) =
#> 1.9706 0.8897 -1.7094 0.9065 0.6663
#> -0.8588 0.0224 2.6829 2.5514 -2.3433
#> -3.1659 3.5826 0.8190 0.8119 1.0245
#>
#> (5,.,.) =
#> -1.3261 0.1733 -0.5291 -2.3699 -0.6561
#> 1.0878 3.9913 3.1729 1.1845 -0.7739
#> -0.1118 1.0368 -1.6662 0.8605 1.1416
#>
#> (6,.,.) =
#> 0.2404 0.8618 -3.1487 -1.1204 4.1657
#> 0.6510 1.2810 0.9076 2.0881 -1.4185
#> 0.7884 -0.3808 -0.6798 -1.6016 2.4876
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]