Bmm
Note
This function does not broadcast .
For broadcasting matrix products, see torch_matmul.
bmm(input, mat2, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices stored in input
and mat2.
input and mat2 must be 3-D tensors each containing
the same number of matrices.
If input is a \((b \times n \times m)\) tensor, mat2 is a
\((b \times m \times p)\) tensor, out will be a
\((b \times n \times p)\) tensor.
$$ \mbox{out}_i = \mbox{input}_i \mathbin{@} \mbox{mat2}_i $$
Examples
if (torch_is_installed()) {
input = torch_randn(c(10, 3, 4))
mat2 = torch_randn(c(10, 4, 5))
res = torch_bmm(input, mat2)
res
}
#> torch_tensor
#> (1,.,.) =
#> 1.9130 0.6033 0.3888 0.5372 -0.4651
#> -3.1482 1.9479 -1.2457 1.1683 2.4823
#> -0.1129 0.9505 0.3201 -2.4624 -2.6049
#>
#> (2,.,.) =
#> 0.8762 -2.0045 1.3103 -0.4235 1.7788
#> 3.9461 4.8694 3.7881 1.6393 -1.5353
#> 0.6971 0.9571 0.2847 -0.7819 0.1743
#>
#> (3,.,.) =
#> -1.5990 2.3937 1.9001 0.6660 4.9165
#> 1.9670 -1.1894 0.9631 1.7993 0.8822
#> 0.6011 -1.0120 0.2467 -1.6122 -3.6152
#>
#> (4,.,.) =
#> 0.5338 -3.1084 -0.5995 1.0252 1.5329
#> -1.7698 3.4609 0.7087 -1.8167 -3.1960
#> -0.2377 -1.5909 -0.1485 0.1318 0.2653
#>
#> (5,.,.) =
#> 0.1142 1.2295 -1.1071 -0.0130 -0.5943
#> -0.3616 2.6841 -0.9658 0.5320 -1.3257
#> 0.8364 0.6762 1.8894 1.4191 2.2782
#>
#> (6,.,.) =
#> -3.0720 -0.5186 -0.3124 1.1297 -4.3022
#> 2.3711 -0.7139 -0.3969 -2.4186 1.5882
#> -0.4344 -1.0548 0.2255 0.0597 -0.6742
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]