Bmm
Note
This function does not broadcast .
For broadcasting matrix products, see torch_matmul
.
bmm(input, mat2, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices stored in input
and mat2
.
input
and mat2
must be 3-D tensors each containing
the same number of matrices.
If input
is a \((b \times n \times m)\) tensor, mat2
is a
\((b \times m \times p)\) tensor, out
will be a
\((b \times n \times p)\) tensor.
$$ \mbox{out}_i = \mbox{input}_i \mathbin{@} \mbox{mat2}_i $$
Examples
if (torch_is_installed()) {
input = torch_randn(c(10, 3, 4))
mat2 = torch_randn(c(10, 4, 5))
res = torch_bmm(input, mat2)
res
}
#> torch_tensor
#> (1,.,.) =
#> -3.9779 -4.0447 2.1614 2.1942 -0.4003
#> 0.3096 -0.9555 0.8257 0.1269 2.7483
#> 3.3655 1.0202 -0.5988 -1.8782 0.1603
#>
#> (2,.,.) =
#> -0.1271 1.2540 -2.4086 0.3651 -0.7772
#> -0.0002 -0.4978 1.4643 0.6693 0.5395
#> 1.4081 -1.8877 -1.6298 -3.7741 -0.8088
#>
#> (3,.,.) =
#> 1.4005 1.4783 -0.4995 -1.1390 3.2663
#> 0.8963 1.0888 1.2733 0.6346 -2.2302
#> 0.6499 -0.1899 1.2917 2.4539 -2.8574
#>
#> (4,.,.) =
#> -1.4033 0.9887 3.1516 -2.2690 0.4593
#> 1.0713 -1.5447 -1.7932 -0.0224 0.6681
#> -0.7326 3.4533 3.6727 -0.2067 -2.8119
#>
#> (5,.,.) =
#> -0.5138 -0.0948 1.8494 1.1469 1.8914
#> -0.0875 -2.4284 -2.2063 -0.0906 0.6199
#> 0.8737 2.5458 -0.0385 -0.7725 -2.5906
#>
#> (6,.,.) =
#> 2.1755 1.7435 -1.5009 0.8825 0.6605
#> 0.4012 0.0050 -0.7586 -0.9768 0.6997
#> -1.5066 -0.4932 1.0696 -1.0605 1.0048
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]