Baddbmm
Source:R/gen-namespace-docs.R
, R/gen-namespace-examples.R
, R/gen-namespace.R
torch_baddbmm.Rd
Baddbmm
Arguments
- self
(Tensor) the tensor to be added
- batch1
(Tensor) the first batch of matrices to be multiplied
- batch2
(Tensor) the second batch of matrices to be multiplied
- beta
(Number, optional) multiplier for
input
(\(\beta\))- alpha
(Number, optional) multiplier for \(\mbox{batch1} \mathbin{@} \mbox{batch2}\) (\(\alpha\))
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices in batch1
and batch2
.
input
is added to the final result.
batch1
and batch2
must be 3-D tensors each containing the same
number of matrices.
If batch1
is a \((b \times n \times m)\) tensor, batch2
is a
\((b \times m \times p)\) tensor, then input
must be
broadcastable with a
\((b \times n \times p)\) tensor and out
will be a
\((b \times n \times p)\) tensor. Both alpha
and beta
mean the
same as the scaling factors used in torch_addbmm
.
$$
\mbox{out}_i = \beta\ \mbox{input}_i + \alpha\ (\mbox{batch1}_i \mathbin{@} \mbox{batch2}_i)
$$
For inputs of type FloatTensor
or DoubleTensor
, arguments beta
and
alpha
must be real numbers, otherwise they should be integers.
Examples
if (torch_is_installed()) {
M = torch_randn(c(10, 3, 5))
batch1 = torch_randn(c(10, 3, 4))
batch2 = torch_randn(c(10, 4, 5))
torch_baddbmm(M, batch1, batch2)
}
#> torch_tensor
#> (1,.,.) =
#> 0.5111 -3.1852 3.7126 -1.3524 3.9963
#> 4.7965 -0.0369 -0.3263 -1.2753 0.0556
#> 2.3793 0.1486 0.1915 -0.4350 1.2126
#>
#> (2,.,.) =
#> 1.5858 -1.5813 -0.6729 -1.2245 -1.2957
#> 2.2479 0.2300 0.6143 1.7530 0.5099
#> -2.0699 2.4970 0.9025 -0.7765 1.3482
#>
#> (3,.,.) =
#> 1.8821 0.1650 2.2522 -0.3192 -0.3841
#> -3.8732 -1.4290 0.6768 -1.7308 3.1746
#> 0.3767 -0.8012 -0.3291 0.9759 -1.7650
#>
#> (4,.,.) =
#> 0.0757 1.7507 2.5247 -2.5280 0.1248
#> 2.0840 1.4858 0.4577 -1.2264 1.5897
#> 2.7945 0.7971 -0.6055 -0.2607 0.7605
#>
#> (5,.,.) =
#> 0.7051 0.2898 1.5222 -0.2621 -2.5265
#> -0.4269 -1.8950 -2.1202 1.3455 0.5589
#> 3.2963 5.4668 4.0805 -4.1564 -2.0503
#>
#> (6,.,.) =
#> 0.5190 -0.5352 -1.4169 -0.1021 2.8163
#> -0.1018 0.0950 1.8032 -1.3999 -1.6487
#> -0.2259 2.6865 4.0377 1.6081 1.1882
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]