Skip to contents

Tensor creation utilities

Tensor attributes

Serialization

load_state_dict()
Load a state dict file
torch_load()
Loads a saved object
torch_save()
Saves an object to a disk file.
torch_serialize()
Serialize a torch object returning a raw object

Mathematical operations on tensors

torch_set_num_threads() torch_set_num_interop_threads() torch_get_num_interop_threads() torch_get_num_threads()
Number of threads
torch_abs()
Abs
torch_absolute()
Absolute
torch_acos()
Acos
torch_acosh()
Acosh
torch_adaptive_avg_pool1d()
Adaptive_avg_pool1d
torch_add()
Add
torch_addbmm()
Addbmm
torch_addcdiv()
Addcdiv
torch_addcmul()
Addcmul
torch_addmm()
Addmm
torch_addmv()
Addmv
torch_addr()
Addr
torch_allclose()
Allclose
torch_amax()
Amax
torch_amin()
Amin
torch_angle()
Angle
torch_arccos()
Arccos
torch_arccosh()
Arccosh
torch_arcsin()
Arcsin
torch_arcsinh()
Arcsinh
torch_arctan()
Arctan
torch_arctanh()
Arctanh
torch_argmax
Argmax
torch_argmin
Argmin
torch_argsort()
Argsort
torch_as_strided()
As_strided
torch_asin()
Asin
torch_asinh()
Asinh
torch_atan()
Atan
torch_atan2()
Atan2
torch_atanh()
Atanh
torch_atleast_1d()
Atleast_1d
torch_atleast_2d()
Atleast_2d
torch_atleast_3d()
Atleast_3d
torch_avg_pool1d()
Avg_pool1d
torch_baddbmm()
Baddbmm
torch_bartlett_window()
Bartlett_window
torch_bernoulli()
Bernoulli
torch_bincount
Bincount
torch_bitwise_and()
Bitwise_and
torch_bitwise_not()
Bitwise_not
torch_bitwise_or()
Bitwise_or
torch_bitwise_xor()
Bitwise_xor
torch_blackman_window()
Blackman_window
torch_block_diag()
Block_diag
torch_bmm()
Bmm
torch_broadcast_tensors()
Broadcast_tensors
torch_bucketize()
Bucketize
torch_can_cast()
Can_cast
torch_cartesian_prod()
Cartesian_prod
torch_cat()
Cat
torch_cdist()
Cdist
torch_ceil()
Ceil
torch_celu()
Celu
torch_celu_()
Celu_
torch_chain_matmul()
Chain_matmul
torch_channel_shuffle()
Channel_shuffle
torch_cholesky()
Cholesky
torch_cholesky_inverse()
Cholesky_inverse
torch_cholesky_solve()
Cholesky_solve
torch_chunk()
Chunk
torch_clamp()
Clamp
torch_clip()
Clip
torch_clone()
Clone
torch_combinations()
Combinations
torch_complex()
Complex
torch_conj()
Conj
torch_conv1d()
Conv1d
torch_conv2d()
Conv2d
torch_conv3d()
Conv3d
torch_conv_tbc()
Conv_tbc
torch_conv_transpose1d()
Conv_transpose1d
torch_conv_transpose2d()
Conv_transpose2d
torch_conv_transpose3d()
Conv_transpose3d
torch_cos()
Cos
torch_cosh()
Cosh
torch_cosine_similarity()
Cosine_similarity
torch_count_nonzero()
Count_nonzero
torch_cross()
Cross
torch_cummax()
Cummax
torch_cummin()
Cummin
torch_cumprod()
Cumprod
torch_cumsum()
Cumsum
torch_deg2rad()
Deg2rad
torch_dequantize()
Dequantize
torch_det()
Det
torch_device()
Create a Device object
torch_diag()
Diag
torch_diag_embed()
Diag_embed
torch_diagflat()
Diagflat
torch_diagonal()
Diagonal
torch_diff()
Computes the n-th forward difference along the given dimension.
torch_digamma()
Digamma
torch_dist()
Dist
torch_div()
Div
torch_divide()
Divide
torch_dot()
Dot
torch_dstack()
Dstack
torch_eig()
Eig
torch_einsum()
Einsum
torch_empty_strided()
Empty_strided
torch_eq()
Eq
torch_equal()
Equal
torch_erf()
Erf
torch_erfc()
Erfc
torch_erfinv()
Erfinv
torch_exp()
Exp
torch_exp2()
Exp2
torch_expm1()
Expm1
torch_fft_fft()
Fft
torch_fft_fftfreq()
fftfreq
torch_fft_ifft()
Ifft
torch_fft_irfft()
Irfft
torch_fft_rfft()
Rfft
torch_fix()
Fix
torch_flatten()
Flatten
torch_flip()
Flip
torch_fliplr()
Fliplr
torch_flipud()
Flipud
torch_floor()
Floor
torch_floor_divide()
Floor_divide
torch_fmod()
Fmod
torch_frac()
Frac
torch_gather()
Gather
torch_gcd()
Gcd
torch_ge()
Ge
torch_generator()
Create a Generator object
torch_geqrf()
Geqrf
torch_ger()
Ger
torch_greater()
Greater
torch_greater_equal()
Greater_equal
torch_gt()
Gt
torch_hamming_window()
Hamming_window
torch_hann_window()
Hann_window
torch_heaviside()
Heaviside
torch_histc()
Histc
torch_hstack()
Hstack
torch_hypot()
Hypot
torch_i0()
I0
torch_imag()
Imag
torch_index()
Index torch tensors
torch_index_put()
Modify values selected by indices.
torch_index_put_()
In-place version of torch_index_put.
torch_index_select()
Index_select
torch_install_path()
A simple exported version of install_path Returns the torch installation path.
torch_inverse()
Inverse
torch_is_complex()
Is_complex
torch_is_floating_point()
Is_floating_point
torch_is_installed()
Verifies if torch is installed
torch_is_nonzero()
Is_nonzero
torch_isclose()
Isclose
torch_isfinite()
Isfinite
torch_isinf()
Isinf
torch_isnan()
Isnan
torch_isneginf()
Isneginf
torch_isposinf()
Isposinf
torch_isreal()
Isreal
torch_istft()
Istft
torch_kaiser_window()
Kaiser_window
torch_kron()
Kronecker product
torch_kthvalue()
Kthvalue
torch_strided() torch_sparse_coo()
Creates the corresponding layout
torch_lcm()
Lcm
torch_le()
Le
torch_lerp()
Lerp
torch_less()
Less
torch_less_equal()
Less_equal
torch_lgamma()
Lgamma
torch_log()
Log
torch_log10()
Log10
torch_log1p()
Log1p
torch_log2()
Log2
torch_logaddexp()
Logaddexp
torch_logaddexp2()
Logaddexp2
torch_logcumsumexp()
Logcumsumexp
torch_logdet()
Logdet
torch_logical_and()
Logical_and
torch_logical_not
Logical_not
torch_logical_or()
Logical_or
torch_logical_xor()
Logical_xor
torch_logit()
Logit
torch_logsumexp()
Logsumexp
torch_lstsq()
Lstsq
torch_lt()
Lt
torch_lu()
LU
torch_lu_solve()
Lu_solve
torch_lu_unpack()
Lu_unpack
torch_manual_seed()
Sets the seed for generating random numbers.
torch_masked_select()
Masked_select
torch_matmul()
Matmul
torch_matrix_exp()
Matrix_exp
torch_matrix_power()
Matrix_power
torch_matrix_rank()
Matrix_rank
torch_max
Max
torch_maximum()
Maximum
torch_mean()
Mean
torch_median()
Median
torch_contiguous_format() torch_preserve_format() torch_channels_last_format()
Memory format
torch_meshgrid()
Meshgrid
torch_min
Min
torch_minimum()
Minimum
torch_mm()
Mm
torch_mode()
Mode
torch_movedim()
Movedim
torch_mul()
Mul
torch_multinomial()
Multinomial
torch_multiply()
Multiply
torch_mv()
Mv
torch_mvlgamma()
Mvlgamma
torch_nanquantile()
Nanquantile
torch_nansum()
Nansum
torch_narrow()
Narrow
torch_ne()
Ne
torch_neg()
Neg
torch_negative()
Negative
torch_nextafter()
Nextafter
torch_nonzero()
Nonzero
torch_norm()
Norm
torch_normal()
Normal
torch_not_equal()
Not_equal
torch_orgqr()
Orgqr
torch_ormqr()
Ormqr
torch_outer()
Outer
torch_pdist()
Pdist
torch_pinverse()
Pinverse
torch_pixel_shuffle()
Pixel_shuffle
torch_poisson()
Poisson
torch_polar()
Polar
torch_polygamma()
Polygamma
torch_pow()
Pow
torch_prod()
Prod
torch_promote_types()
Promote_types
torch_qr()
Qr
torch_quantile()
Quantile
torch_quantize_per_channel()
Quantize_per_channel
torch_quantize_per_tensor()
Quantize_per_tensor
torch_rad2deg()
Rad2deg
torch_range()
Range
torch_real()
Real
torch_reciprocal()
Reciprocal
torch_relu()
Relu
torch_relu_()
Relu_
torch_remainder()
Remainder
torch_renorm()
Renorm
torch_repeat_interleave()
Repeat_interleave
torch_reshape()
Reshape
torch_result_type()
Result_type
torch_roll()
Roll
torch_rot90()
Rot90
torch_round()
Round
torch_rrelu_()
Rrelu_
torch_rsqrt()
Rsqrt
torch_scalar_tensor()
Scalar tensor
torch_searchsorted()
Searchsorted
torch_selu()
Selu
torch_selu_()
Selu_
torch_sgn()
Sgn
torch_sigmoid()
Sigmoid
torch_sign()
Sign
torch_signbit()
Signbit
torch_sin()
Sin
torch_sinh()
Sinh
torch_slogdet()
Slogdet
torch_sort
Sort
torch_sparse_coo_tensor()
Sparse_coo_tensor
torch_split()
Split
torch_sqrt()
Sqrt
torch_square()
Square
torch_squeeze()
Squeeze
torch_stack()
Stack
torch_std()
Std
torch_std_mean()
Std_mean
torch_stft()
Stft
torch_sub()
Sub
torch_subtract()
Subtract
torch_sum()
Sum
torch_svd()
Svd
torch_symeig()
Symeig
torch_t()
T
torch_take()
Take
torch_tan()
Tan
torch_tanh()
Tanh
torch_tensor()
Converts R objects to a torch tensor
torch_tensordot()
Tensordot
torch_threshold_()
Threshold_
torch_topk()
Topk
torch_trace()
Trace
torch_transpose()
Transpose
torch_trapz()
Trapz
torch_triangular_solve()
Triangular_solve
torch_tril()
Tril
torch_tril_indices()
Tril_indices
torch_triu()
Triu
torch_triu_indices()
Triu_indices
torch_true_divide()
TRUE_divide
torch_trunc()
Trunc
torch_unbind()
Unbind
torch_unique_consecutive()
Unique_consecutive
torch_unsafe_chunk()
Unsafe_chunk
torch_unsafe_split()
Unsafe_split
torch_unsqueeze()
Unsqueeze
torch_vander()
Vander
torch_var()
Var
torch_var_mean()
Var_mean
torch_vdot()
Vdot
torch_view_as_complex()
View_as_complex
torch_view_as_real()
View_as_real
torch_vstack()
Vstack
torch_where()
Where
broadcast_all()
Given a list of values (possibly containing numbers), returns a list where each value is broadcasted based on the following rules:

