Skip to contents

For use with nn_sequential.

Usage

nn_flatten(start_dim = 2, end_dim = -1)

Arguments

start_dim

first dim to flatten (default = 2).

end_dim

last dim to flatten (default = -1).

Shape

  • Input: (*, S_start,..., S_i, ..., S_end, *), where S_i is the size at dimension i and * means any number of dimensions including none.

  • Output: (*, S_start*...*S_i*...S_end, *).

See also

Examples

if (torch_is_installed()) {
input <- torch_randn(32, 1, 5, 5)
m <- nn_flatten()
m(input)
}
#> torch_tensor
#> Columns 1 to 10-0.8015 -1.0448  0.4715 -0.0156  0.7085  0.6113 -0.5675 -0.7171 -0.2615  1.2159
#>  2.3857  1.2484 -0.5993  0.3149  1.0410 -0.0031 -0.1256 -1.4730 -0.4148  0.2949
#> -0.6647  2.6382 -1.9189  1.5107  0.1844 -0.4771 -0.1499  0.6239 -0.2575  0.1938
#>  2.0218 -0.1541  0.3580 -0.6831 -0.4747  1.5348  0.7168  0.3497 -0.4437  0.5037
#>  1.5272 -1.2138 -2.3443 -0.6669  0.8383  1.6540 -0.5365  0.3552  1.8786 -1.7464
#>  0.6394 -0.0746 -0.3072 -0.5264 -0.9450  1.0271  0.0991  0.0619  0.3204 -0.2456
#>  0.0250  1.0297 -1.4002  1.7605  1.5106 -2.4053 -0.9495 -0.1437  0.0332  1.5030
#> -0.5799  0.0074  1.2746  1.4900 -1.0154  0.3608 -0.8096 -0.9809  0.0256  2.0557
#>  0.3317  0.3536  0.8048  0.7115  0.3078  1.2502  1.4476 -0.4187 -1.0832 -0.8844
#> -1.0850 -0.6192  0.1627 -1.9400 -0.6659  1.2134  0.2179  0.5830 -0.6183  0.0321
#> -1.6692  0.7485 -0.8128  1.0693 -0.9317  0.9081 -0.7930 -0.7436 -0.6327  0.9231
#>  1.7045  0.1461  0.3835  1.6954 -0.1849 -0.2247  0.6355 -0.5568  0.3298  1.1491
#>  0.1119 -0.1293 -0.1382  1.6020 -0.7861  0.8450  1.0183 -0.0673  0.3050  0.9445
#> -1.8077 -0.1964 -0.1349  1.0182  0.7130 -0.4587 -2.3798  0.3252  1.3848 -0.7724
#>  0.2489  0.5840  0.1690 -0.9544  1.1001 -0.4978 -1.2585 -0.9969  0.0469  0.6774
#> -1.7650 -0.4240 -0.3131  2.0314  1.0105  1.0466  0.2259 -0.2248  1.6021  1.5515
#> -0.8189 -0.2861 -0.1885 -0.1888  1.2369 -0.8073  1.0088  0.8614  0.9339  0.4690
#> -0.8568  1.5507 -1.4375  1.7597  0.7363 -1.4567  2.2990  1.3478  1.4151  0.4128
#>  0.1707  0.9011 -1.6190  0.0704  0.8973 -1.6506  0.3608  0.0830  1.6077 -1.5901
#> -1.0287 -0.4341  0.0636 -0.4749  0.1226 -1.1832 -0.0746  0.9812 -1.4514 -0.0938
#>  0.0018  1.7226 -0.3270  1.1839 -0.4936 -0.8200  0.0203 -0.7399  0.3987  0.4709
#>  0.1172  0.9947  0.4600  1.3165 -1.3755  0.4174 -2.3616 -1.1633 -0.0096 -1.6175
#>  0.8800 -0.6760 -0.3397  1.4769 -0.1437  1.3507  0.3871 -0.4962 -1.1812  0.5646
#> -0.3333 -0.1028 -0.1121  0.2401 -0.9883  1.6939 -0.7209  1.0521 -0.3183 -0.0213
#> -2.6093 -1.1583 -0.8108  0.3671 -0.3561  1.3556  0.1852 -0.6264  1.0287  0.8352
#>  0.8922  1.2746  0.8763 -0.0620 -0.4186 -0.4048  0.1820 -2.4252 -0.5739 -0.7585
#>  1.6924 -0.4010 -0.0632 -0.6726  1.6817  0.9732 -0.9486 -1.2342 -1.1307  1.6650
#> -1.1404  1.6418  0.8507 -1.0957 -0.4790  1.2576  2.6430  1.4992  1.5525 -1.8223
#>  0.1169 -1.0268  1.0565 -0.2669 -0.2602  1.5338 -0.1799 -0.2940 -2.9724 -0.5963
#>  0.2384  1.6367  0.3955 -0.0629  1.3515 -0.0084  1.0178 -0.1979  0.3229  0.7691
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{32,25} ]