Getting Started
This guide shows you how to compute gradients and use Rune's transformations.
Installation
opam install rune
Or build from source:
git clone https://github.com/raven-ml/raven
cd raven && dune build rune
Add to your dune file:
namelibraries
Your First Gradient
Rune operates on Nx tensors directly. Write a function using Nx operations, then use grad to get its derivative:
open Nx
open Rune
let () =
(* A simple function: f(x) = x² + sin(x) *)
let f x = add (mul x x) (sin x) in
(* grad returns a function that computes the derivative *)
let f' = grad f in
let x = scalar Float32 2.0 in
Printf.printf "f(2) = %.4f\n" (item [] (f x));
Printf.printf "f'(2) = %.4f\n" (item [] (f' x))
(* f'(x) = 2x + cos(x), so f'(2) ≈ 3.5839 *)
Key points:
grad ftakes a functionf : Nx.t -> Nx.tand returns a new function that computes the gradient- The input function must return a scalar tensor
- The gradient has the same shape as the input
Value and Gradient Together
In practice, you usually want both the function value and its gradient. Use value_and_grad to avoid computing the forward pass twice:
open Nx
open Rune
let () =
let f x = mean (mul x x) in
let x = create Float32 [|3|] [|1.0; 2.0; 3.0|] in
let value, gradient = value_and_grad f x in
Printf.printf "f(x) = %.4f\n" (item [] value);
print_data gradient
Multiple Inputs
When your function takes multiple inputs, use grads or value_and_grads:
open Nx
open Rune
let () =
let f inputs =
match inputs with
| [x; y] -> add (mul x x) (mul y y)
| _ -> failwith "expected 2 inputs"
in
let df = grads f in
match df [scalar Float32 3.0; scalar Float32 4.0] with
| [dx; dy] ->
Printf.printf "df/dx = %.1f\n" (item [] dx);
Printf.printf "df/dy = %.1f\n" (item [] dy)
| _ -> assert false
Higher-Order Derivatives
Since grad returns a regular function, you can differentiate again:
open Nx
open Rune
let () =
(* f(x) = x⁴ *)
let f x = mul x (mul x (mul x x)) in
let f' = grad f in (* 4x³ *)
let f'' = grad f' in (* 12x² *)
let f''' = grad f'' in (* 24x *)
let x = scalar Float32 2.0 in
Printf.printf "f(2) = %.1f\n" (item [] (f x));
Printf.printf "f'(2) = %.1f\n" (item [] (f' x));
Printf.printf "f''(2) = %.1f\n" (item [] (f'' x));
Printf.printf "f'''(2) = %.1f\n" (item [] (f''' x))
Stopping Gradients
Sometimes you need part of a computation to be treated as a constant:
open Rune
(* no_grad: nothing inside is recorded *)
let baseline = no_grad (fun () ->
(* compute a baseline value that should not be differentiated *)
mean predictions
)
(* detach: make a single tensor a constant *)
let target = detach current_prediction
A Simple Training Loop
Here is a minimal example that trains a linear model with gradient descent:
open Nx
open Rune
let () =
(* Data: y = 2x + 1 *)
let x_data = create Float32 [|4; 1|] [|1.; 2.; 3.; 4.|] in
let y_data = create Float32 [|4; 1|] [|3.; 5.; 7.; 9.|] in
(* Parameters *)
let w = rand Float32 [|1; 1|] in
let b = zeros Float32 [|1|] in
let loss_fn params =
match params with
| [w; b] ->
let pred = add (matmul x_data w) b in
mean (mul (sub pred y_data) (sub pred y_data))
| _ -> assert false
in
let lr = scalar Float32 0.01 in
for epoch = 1 to 200 do
let loss, gs = value_and_grads loss_fn [w; b] in
match gs with
| [gw; gb] ->
ignore (sub ~out:w w (mul lr gw));
ignore (sub ~out:b b (mul lr gb));
if epoch mod 50 = 0 then
Printf.printf "epoch %d loss %.6f\n" epoch (item [] loss)
| _ -> assert false
done;
Printf.printf "w = %.3f b = %.3f\n" (item [0; 0] w) (item [0] b)
For real neural networks, use Kaun which provides layers, optimizers, and training loops built on top of Rune.
Next Steps
- Transformations — complete guide to grad, jvp, vmap, and more
- How It Works — how effects-based autodiff works under the hood
- Kaun Getting Started — high-level neural network training