Transformations

Rune provides functional transformations that operate on Nx tensor functions. This guide covers every transformation available.

Reverse-Mode AD

Reverse-mode AD (backpropagation) is efficient when you have many inputs and a scalar output — the typical case in machine learning.

grad

grad f returns a function that computes the gradient of scalar-valued f.

open Nx
open Rune

let () =
  let f x = sum (mul x x) in
  let df = grad f in
  let x = create Float32 [|3|] [|1.; 2.; 3.|] in
  print_data (df x)
  (* gradient: [2. 4. 6.] *)

grads

grads differentiates with respect to multiple inputs:

open Nx
open Rune

let () =
  let f inputs =
    match inputs with
    | [x; y] -> sum (add (mul x x) (mul y y))
    | _ -> assert false
  in
  let gs = grads f [scalar Float32 3.0; scalar Float32 4.0] in
  List.iter (fun g -> Printf.printf "%.1f " (item [] g)) gs
  (* 6.0 8.0 *)

value_and_grad

Computes both the function value and gradient in a single forward-backward pass, avoiding redundant computation:

let loss, gradient = value_and_grad loss_fn params

value_and_grads does the same for multiple inputs.

value_and_grad_aux

When your function returns auxiliary data alongside the loss (e.g., predictions, metrics), use the _aux variants to carry it through without differentiating it:

let f x =
  let pred = forward_pass x in
  let loss = compute_loss pred in
  (loss, pred)  (* pred is auxiliary — not differentiated *)

let loss, gradient, pred = value_and_grad_aux f x

value_and_grads_aux does the same for multiple inputs.

vjp

Vector-Jacobian product. Unlike grad, the function does not need to return a scalar — you provide a cotangent vector:

open Nx
open Rune

let () =
  let f x = mul x x in
  let x = create Float32 [|3|] [|1.; 2.; 3.|] in
  let v = ones Float32 [|3|] in
  let y, g = vjp f x v in
  print_data y;  (* [1. 4. 9.] *)
  print_data g   (* [2. 4. 6.] *)

vjps handles multiple inputs.

Forward-Mode AD

Forward-mode AD propagates tangent vectors alongside primal values. It is efficient when the number of inputs is small relative to the number of outputs.

jvp

Jacobian-vector product. Provide a tangent vector with the same shape as the input:

open Nx
open Rune

let () =
  let f x = mul x x in
  let x = create Float32 [|3|] [|1.; 2.; 3.|] in
  let v = ones Float32 [|3|] in
  let y, tangent = jvp f x v in
  print_data y;       (* [1. 4. 9.] — primal *)
  print_data tangent  (* [2. 4. 6.] — directional derivative *)

jvps handles multiple inputs. jvp_aux carries auxiliary outputs.

Choosing Between Forward and Reverse Mode

  • Reverse mode (grad, vjp): One backward pass gives gradients for all inputs. Best when outputs << inputs (typical in ML: scalar loss, many parameters).
  • Forward mode (jvp): One forward pass gives one directional derivative. Best when inputs << outputs (e.g., sensitivity analysis with few parameters).

Stopping Gradients

no_grad

Evaluate a computation without recording it for differentiation:

let baseline = no_grad (fun () ->
  mean predictions
)

Everything computed inside no_grad is treated as a constant by enclosing gradient computations.

detach

Make a single tensor a constant:

let target = detach current_value
(* target has the same values but is not differentiated *)

Vectorising Map

vmap

vmap transforms a function that operates on single examples into one that operates on batches:

(* Function that works on a single vector *)
let f x = sum (mul x x)

(* Automatically batched: maps over axis 0 of the input *)
let f_batched = vmap f

(* Process a batch of 10 vectors at once *)
let batch = rand Float32 [|10; 5|] in
let results = f_batched batch
(* results has shape [|10|] — one scalar per example *)

By default, vmap maps over axis 0 of inputs and stacks outputs on axis 0. You can customize this:

(* Map over axis 1 instead *)
let f_axis1 = vmap ~in_axes:(Single (Map 1)) f

(* Don't map an input (broadcast it) *)
let f_shared = vmap ~in_axes:(Single NoMap) f

vmaps handles functions with multiple inputs, with per-input axis specifications.

Composing vmap with grad

Since transformations are composable, you can compute per-example gradients:

(* Per-example gradient (no manual batching needed) *)
let per_example_grad = vmap (grad loss_fn)

Gradient Checking

Rune provides utilities for verifying that autodiff gradients are correct by comparing them against finite-difference approximations.

finite_diff

Approximate the gradient using finite differences:

open Nx
open Rune

let () =
  let f x = sum (mul x x) in
  let x = create Float32 [|3|] [|1.; 2.; 3.|] in
  let fd_grad = finite_diff f x in
  let ad_grad = grad f x in
  print_data fd_grad;
  print_data ad_grad
  (* both approximately [2. 4. 6.] *)

The default method is central differences ((f(x+h) - f(x-h)) / 2h). You can choose Forward or Backward methods and adjust eps (default 1e-4).

check_gradient

Automated comparison of autodiff vs finite-difference gradients:

match check_gradient ~verbose:true my_function x with
| `Pass result -> Printf.printf "max error: %e\n" result.max_abs_error
| `Fail result ->
  Printf.printf "%d of %d elements failed\n"
    result.num_failed result.num_checked

check_gradients handles functions with multiple inputs.

Debugging

debug

Print every tensor operation as it executes:

let () =
  let f x = add (mul x x) (sin x) in
  let x = scalar Float32 2.0 in
  let _ = debug f x in
  ()
(* Prints each operation, its inputs, and its output *)

This is useful for understanding what operations a function performs, especially when debugging unexpected gradients.

Summary

Transform Purpose When to use
grad Gradient of scalar function Training loss → parameter gradients
value_and_grad Value + gradient together Avoid duplicate forward pass
vjp Vector-Jacobian product Non-scalar outputs
jvp Jacobian-vector product Few inputs, many outputs
vmap Vectorise over a batch dimension Per-example computation
no_grad / detach Stop gradient propagation Baselines, targets, constants
check_gradient Verify gradient correctness Testing custom operations
debug Trace all operations Understanding/debugging