Rune vs. JAX -- A Practical Comparison
This guide explains how Rune's functional transformations relate to JAX, focusing on:
- How core concepts map (grad, vjp, jvp, vmap)
- Where the APIs feel similar vs. deliberately different
- How to translate common JAX patterns into Rune
If you already use JAX, this should be enough to become productive in Rune quickly.
1. Big-Picture Differences
| Aspect | JAX (Python) | Rune (OCaml) |
|---|---|---|
| Language | Dynamic, interpreted | Statically typed, compiled |
| Array type | jax.Array |
Nx.t (no separate Rune tensor type) |
| Array library | jax.numpy |
Nx |
| AD mechanism | Tracing + XLA compilation | OCaml 5 effect handlers |
| Reverse-mode AD | jax.grad, jax.value_and_grad |
grad, value_and_grad, grads, value_and_grads |
| Forward-mode AD | jax.jvp |
jvp, jvps |
| VJP | jax.vjp |
vjp, vjps |
| Vectorising map | jax.vmap |
vmap, vmaps |
| JIT compilation | jax.jit |
Not yet implemented |
| Device placement | jax.device_put, device kwarg |
Not yet implemented |
| Gradient stopping | jax.lax.stop_gradient |
no_grad, detach |
| Gradient checking | jax.test_util.check_grads |
check_gradient, check_gradients |
| Debugging | jax.debug.print |
debug |
| Control flow | Restricted inside jit (requires lax.cond, lax.scan) |
Full OCaml control flow (if, match, loops, recursion) |
| Mutability | Immutable arrays; functional updates | Immutable Nx tensors; same model |
Key things to know:
- Rune operates on
Nx.tdirectly. There is no separate tensor type, norune.numpy, and no tracing step. - Because Rune uses effect handlers rather than tracing, ordinary OCaml control flow works inside differentiated functions. No need for
lax.condorlax.scan. - JIT compilation and device/GPU placement do not exist yet. All computation runs eagerly on CPU via the Nx C backend.
2. Reverse-Mode AD (grad)
2.1 Basic gradient
JAX
import jax
import jax.numpy as jnp
def f(x):
return jnp.sum(x ** 2)
grad_f = jax.grad(f)
x = jnp.array([1.0, 2.0, 3.0])
print(grad_f(x)) # [2. 4. 6.]
Rune
open Nx
open Rune
let () =
let f x = sum (mul x x) in
let grad_f = grad f in
let x = create Float32 [|3|] [|1.; 2.; 3.|] in
print_data (grad_f x)
(* [2. 4. 6.] *)
Both jax.grad and Rune.grad take a function and return a new function that computes the gradient. The input function must return a scalar.
2.2 Value and gradient
JAX
loss, grads = jax.value_and_grad(loss_fn)(params)
Rune
let loss, gradient = value_and_grad loss_fn params
Both avoid computing the forward pass twice.
2.3 Multiple inputs
JAX
def f(x, y):
return jnp.sum(x ** 2 + y ** 2)
# argnums selects which arguments to differentiate
dx, dy = jax.grad(f, argnums=(0, 1))(x, y)
Rune
open Nx
open Rune
let () =
let f inputs =
match inputs with
| [x; y] -> sum (add (mul x x) (mul y y))
| _ -> assert false
in
let gs = grads f [scalar Float32 3.0; scalar Float32 4.0] in
List.iter (fun g -> Printf.printf "%.1f " (item [] g)) gs
(* 6.0 8.0 *)
JAX uses argnums to select which positional arguments to differentiate. Rune takes a function of Nx.t list and differentiates with respect to all inputs. value_and_grads combines both:
let loss, gradients = value_and_grads loss_fn [w; b]
2.4 Auxiliary outputs
JAX
def f(x):
pred = model(x)
loss = compute_loss(pred)
return loss, pred # pred is auxiliary
(loss, pred), grads = jax.value_and_grad(f, has_aux=True)(x)
Rune
let f x =
let pred = forward_pass x in
let loss = compute_loss pred in
(loss, pred) (* pred is auxiliary -- not differentiated *)
let loss, gradient, pred = value_and_grad_aux f x
JAX uses a has_aux=True flag. Rune has dedicated _aux variants: value_and_grad_aux and value_and_grads_aux.
2.5 Higher-order derivatives
JAX
f = lambda x: x ** 4
f_prime = jax.grad(f)
f_double_prime = jax.grad(f_prime)
Rune
open Nx
open Rune
let () =
let f x = mul x (mul x (mul x x)) in
let f' = grad f in
let f'' = grad f' in
let f''' = grad f'' in
let x = scalar Float32 2.0 in
Printf.printf "f(2) = %.1f\n" (item [] (f x));
Printf.printf "f'(2) = %.1f\n" (item [] (f' x));
Printf.printf "f''(2) = %.1f\n" (item [] (f'' x));
Printf.printf "f'''(2) = %.1f\n" (item [] (f''' x))
Both compose naturally because grad returns an ordinary function.
3. VJP (Vector-Jacobian Product)
JAX
def f(x):
return x ** 2
primals, vjp_fn = jax.vjp(f, x)
grads = vjp_fn(v)
Rune
open Nx
open Rune
let () =
let f x = mul x x in
let x = create Float32 [|3|] [|1.; 2.; 3.|] in
let v = ones Float32 [|3|] in
let y, g = vjp f x v in
print_data y; (* [1. 4. 9.] *)
print_data g (* [2. 4. 6.] *)
In JAX, jax.vjp returns a closure vjp_fn that you call with the cotangent. In Rune, vjp f x v takes the cotangent v directly and returns (y, g) in one call.
For multiple inputs, JAX still uses positional arguments while Rune uses vjps with a list:
let y, gs = vjps f [x1; x2] v
4. Forward-Mode AD (JVP)
JAX
def f(x):
return x ** 2
primals, tangents = jax.jvp(f, (x,), (v,))
Rune
open Nx
open Rune
let () =
let f x = mul x x in
let x = create Float32 [|3|] [|1.; 2.; 3.|] in
let v = ones Float32 [|3|] in
let y, tangent = jvp f x v in
print_data y; (* [1. 4. 9.] -- primal *)
print_data tangent (* [2. 4. 6.] -- directional derivative *)
The API shape is nearly identical. JAX takes tuples of primals and tangents; Rune takes them as separate arguments.
For multiple inputs:
let y, tangent = jvps f [x1; x2] [v1; v2]
jvp_aux carries auxiliary outputs:
let y, tangent, aux = jvp_aux f x v
5. Stopping Gradients
JAX
import jax.lax
def f(x):
baseline = jax.lax.stop_gradient(running_mean)
return loss(x) - baseline
Rune
There are two options:
(* Option 1: detach a single tensor *)
let baseline = detach running_mean
(* Option 2: block an entire computation *)
let baseline = no_grad (fun () ->
mean predictions
)
JAX has a single stop_gradient that operates on arrays. Rune offers two mechanisms:
detach xreturns a copy ofxthat is treated as a constant during differentiation. Closest tojax.lax.stop_gradient.no_grad frunsf ()without recording any operations. Useful when a whole sub-computation should be excluded.
6. Vectorising Map (vmap)
6.1 Basic usage
JAX
def f(x):
return jnp.sum(x ** 2)
f_batched = jax.vmap(f)
batch = jnp.ones((10, 5))
results = f_batched(batch) # shape (10,)
Rune
let f x = sum (mul x x) in
let f_batched = vmap f in
let batch = ones Float32 [|10; 5|] in
let results = f_batched batch
(* results has shape [|10|] *)
Both map over axis 0 by default and stack outputs on axis 0.
6.2 Axis specifications
JAX
# Map over axis 1
jax.vmap(f, in_axes=1)
# Don't map an input (broadcast it)
jax.vmap(f, in_axes=(0, None))
Rune
(* Map over axis 1 *)
let f_axis1 = vmap ~in_axes:(Single (Map 1)) f
(* Don't map an input (broadcast it) *)
let f_shared =
vmaps
~in_axes:[Map 0; NoMap]
f_multi
JAX uses None to indicate a non-mapped input and integers for mapped axes. Rune uses Map n and NoMap constructors. For single-input functions, wrap in Single; for multi-input, use vmaps with a list.
Output axis control:
(* Stack outputs along axis 1 instead of 0 *)
let f' = vmap ~out_axes:(OutSingle (Some 1)) f
(* Discard the batch dimension (e.g., for reductions) *)
let f' = vmap ~out_axes:(OutSingle None) f
6.3 Composing vmap with grad
JAX
# Per-example gradients
per_example_grad = jax.vmap(jax.grad(loss_fn))
Rune
let per_example_grad = vmap (grad loss_fn)
Both compose naturally. This gives per-example gradients without writing batch loops.
7. Gradient Checking
JAX
from jax._src import test_util as jtu
jtu.check_grads(f, (x,), order=1)
Rune
match check_gradient ~verbose:true f x with
| `Pass result -> Printf.printf "max error: %e\n" result.max_abs_error
| `Fail result ->
Printf.printf "%d of %d elements failed\n"
result.num_failed result.num_checked
Rune provides more detailed results. The gradient_check_result record includes:
max_abs_error,max_rel_error,mean_abs_error,mean_rel_errorfailed_indiceswith per-element(index, autodiff_value, finite_diff_value, abs_error)passed,num_checked,num_failed
Additional utilities:
finite_diff f x-- approximate gradient via finite differencesfinite_diff_jacobian f x-- approximate Jacobian for non-scalar outputscheck_gradients f xs-- check a multi-input function
You can control the finite-difference method:
let fd = finite_diff ~method_:`Forward ~eps:1e-5 f x
Available methods: `Central (default), `Forward, `Backward.
8. Debugging
JAX
def f(x):
y = x ** 2
jax.debug.print("y = {}", y)
return y
f(jnp.array(3.0))
Rune
let f x = add (mul x x) (sin x) in
let x = scalar Float32 2.0 in
let _ = debug f x in
()
(* Prints each operation, its inputs, and its output *)
JAX's debug.print is a targeted print inside traced code. Rune's debug wraps an entire function and traces every tensor operation, printing the operation name, inputs, and output. It is more coarse-grained but requires no instrumentation inside the function.
9. Control Flow
This is a fundamental difference.
JAX
Inside jit-compiled functions, Python control flow does not work because JAX traces the function:
# Breaks under jit:
@jax.jit
def f(x):
if x > 0: # Error: traced value used in Python conditional
return x
else:
return -x
# Must use JAX primitives:
@jax.jit
def f(x):
return jax.lax.cond(x > 0, lambda: x, lambda: -x)
Rune
OCaml control flow works everywhere, including inside grad, jvp, and vmap:
let f x =
if item [] x > 0.0 then x
else neg x
(* Works fine *)
let df = grad f
Rune does not trace functions into a graph. It intercepts operations as they execute via effect handlers, so any OCaml expression is valid. No special cond, scan, or while_loop primitives are needed.
10. What Rune Does Not Have (Yet)
| JAX feature | Status in Rune |
|---|---|
jax.jit |
Not implemented. All operations execute eagerly. |
Device placement (jax.device_put, GPU/TPU) |
Not implemented. All computation runs on CPU. |
jax.pmap / distributed |
Not implemented. |
jax.lax.scan, jax.lax.while_loop |
Not needed. Use ordinary OCaml loops and recursion. |
jax.custom_vjp, jax.custom_jvp |
Not yet exposed. |
jax.checkpoint (gradient checkpointing) |
Not implemented. |
| Pytrees / tree utilities | Not needed. Use OCaml data structures directly. |
jax.random (splittable PRNG) |
Use Nx.rand, Nx.randn directly. |
11. Quick Cheat Sheet
| Task | JAX | Rune |
|---|---|---|
| Gradient of scalar fn | jax.grad(f)(x) |
grad f x |
| Value + gradient | jax.value_and_grad(f)(x) |
value_and_grad f x |
| Multi-input gradient | jax.grad(f, argnums=(0,1))(x, y) |
grads f [x; y] |
| Auxiliary output | jax.value_and_grad(f, has_aux=True)(x) |
value_and_grad_aux f x |
| Higher-order deriv | jax.grad(jax.grad(f)) |
grad (grad f) |
| VJP | primals, fn = jax.vjp(f, x); fn(v) |
vjp f x v |
| JVP | jax.jvp(f, (x,), (v,)) |
jvp f x v |
| Stop gradient | jax.lax.stop_gradient(x) |
detach x |
| Block region from AD | (no direct equivalent) | no_grad (fun () -> ...) |
| Batch map | jax.vmap(f)(batch) |
vmap f batch |
| vmap axis control | jax.vmap(f, in_axes=(0, None)) |
vmaps ~in_axes:[Map 0; NoMap] f |
| Per-example grad | jax.vmap(jax.grad(f)) |
vmap (grad f) |
| Gradient check | jtu.check_grads(f, (x,), 1) |
check_gradient f x |
| Finite differences | (manual) | finite_diff f x |
| Debug tracing | jax.debug.print(...) |
debug f x |
| JIT compilation | jax.jit(f) |
Not yet available |
| GPU placement | jax.device_put(x, gpu) |
Not yet available |