Baddbmm
Source:R/gen-namespace-docs.R
, R/gen-namespace-examples.R
, R/gen-namespace.R
torch_baddbmm.Rd
Baddbmm
Arguments
- self
(Tensor) the tensor to be added
- batch1
(Tensor) the first batch of matrices to be multiplied
- batch2
(Tensor) the second batch of matrices to be multiplied
- beta
(Number, optional) multiplier for
input
(\(\beta\))- alpha
(Number, optional) multiplier for \(\mbox{batch1} \mathbin{@} \mbox{batch2}\) (\(\alpha\))
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices in batch1
and batch2
.
input
is added to the final result.
batch1
and batch2
must be 3-D tensors each containing the same
number of matrices.
If batch1
is a \((b \times n \times m)\) tensor, batch2
is a
\((b \times m \times p)\) tensor, then input
must be
broadcastable with a
\((b \times n \times p)\) tensor and out
will be a
\((b \times n \times p)\) tensor. Both alpha
and beta
mean the
same as the scaling factors used in torch_addbmm
.
$$
\mbox{out}_i = \beta\ \mbox{input}_i + \alpha\ (\mbox{batch1}_i \mathbin{@} \mbox{batch2}_i)
$$
For inputs of type FloatTensor
or DoubleTensor
, arguments beta
and
alpha
must be real numbers, otherwise they should be integers.
Examples
if (torch_is_installed()) {
M = torch_randn(c(10, 3, 5))
batch1 = torch_randn(c(10, 3, 4))
batch2 = torch_randn(c(10, 4, 5))
torch_baddbmm(M, batch1, batch2)
}
#> torch_tensor
#> (1,.,.) =
#> 2.8642 -0.6452 -1.6961 2.0009 -2.6657
#> -1.8359 -2.3786 -0.9628 -0.0016 -0.2856
#> -2.1721 -1.9670 -2.7795 -1.0340 0.0047
#>
#> (2,.,.) =
#> -0.9847 2.4049 -0.1452 4.2340 -6.2758
#> 1.5184 -1.3777 1.9689 -1.7716 5.5948
#> 1.6697 4.0008 -1.1465 3.2824 -2.3229
#>
#> (3,.,.) =
#> -0.8041 4.4811 1.7899 2.5888 -0.7763
#> 2.9867 2.1121 0.0483 -2.9491 2.2222
#> -2.2620 3.1658 3.4327 3.1481 0.0484
#>
#> (4,.,.) =
#> 2.3108 -0.6733 -0.8447 -4.0547 0.3361
#> 0.0457 -3.6878 0.5342 1.5557 -0.7195
#> -2.8151 0.4931 1.6969 3.8646 0.9910
#>
#> (5,.,.) =
#> -1.2385 -0.7696 -3.0007 -0.5918 -1.1452
#> 2.6435 -2.0913 1.2918 -1.0463 0.4882
#> -2.0015 -1.7780 -1.1738 1.2719 -1.7094
#>
#> (6,.,.) =
#> 0.9015 1.0750 0.2324 -0.2374 -0.1188
#> -3.5077 -3.5626 -0.0457 -0.9870 2.9761
#> 1.8342 -1.3462 -1.6882 -0.4784 0.7568
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]