Bmm
Note
This function does not broadcast .
For broadcasting matrix products, see torch_matmul
.
bmm(input, mat2, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices stored in input
and mat2
.
input
and mat2
must be 3-D tensors each containing
the same number of matrices.
If input
is a \((b \times n \times m)\) tensor, mat2
is a
\((b \times m \times p)\) tensor, out
will be a
\((b \times n \times p)\) tensor.
$$ \mbox{out}_i = \mbox{input}_i \mathbin{@} \mbox{mat2}_i $$
Examples
if (torch_is_installed()) {
input = torch_randn(c(10, 3, 4))
mat2 = torch_randn(c(10, 4, 5))
res = torch_bmm(input, mat2)
res
}
#> torch_tensor
#> (1,.,.) =
#> -0.0843 -2.1627 -1.6424 -0.4507 4.1480
#> -0.1213 -0.8690 -4.4613 2.1322 3.0673
#> 0.0768 0.8064 -0.6706 1.2668 -1.3280
#>
#> (2,.,.) =
#> 0.3162 1.7651 0.8488 1.5031 -1.0899
#> -0.4573 -0.7125 -0.2178 -1.1378 -2.4903
#> -3.0860 -1.8934 -2.5964 1.4745 3.4209
#>
#> (3,.,.) =
#> 0.7205 -0.3173 -1.2553 -2.3448 1.0693
#> 0.9858 6.5661 -0.0020 -0.3931 -0.1510
#> -0.5032 -1.5543 0.0991 -0.8854 0.4289
#>
#> (4,.,.) =
#> 1.2022 0.7230 0.5549 -0.4991 -1.0222
#> -0.1239 0.9057 2.8032 0.2359 -1.3709
#> -1.1100 -4.6168 3.5286 2.7430 2.3721
#>
#> (5,.,.) =
#> 0.6792 -0.7432 0.5184 0.0772 -1.6586
#> 4.0507 -0.9044 -1.0465 0.9289 -0.7139
#> -1.5169 -2.0995 4.3298 -0.2039 -0.7965
#>
#> (6,.,.) =
#> 0.6674 2.3772 -0.4940 0.5409 -1.7737
#> -2.1742 -7.8698 0.9262 4.9745 5.0967
#> -0.7233 -3.3676 0.0787 1.9661 2.3312
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]