Skip to contents

Trace a module and return an executable ScriptModule that will be optimized using just-in-time compilation. When a module is passed to jit_trace(), only the forward method is run and traced. With jit_trace_module(), you can specify a named list of method names to example inputs to trace (see the inputs) argument below.

Usage

jit_trace_module(mod, ..., strict = TRUE)

Arguments

mod

A torch nn_module() containing methods whose names are specified in inputs. The given methods will be compiled as a part of a single ScriptModule.

...

A named list containing sample inputs indexed by method names in mod. The inputs will be passed to methods whose names correspond to inputs keys while tracing. list('forward'=example_forward_input, 'method2'=example_method2_input).

strict

run the tracer in a strict mode or not (default: TRUE). Only turn this off when you want the tracer to record your mutable container types (currently list/dict) and you are sure that the container you are using in your problem is a constant structure and does not get used as control flow (if, for) conditions.

Details

See jit_trace for more information on tracing.

Examples

if (torch_is_installed()) {
linear <- nn_linear(10, 1)
tr_linear <- jit_trace_module(linear, forward = list(torch_randn(10, 10)))

x <- torch_randn(10, 10)
torch_allclose(linear(x), tr_linear(x))
}
#> [1] TRUE