Bmm
Note
This function does not broadcast .
For broadcasting matrix products, see torch_matmul
.
bmm(input, mat2, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices stored in input
and mat2
.
input
and mat2
must be 3-D tensors each containing
the same number of matrices.
If input
is a \((b \times n \times m)\) tensor, mat2
is a
\((b \times m \times p)\) tensor, out
will be a
\((b \times n \times p)\) tensor.
$$ \mbox{out}_i = \mbox{input}_i \mathbin{@} \mbox{mat2}_i $$
Examples
if (torch_is_installed()) {
input = torch_randn(c(10, 3, 4))
mat2 = torch_randn(c(10, 4, 5))
res = torch_bmm(input, mat2)
res
}
#> torch_tensor
#> (1,.,.) =
#> 2.7168 -1.2141 -0.5394 1.7222 1.0226
#> -2.8314 0.3601 1.4570 -0.7288 -0.9739
#> 4.7212 -2.3189 -3.2758 1.9125 1.2320
#>
#> (2,.,.) =
#> 3.7468 0.5798 -1.7872 -0.6398 1.5678
#> -1.5621 -0.8541 1.0281 -0.6850 0.1060
#> -0.3863 -0.4421 1.1998 -0.2861 -0.1437
#>
#> (3,.,.) =
#> 1.6474 2.7037 -4.0628 -3.0125 -3.0695
#> 1.4426 3.2565 -5.3501 -2.8497 -2.9017
#> -1.0122 -3.1213 3.8614 1.0956 1.1245
#>
#> (4,.,.) =
#> 4.3702 -6.1084 3.6425 -2.4940 -5.5863
#> -1.4022 -2.9472 -1.3945 0.9971 -0.0288
#> 1.2110 -3.2717 0.9385 -0.7434 -2.2789
#>
#> (5,.,.) =
#> -4.6309 -3.4303 -0.0761 -0.4952 2.4886
#> -2.0501 0.9634 0.3599 2.2724 2.8026
#> -2.0736 -0.1793 -0.1479 1.1329 2.0929
#>
#> (6,.,.) =
#> -2.0675 -2.2060 -2.4834 0.2732 1.9481
#> 3.0875 3.8344 2.7201 -0.5610 1.1073
#> -1.7108 -4.0030 -2.2508 -0.9731 -3.8781
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]