Skip to contents

Bmm

Usage

torch_bmm(self, mat2)

Arguments

self

(Tensor) the first batch of matrices to be multiplied

mat2

(Tensor) the second batch of matrices to be multiplied

Note

This function does not broadcast . For broadcasting matrix products, see torch_matmul.

bmm(input, mat2, out=NULL) -> Tensor

Performs a batch matrix-matrix product of matrices stored in input and mat2.

input and mat2 must be 3-D tensors each containing the same number of matrices.

If input is a (b×n×m) tensor, mat2 is a (b×m×p) tensor, out will be a (b×n×p) tensor.

outi=inputi@mat2i

Examples

if (torch_is_installed()) {

input = torch_randn(c(10, 3, 4))
mat2 = torch_randn(c(10, 4, 5))
res = torch_bmm(input, mat2)
res
}
#> torch_tensor
#> (1,.,.) = 
#>  -2.8534  0.4345  0.4712 -0.4744  1.3620
#>   0.4181  1.9234 -3.9197  0.4809 -0.8014
#>  -0.5927 -2.9012  3.5623 -0.0814  0.2617
#> 
#> (2,.,.) = 
#>   1.9650  0.3391  3.0543  1.7688 -1.4351
#>  -0.8712  0.1365 -1.3379 -0.6271 -1.1187
#>  -1.8277  0.1141 -2.9571 -1.2596  3.1393
#> 
#> (3,.,.) = 
#>   1.3240  0.9878  1.4821 -1.6199  1.3402
#>   3.5044 -3.2130  1.6134 -1.1278  2.3436
#>   0.6480 -1.3304  0.1500  0.0666  0.4061
#> 
#> (4,.,.) = 
#>  -1.0794 -1.4471  1.4382  1.3145  0.6798
#>   2.2487 -0.0596 -4.6165  1.2463 -4.3305
#>   0.8640 -0.4978  0.5152  0.5171 -1.4735
#> 
#> (5,.,.) = 
#>   5.2092 -1.5554  6.7863  2.8892 -3.3742
#>   0.7760 -1.9957 -4.9847 -2.1459  0.0122
#>   1.1338  0.7593  2.0044  0.5594  0.5299
#> 
#> (6,.,.) = 
#>   0.7099 -0.1965  0.1813 -0.0309  0.5217
#>   0.4900 -0.2486  2.1104 -1.8801  0.3585
#>   1.4203 -0.9232  3.3366 -0.7276  1.5617
#> 
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]