Baddbmm
Source:R/gen-namespace-docs.R
, R/gen-namespace-examples.R
, R/gen-namespace.R
torch_baddbmm.Rd
Baddbmm
Arguments
- self
(Tensor) the tensor to be added
- batch1
(Tensor) the first batch of matrices to be multiplied
- batch2
(Tensor) the second batch of matrices to be multiplied
- beta
(Number, optional) multiplier for
input
(\(\beta\))- alpha
(Number, optional) multiplier for \(\mbox{batch1} \mathbin{@} \mbox{batch2}\) (\(\alpha\))
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices in batch1
and batch2
.
input
is added to the final result.
batch1
and batch2
must be 3-D tensors each containing the same
number of matrices.
If batch1
is a \((b \times n \times m)\) tensor, batch2
is a
\((b \times m \times p)\) tensor, then input
must be
broadcastable with a
\((b \times n \times p)\) tensor and out
will be a
\((b \times n \times p)\) tensor. Both alpha
and beta
mean the
same as the scaling factors used in torch_addbmm
.
$$
\mbox{out}_i = \beta\ \mbox{input}_i + \alpha\ (\mbox{batch1}_i \mathbin{@} \mbox{batch2}_i)
$$
For inputs of type FloatTensor
or DoubleTensor
, arguments beta
and
alpha
must be real numbers, otherwise they should be integers.
Examples
if (torch_is_installed()) {
M = torch_randn(c(10, 3, 5))
batch1 = torch_randn(c(10, 3, 4))
batch2 = torch_randn(c(10, 4, 5))
torch_baddbmm(M, batch1, batch2)
}
#> torch_tensor
#> (1,.,.) =
#> 0.8327 1.4081 1.4228 1.2399 -2.8737
#> -0.2565 0.0368 0.6259 -3.9232 -0.4630
#> -0.5861 -1.2092 -0.1388 -4.4102 -1.5230
#>
#> (2,.,.) =
#> 0.6239 -3.5683 0.6935 1.6923 3.6257
#> -0.0053 -4.4689 -3.1876 -1.5333 -0.6221
#> -1.0120 -1.0548 0.1891 0.3091 2.8736
#>
#> (3,.,.) =
#> -2.4318 0.3211 -1.6522 0.8343 -1.7426
#> 0.2908 -0.3338 -2.0635 -0.6237 0.7619
#> 0.0415 0.0592 -3.5628 -1.1115 0.0642
#>
#> (4,.,.) =
#> 2.6377 1.3936 0.5833 -0.7960 6.1028
#> 4.7581 0.5937 1.2806 -0.8959 5.5142
#> 0.5030 -1.4242 0.5065 2.2072 1.1472
#>
#> (5,.,.) =
#> 0.3124 0.6058 -4.0384 -1.4171 0.9059
#> -2.1184 -0.1158 3.1254 -0.8333 -0.7524
#> -1.4068 1.7623 -0.6860 0.1140 -1.2178
#>
#> (6,.,.) =
#> -1.3656 -0.7483 -0.1196 0.3634 4.7731
#> -1.6870 0.4042 0.6884 -2.4956 -3.3773
#> 0.8650 1.4072 -0.6556 -0.8160 -4.4744
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]