Module Rune
Functional transformations for Nx tensors.
Rune provides automatic differentiation (forward and reverse mode), vectorising maps, and gradient checking. It operates by intercepting Nx tensor operations via OCaml 5 effect handlers — no special tensor type is needed.
Terminology.
- Primal: the input value at which a derivative is evaluated.
- Tangent: the directional derivative seed (forward mode).
- Cotangent: the adjoint seed propagated backward (reverse mode).
- JVP: Jacobian-vector product (forward-mode AD).
- VJP: vector-Jacobian product (reverse-mode AD).
Reverse-mode AD
Compute gradients of scalar-valued functions via reverse-mode (backpropagation). The function f must return a scalar tensor; the gradient has the same shape as the input.
grad f x is the gradient of scalar-valued f at x.
Equivalent to snd (value_and_grad f x).
See also grads, value_and_grad.
grads f xs is the list of gradients of scalar-valued f with respect to each tensor in xs. The i-th element of the result has the same shape as the i-th element of xs.
See also grad, value_and_grads.
val value_and_grad :
(('a, 'b) Nx.t -> ('c, 'd) Nx.t) ->
('a, 'b) Nx.t ->
('c, 'd) Nx.t * ('a, 'b) Nx.tvalue_and_grad f x is (f x, grad f x), computed in a single forward-backward pass.
See also value_and_grad_aux.
val value_and_grad_aux :
(('a, 'b) Nx.t -> ('c, 'd) Nx.t * 'e) ->
('a, 'b) Nx.t ->
('c, 'd) Nx.t * ('a, 'b) Nx.t * 'evalue_and_grad_aux f x is (y, g, aux) where (y, aux) = f x and g is the gradient of y with respect to x. The auxiliary output aux is carried through but not differentiated.
See also value_and_grads_aux.
val value_and_grads :
(('a, 'b) Nx.t list -> ('c, 'd) Nx.t) ->
('a, 'b) Nx.t list ->
('c, 'd) Nx.t * ('a, 'b) Nx.t listvalue_and_grads f xs is (f xs, grads f xs), computed in a single forward-backward pass.
See also value_and_grads_aux.
val value_and_grads_aux :
(('a, 'b) Nx.t list -> ('c, 'd) Nx.t * 'e) ->
('a, 'b) Nx.t list ->
('c, 'd) Nx.t * ('a, 'b) Nx.t list * 'evalue_and_grads_aux f xs is (y, gs, aux) where (y, aux) = f xs and gs is the list of gradients of y with respect to each tensor in xs. The auxiliary output aux is carried through but not differentiated.
See also value_and_grad_aux.
val vjps :
(('a, 'b) Nx.t list -> ('c, 'd) Nx.t) ->
('a, 'b) Nx.t list ->
('c, 'd) Nx.t ->
('c, 'd) Nx.t * ('a, 'b) Nx.t listvjps f xs v is like vjp for functions with multiple inputs. Returns (y, gs) where each gradient in gs corresponds to one input in xs.
Forward-mode AD
Compute Jacobian-vector products by propagating tangent vectors alongside primal values. Forward mode is efficient when the number of inputs is small relative to the number of outputs.
val jvp_aux :
(('a, 'b) Nx.t -> ('c, 'd) Nx.t * 'e) ->
('a, 'b) Nx.t ->
('a, 'b) Nx.t ->
('c, 'd) Nx.t * ('c, 'd) Nx.t * 'ejvp_aux f x v is like jvp but for functions with auxiliary output. Returns (y, t, aux) where aux is carried through but not differentiated.
val jvps :
(('a, 'b) Nx.t list -> ('c, 'd) Nx.t) ->
('a, 'b) Nx.t list ->
('a, 'b) Nx.t list ->
('c, 'd) Nx.t * ('c, 'd) Nx.tjvps f xs vs is like jvp for functions with multiple inputs. Each tangent in vs must have the same shape as the corresponding primal in xs.
Stopping gradients
no_grad f evaluates f () without recording operations for automatic differentiation. All tensors produced inside f are treated as constants by enclosing gradient computations.
detach x is a copy of x that is treated as a constant with respect to automatic differentiation.
See also no_grad.
Gradient checking
Compare autodiff gradients against finite-difference approximations. Useful for testing custom operations.
The type for finite difference methods.
`Central—(f(x+h) - f(x-h)) / 2h. Most accurate, requires two evaluations per element.`Forward—(f(x+h) - f(x)) / h.`Backward—(f(x) - f(x-h)) / h.
val finite_diff :
?eps:float ->
?method_:method_ ->
(('a, 'b) Nx.t -> ('c, 'd) Nx.t) ->
('a, 'b) Nx.t ->
('a, 'b) Nx.tfinite_diff f x is the gradient of scalar-valued f at x approximated by finite differences.
eps defaults to 1e-4. method_ defaults to `Central.
val finite_diff_jacobian :
?eps:float ->
?method_:method_ ->
(('a, 'b) Nx.t -> ('c, 'd) Nx.t) ->
('a, 'b) Nx.t ->
('c, 'd) Nx.tfinite_diff_jacobian f x is the Jacobian of f at x approximated by finite differences.
eps defaults to 1e-4. method_ defaults to `Central.
type gradient_check_result = {max_abs_error : float;(*Largest absolute error across all elements.
*)max_rel_error : float;(*Largest relative error across all elements.
*)mean_abs_error : float;(*Mean absolute error.
*)mean_rel_error : float;(*Mean relative error.
*)failed_indices : (int array * float * float * float) list;(*
*)(index, autodiff, finite_diff, abs_error)for each failed element.passed : bool;(*
*)trueiff no element exceeded the tolerances.num_checked : int;(*Number of elements checked.
*)num_failed : int;(*Number of elements that exceeded tolerances.
*)
}The type for gradient check results.
val check_gradient :
?eps:float ->
?rtol:float ->
?atol:float ->
?verbose:bool ->
?check_indices:int list option ->
?method_:[ `Central | `Forward | `Backward ] ->
((float, 'a) Nx.t -> ('b, 'c) Nx.t) ->
(float, 'a) Nx.t ->
[ `Pass of gradient_check_result | `Fail of gradient_check_result ]check_gradient f x compares the autodiff gradient of f at x against a finite-difference approximation.
An element passes when abs_error <= atol or rel_error <= rtol.
epsdefaults to1e-4.rtoldefaults to2e-3.atoldefaults to2e-3.verbosedefaults tofalse. Whentrue, prints per-element failures and a summary to standard output.check_indicesdefaults toNone(check all elements). WhenSome indices, only the listed flat indices are checked.method_defaults to`Central.
See also check_gradients.
val check_gradients :
?eps:float ->
?rtol:float ->
?atol:float ->
?verbose:bool ->
?method_:[ `Central | `Forward | `Backward ] ->
((float, 'a) Nx.t list -> ('b, 'c) Nx.t) ->
(float, 'a) Nx.t list ->
[ `Pass of gradient_check_result list | `Fail of gradient_check_result list ]check_gradients f xs is like check_gradient for functions with multiple inputs. Returns one gradient_check_result per input tensor.
Optional parameters have the same defaults as check_gradient.
Vectorising map
Map a computation over a batch dimension. vmap transforms a function that operates on single examples into one that operates on batches, without the user writing explicit batch loops.
type 'a in_axes_spec = | Single of axis_spec(*Apply to all inputs.
*)| Container of 'a(*Per-input specifications.
*)
The type for input axis specifications.
val vmap :
?in_axes:'a in_axes_spec ->
?out_axes:'b out_axes_spec ->
?axis_name:string ->
?axis_size:int ->
(('c, 'd) Nx.t -> ('e, 'f) Nx.t) ->
('c, 'd) Nx.t ->
('e, 'f) Nx.tvmap f x is a vectorised version of f applied to x.
in_axesdefaults toSingle (Map 0).out_axesdefaults toOutSingle (Some 0).axis_nameis an optional label for the mapped axis (used in error messages).axis_sizeoverrides the batch size inferred from the input shape. Required when all inputs useNoMap.
See also vmaps.
val vmaps :
?in_axes:Rune__.Vmap.axis_spec list ->
?out_axes:'b Rune__.Vmap.out_axes_spec ->
?axis_name:string ->
?axis_size:int ->
(('c, 'd) Nx.t list -> ('e, 'f) Nx.t) ->
('c, 'd) Nx.t list ->
('e, 'f) Nx.tvmaps f xs is like vmap for functions with multiple inputs. Each element of in_axes corresponds to one input in xs.
in_axes defaults to Map 0 for every input.