Note: This is an R port of the official tutorial available here. All credits goes to Justin Johnson.

As an example of dynamic graphs and weight sharing, we implement a very strange model: a fully-connected ReLU network that on each forward pass chooses a random number between 1 and 4 and uses that many hidden layers, reusing the same weights multiple times to compute the innermost hidden layers.

For this model we can use normal R flow control to implement the loop, and we can implement weight sharing among the innermost layers by simply reusing the same Module multiple times when defining the forward pass.

We can easily implement this model using nn_module:

dynamic_net <- nn_module(
   "dynamic_net",
   # In the constructor we construct three nn_linear instances that we will use
   # in the forward pass.
   initialize = function(D_in, H, D_out) {
      self$input_linear <- nn_linear(D_in, H)
      self$middle_linear <- nn_linear(H, H)
      self$output_linear <- nn_linear(H, D_out)
   },
   # For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
   # and reuse the middle_linear Module that many times to compute hidden layer
   # representations.
   # 
   # Since each forward pass builds a dynamic computation graph, we can use normal
   # R control-flow operators like loops or conditional statements when
   # defining the forward pass of the model.
   # 
   # Here we also see that it is perfectly safe to reuse the same Module many
   # times when defining a computational graph. This is a big improvement from Lua
   # Torch, where each Module could be used only once.
   forward = function(x) {
      h_relu <- self$input_linear(x)$clamp(min = 0)
      for (i in seq_len(sample.int(4, size = 1))) {
         h_relu <- self$middle_linear(h_relu)$clamp(min=0)
      }
      y_pred <- self$output_linear(h_relu)
      y_pred
   }
)


if (cuda_is_available()) {
   device <- torch_device("cuda")
} else {
   device <- torch_device("cpu")
}
   
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N <- 64
D_in <- 1000
H <- 100
D_out <- 10

# Create random input and output data
# Setting requires_grad=FALSE (the default) indicates that we do not need to 
# compute gradients with respect to these Tensors during the backward pass.
x <- torch_randn(N, D_in, device=device)
y <- torch_randn(N, D_out, device=device)

# Construct our model by instantiating the class defined above
model <- dynamic_net(D_in, H, D_out)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn <- nnf_mse_loss

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algorithms. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate <- 1e-4
optimizer <- optim_sgd(model$parameters, lr=learning_rate, momentum = 0.9)

for (t in seq_len(500)) {
   # Forward pass: compute predicted y by passing x to the model. Module objects
   # can be called like functions. When doing so you pass a Tensor of input
   # data to the Module and it produces a Tensor of output data.
   y_pred <- model(x)
   
   # Compute and print loss. We pass Tensors containing the predicted and true
   # values of y, and the loss function returns a Tensor containing the
   # loss.
   loss <- loss_fn(y_pred, y)
   if (t %% 100 == 0 || t == 1)
      cat("Step:", t, ":", as.numeric(loss), "\n")
   
   # Before the backward pass, use the optimizer object to zero all of the
   # gradients for the variables it will update (which are the learnable
   # weights of the model). This is because by default, gradients are
   # accumulated in buffers( i.e, not overwritten) whenever $backward()
   # is called. Checkout docs of `autograd_backward` for more details.
   optimizer$zero_grad()

   # Backward pass: compute gradient of the loss with respect to model
   # parameters
   loss$backward()

   # Calling the step function on an Optimizer makes an update to its
   # parameters
   optimizer$step()
}
#> Step: 1 : 0.984449 
#> Step: 100 : 0.9808377 
#> Step: 200 : 0.9813455 
#> Step: 300 : 0.9803184 
#> Step: 400 : 0.9797473 
#> Step: 500 : 0.9786749