Neural network modules

nn_adaptive_avg_pool1d()
Applies a 1D adaptive average pooling over an input signal composed of several input planes.
nn_adaptive_avg_pool2d()
Applies a 2D adaptive average pooling over an input signal composed of several input planes.
nn_adaptive_avg_pool3d()
Applies a 3D adaptive average pooling over an input signal composed of several input planes.
nn_adaptive_log_softmax_with_loss()
AdaptiveLogSoftmaxWithLoss module
nn_adaptive_max_pool1d()
Applies a 1D adaptive max pooling over an input signal composed of several input planes.
nn_adaptive_max_pool2d()
Applies a 2D adaptive max pooling over an input signal composed of several input planes.
nn_adaptive_max_pool3d()
Applies a 3D adaptive max pooling over an input signal composed of several input planes.
nn_avg_pool1d()
Applies a 1D average pooling over an input signal composed of several input planes.
nn_avg_pool2d()
Applies a 2D average pooling over an input signal composed of several input planes.
nn_avg_pool3d()
Applies a 3D average pooling over an input signal composed of several input planes.
nn_batch_norm1d()
BatchNorm1D module
nn_batch_norm2d()
BatchNorm2D
nn_batch_norm3d()
BatchNorm3D
nn_bce_loss()
Binary cross entropy loss
nn_bce_with_logits_loss()
BCE with logits loss
nn_bilinear()
Bilinear module
nn_buffer()
Creates a nn_buffer
nn_celu()
CELU module
nn_contrib_sparsemax()
Sparsemax activation
nn_conv1d()
Conv1D module
nn_conv2d()
Conv2D module
nn_conv3d()
Conv3D module
nn_conv_transpose1d()
ConvTranspose1D
nn_conv_transpose2d()
ConvTranpose2D module
nn_conv_transpose3d()
ConvTranpose3D module
nn_cosine_embedding_loss()
Cosine embedding loss
nn_cross_entropy_loss()
CrossEntropyLoss module
nn_ctc_loss()
The Connectionist Temporal Classification loss.
nn_dropout()
Dropout module
nn_dropout2d()
Dropout2D module
nn_dropout3d()
Dropout3D module
nn_elu()
ELU module
nn_embedding()
Embedding module
nn_embedding_bag()
Embedding bag module
nn_flatten()
Flattens a contiguous range of dims into a tensor.
nn_fractional_max_pool2d()
Applies a 2D fractional max pooling over an input signal composed of several input planes.
nn_fractional_max_pool3d()
Applies a 3D fractional max pooling over an input signal composed of several input planes.
nn_gelu()
GELU module
nn_glu()
GLU module
nn_group_norm()
Group normalization
nn_gru()
Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
nn_hardshrink()
Hardshwink module
nn_hardsigmoid()
Hardsigmoid module
nn_hardswish()
Hardswish module
nn_hardtanh()
Hardtanh module
nn_hinge_embedding_loss()
Hinge embedding loss
nn_identity()
Identity module
nn_init_calculate_gain()
Calculate gain
nn_init_constant_()
Constant initialization
nn_init_dirac_()
Dirac initialization
nn_init_eye_()
Eye initialization
nn_init_kaiming_normal_()
Kaiming normal initialization
nn_init_kaiming_uniform_()
Kaiming uniform initialization
nn_init_normal_()
Normal initialization
nn_init_ones_()
Ones initialization
nn_init_orthogonal_()
Orthogonal initialization
nn_init_sparse_()
Sparse initialization
nn_init_trunc_normal_()
Truncated normal initialization
nn_init_uniform_()
Uniform initialization
nn_init_xavier_normal_()
Xavier normal initialization
nn_init_xavier_uniform_()
Xavier uniform initialization
nn_init_zeros_()
Zeros initialization
nn_kl_div_loss()
Kullback-Leibler divergence loss
nn_l1_loss()
L1 loss
nn_layer_norm()
Layer normalization
nn_leaky_relu()
LeakyReLU module
nn_linear()
Linear module
nn_log_sigmoid()
LogSigmoid module
nn_log_softmax()
LogSoftmax module
nn_lp_pool1d()
Applies a 1D power-average pooling over an input signal composed of several input planes.
nn_lp_pool2d()
Applies a 2D power-average pooling over an input signal composed of several input planes.
nn_lstm()
Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.
nn_margin_ranking_loss()
Margin ranking loss
nn_max_pool1d()
MaxPool1D module
nn_max_pool2d()
MaxPool2D module
nn_max_pool3d()
Applies a 3D max pooling over an input signal composed of several input planes.
nn_max_unpool1d()
Computes a partial inverse of MaxPool1d.
nn_max_unpool2d()
Computes a partial inverse of MaxPool2d.
nn_max_unpool3d()
Computes a partial inverse of MaxPool3d.
nn_module()
Base class for all neural network modules.
nn_module_list()
Holds submodules in a list.
nn_mse_loss()
MSE loss
nn_multi_margin_loss()
Multi margin loss
nn_multihead_attention()
MultiHead attention
nn_multilabel_margin_loss()
Multilabel margin loss
nn_multilabel_soft_margin_loss()
Multi label soft margin loss
nn_nll_loss()
Nll loss
nn_pairwise_distance()
Pairwise distance
nn_parameter()
Creates an nn_parameter
nn_poisson_nll_loss()
Poisson NLL loss
nn_prelu()
PReLU module
nn_prune_head()
Prune top layer(s) of a network
nn_relu()
ReLU module
nn_relu6()
ReLu6 module
nn_rnn()
RNN module
nn_rrelu()
RReLU module
nn_selu()
SELU module
nn_sequential()
A sequential container
nn_sigmoid()
Sigmoid module
nn_smooth_l1_loss()
Smooth L1 loss
nn_soft_margin_loss()
Soft margin loss
nn_softmax()
Softmax module
nn_softmax2d()
Softmax2d module
nn_softmin()
Softmin
nn_softplus()
Softplus module
nn_softshrink()
Softshrink module
nn_softsign()
Softsign module
nn_tanh()
Tanh module
nn_tanhshrink()
Tanhshrink module
nn_threshold()
Threshold module
nn_triplet_margin_loss()
Triplet margin loss
nn_triplet_margin_with_distance_loss()
Triplet margin with distance loss
nn_unflatten()
Unflattens a tensor dim expanding it to a desired shape. For use with [nn_sequential.
nn_upsample()
Upsample module
nn_utils_clip_grad_norm_()
Clips gradient norm of an iterable of parameters.
nn_utils_clip_grad_value_()
Clips gradient of an iterable of parameters at specified value.
nn_utils_rnn_pack_padded_sequence()
Packs a Tensor containing padded sequences of variable length.
nn_utils_rnn_pack_sequence()
Packs a list of variable length Tensors
nn_utils_rnn_pad_packed_sequence()
Pads a packed batch of variable length sequences.
nn_utils_rnn_pad_sequence()
Pad a list of variable length Tensors with padding_value
is_nn_module()
Checks if the object is an nn_module
is_nn_parameter()
Checks if an object is a nn_parameter
is_nn_buffer()
Checks if the object is a nn_buffer

Neural networks functional module

nnf_adaptive_avg_pool1d()
Adaptive_avg_pool1d
nnf_adaptive_avg_pool2d()
Adaptive_avg_pool2d
nnf_adaptive_avg_pool3d()
Adaptive_avg_pool3d
nnf_adaptive_max_pool1d()
Adaptive_max_pool1d
nnf_adaptive_max_pool2d()
Adaptive_max_pool2d
nnf_adaptive_max_pool3d()
Adaptive_max_pool3d
nnf_affine_grid()
Affine_grid
nnf_alpha_dropout()
Alpha_dropout
nnf_avg_pool1d()
Avg_pool1d
nnf_avg_pool2d()
Avg_pool2d
nnf_avg_pool3d()
Avg_pool3d
nnf_batch_norm()
Batch_norm
nnf_bilinear()
Bilinear
nnf_binary_cross_entropy()
Binary_cross_entropy
nnf_binary_cross_entropy_with_logits()
Binary_cross_entropy_with_logits
nnf_celu() nnf_celu_()
Celu
nnf_contrib_sparsemax()
Sparsemax
nnf_conv1d()
Conv1d
nnf_conv2d()
Conv2d
nnf_conv3d()
Conv3d
nnf_conv_tbc()
Conv_tbc
nnf_conv_transpose1d()
Conv_transpose1d
nnf_conv_transpose2d()
Conv_transpose2d
nnf_conv_transpose3d()
Conv_transpose3d
nnf_cosine_embedding_loss()
Cosine_embedding_loss
nnf_cosine_similarity()
Cosine_similarity
nnf_cross_entropy()
Cross_entropy
nnf_ctc_loss()
Ctc_loss
nnf_dropout()
Dropout
nnf_dropout2d()
Dropout2d
nnf_dropout3d()
Dropout3d
nnf_elu() nnf_elu_()
Elu
nnf_embedding()
Embedding
nnf_embedding_bag()
Embedding_bag
nnf_fold()
Fold
nnf_fractional_max_pool2d()
Fractional_max_pool2d
nnf_fractional_max_pool3d()
Fractional_max_pool3d
nnf_gelu()
Gelu
nnf_glu()
Glu
nnf_grid_sample()
Grid_sample
nnf_group_norm()
Group_norm
nnf_gumbel_softmax()
Gumbel_softmax
nnf_hardshrink()
Hardshrink
nnf_hardsigmoid()
Hardsigmoid
nnf_hardswish()
Hardswish
nnf_hardtanh() nnf_hardtanh_()
Hardtanh
nnf_hinge_embedding_loss()
Hinge_embedding_loss
nnf_instance_norm()
Instance_norm
nnf_interpolate()
Interpolate
nnf_kl_div()
Kl_div
nnf_l1_loss()
L1_loss
nnf_layer_norm()
Layer_norm
nnf_leaky_relu()
Leaky_relu
nnf_linear()
Linear
nnf_local_response_norm()
Local_response_norm
nnf_log_softmax()
Log_softmax
nnf_logsigmoid()
Logsigmoid
nnf_lp_pool1d()
Lp_pool1d
nnf_lp_pool2d()
Lp_pool2d
nnf_margin_ranking_loss()
Margin_ranking_loss
nnf_max_pool1d()
Max_pool1d
nnf_max_pool2d()
Max_pool2d
nnf_max_pool3d()
Max_pool3d
nnf_max_unpool1d()
Max_unpool1d
nnf_max_unpool2d()
Max_unpool2d
nnf_max_unpool3d()
Max_unpool3d
nnf_mse_loss()
Mse_loss
nnf_multi_head_attention_forward()
Multi head attention forward
nnf_multi_margin_loss()
Multi_margin_loss
nnf_multilabel_margin_loss()
Multilabel_margin_loss
nnf_multilabel_soft_margin_loss()
Multilabel_soft_margin_loss
nnf_nll_loss()
Nll_loss
nnf_normalize()
Normalize
nnf_one_hot()
One_hot
nnf_pad()
Pad
nnf_pairwise_distance()
Pairwise_distance
nnf_pdist()
Pdist
nnf_pixel_shuffle()
Pixel_shuffle
nnf_poisson_nll_loss()
Poisson_nll_loss
nnf_prelu()
Prelu
nnf_relu() nnf_relu_()
Relu
nnf_relu6()
Relu6
nnf_rrelu() nnf_rrelu_()
Rrelu
nnf_selu() nnf_selu_()
Selu
nnf_sigmoid()
Sigmoid
nnf_smooth_l1_loss()
Smooth_l1_loss
nnf_soft_margin_loss()
Soft_margin_loss
nnf_softmax()
Softmax
nnf_softmin()
Softmin
nnf_softplus()
Softplus
nnf_softshrink()
Softshrink
nnf_softsign()
Softsign
nnf_tanhshrink()
Tanhshrink
nnf_threshold() nnf_threshold_()
Threshold
nnf_triplet_margin_loss()
Triplet_margin_loss
nnf_triplet_margin_with_distance_loss()
Triplet margin with distance loss
nnf_unfold()
Unfold

Optimizers

optimizer()
Creates a custom optimizer
optim_adadelta()
Adadelta optimizer
optim_adagrad()
Adagrad optimizer
optim_adam()
Implements Adam algorithm.
optim_asgd()
Averaged Stochastic Gradient Descent optimizer
optim_lbfgs()
LBFGS optimizer
optim_required()
Dummy value indicating a required value.
optim_rmsprop()
RMSprop optimizer
optim_rprop()
Implements the resilient backpropagation algorithm.
optim_sgd()
SGD optimizer
is_optimizer()
Checks if the object is a torch optimizer

Learning rate schedulers

lr_lambda()
Sets the learning rate of each parameter group to the initial lr times a given function. When last_epoch=-1, sets initial lr as lr.
lr_multiplicative()
Multiply the learning rate of each parameter group by the factor given in the specified function. When last_epoch=-1, sets initial lr as lr.
lr_one_cycle()
Once cycle learning rate
lr_reduce_on_plateau()
Reduce learning rate on plateau
lr_scheduler()
Creates learning rate schedulers
lr_step()
Step learning rate decay

Datasets

dataset()
Helper function to create an function that generates R6 instances of class dataset
dataset_subset()
Dataset Subset
dataloader()
Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.
dataloader_make_iter()
Creates an iterator from a DataLoader
dataloader_next()
Get the next element of a dataloader iterator
enumerate()
Enumerate an iterator
enumerate(<dataloader>)
Enumerate an iterator
tensor_dataset()
Dataset wrapping tensors.
is_dataloader()
Checks if the object is a dataloader
sampler()
Creates a new Sampler

Distributions

Distribution
Generic R6 class representing distributions
distr_bernoulli()
Creates a Bernoulli distribution parameterized by probs or logits (but not both). Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.
distr_categorical()
Creates a categorical distribution parameterized by either probs or logits (but not both).
distr_chi2()
Creates a Chi2 distribution parameterized by shape parameter df. This is exactly equivalent to distr_gamma(alpha=0.5*df, beta=0.5)
distr_gamma()
Creates a Gamma distribution parameterized by shape concentration and rate.
distr_mixture_same_family()
Mixture of components in the same family
distr_multivariate_normal()
Gaussian distribution
distr_normal()
Creates a normal (also called Gaussian) distribution parameterized by loc and scale.
distr_poisson()
Creates a Poisson distribution parameterized by rate, the rate parameter.
Constraint
Abstract base class for constraints.

Autograd

autograd_backward()
Computes the sum of gradients of given tensors w.r.t. graph leaves.
autograd_function()
Records operation history and defines formulas for differentiating ops.
autograd_grad()
Computes and returns the sum of gradients of outputs w.r.t. the inputs.
autograd_set_grad_mode()
Set grad mode
with_no_grad()
Temporarily modify gradient recording.
with_enable_grad()
Enable grad
with_detect_anomaly()
Context-manager that enable anomaly detection for the autograd engine.
AutogradContext
Class representing the context.

Linear Algebra

linalg_cholesky()
Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.
linalg_cholesky_ex()
Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.
linalg_cond()
Computes the condition number of a matrix with respect to a matrix norm.
linalg_det()
Computes the determinant of a square matrix.
linalg_eig()
Computes the eigenvalue decomposition of a square matrix if it exists.
linalg_eigh()
Computes the eigenvalue decomposition of a complex Hermitian or real symmetric matrix.
linalg_eigvals()
Computes the eigenvalues of a square matrix.
linalg_eigvalsh()
Computes the eigenvalues of a complex Hermitian or real symmetric matrix.
linalg_householder_product()
Computes the first n columns of a product of Householder matrices.
linalg_inv()
Computes the inverse of a square matrix if it exists.
linalg_inv_ex()
Computes the inverse of a square matrix if it is invertible.
linalg_lstsq()
Computes a solution to the least squares problem of a system of linear equations.
linalg_matrix_norm()
Computes a matrix norm.
linalg_matrix_power()
Computes the n-th power of a square matrix for an integer n.
linalg_matrix_rank()
Computes the numerical rank of a matrix.
linalg_multi_dot()
Efficiently multiplies two or more matrices
linalg_norm()
Computes a vector or matrix norm.
linalg_pinv()
Computes the pseudoinverse (Moore-Penrose inverse) of a matrix.
linalg_qr()
Computes the QR decomposition of a matrix.
linalg_slogdet()
Computes the sign and natural logarithm of the absolute value of the determinant of a square matrix.
linalg_solve()
Computes the solution of a square system of linear equations with a unique solution.
linalg_svd()
Computes the singular value decomposition (SVD) of a matrix.
linalg_svdvals()
Computes the singular values of a matrix.
linalg_tensorinv()
Computes the multiplicative inverse of torch_tensordot()
linalg_tensorsolve()
Computes the solution X to the system torch_tensordot(A, X) = B.
linalg_vector_norm()
Computes a vector norm.

Cuda utilities

cuda_current_device()
Returns the index of a currently selected device.
cuda_device_count()
Returns the number of GPUs available.
cuda_empty_cache()
Empty cache
cuda_get_device_capability()
Returns the major and minor CUDA capability of device
cuda_is_available()
Returns a bool indicating if CUDA is currently available.
cuda_memory_stats() cuda_memory_summary()
Returns a dictionary of CUDA memory allocator statistics for a given device.
cuda_runtime_version()
Returns the CUDA runtime version
cuda_synchronize()
Waits for all kernels in all streams on a CUDA device to complete.

JIT

jit_compile()
Compile TorchScript code into a graph
jit_load()
Loads a script_function or script_module previously saved with jit_save
jit_save()
Saves a script_function to a path
jit_save_for_mobile()
Saves a script_function or script_module in bytecode form, to be loaded on a mobile device
jit_scalar()
Adds the 'jit_scalar' class to the input
jit_trace()
Trace a function and return an executable script_function.
jit_trace_module()
Trace a module
jit_tuple()
Adds the 'jit_tuple' class to the input

Backends

backends_cudnn_is_available()
CuDNN is available
backends_cudnn_version()
CuDNN version
backends_mkl_is_available()
MKL is available
backends_mkldnn_is_available()
MKLDNN is available
backends_mps_is_available()
MPS is available
backends_openmp_is_available()
OpenMP is available

Installation

install_torch()
Install Torch
get_install_libs_url() install_torch_from_file()
Install Torch from files

Contrib

contrib_sort_vertices()
Contrib sort vertices