Quickstart
This gets you from zero to computing gradients and training a model in five minutes.
Setup
opam install raven
Create a dune-project and dune file:
; dune-project
3
; dune
namelibraries
Installing kaun pulls in nx and rune automatically.
Step 1: Arrays with Nx
Nx provides n-dimensional arrays. Every value has a data type and a shape.
open Nx
let () =
(* Create arrays *)
let a = create Float32 [|2; 3|] [|1.; 2.; 3.; 4.; 5.; 6.|] in
let b = ones Float32 [|2; 3|] in
(* Element-wise operations *)
let c = add a b in
print_data c;
(* Reductions *)
Printf.printf "sum = %.1f\n" (item [] (sum a));
Printf.printf "mean = %.1f\n" (item [] (mean a));
(* Matrix multiplication *)
let x = rand Float32 [|3; 4|] in
let y = rand Float32 [|4; 2|] in
let z = matmul x y in
Printf.printf "matmul shape: %s\n"
(Array.to_list (shape z) |> List.map string_of_int |> String.concat "x")
Step 2: Gradients with Rune
Rune computes derivatives of Nx functions automatically. Write a function using Nx operations, then use grad to differentiate it.
open Nx
open Rune
let () =
(* f(x) = x² + sin(x) *)
let f x = add (mul x x) (sin x) in
(* grad returns the derivative function *)
let f' = grad f in
let x = scalar Float32 2.0 in
Printf.printf "f(2) = %.4f\n" (item [] (f x));
Printf.printf "f'(2) = %.4f\n" (item [] (f' x));
(* Higher-order: second derivative *)
let f'' = grad f' in
Printf.printf "f''(2) = %.4f\n" (item [] (f'' x))
Step 3: Training with Kaun
Kaun provides layers, optimizers, and training loops built on Rune.
open Kaun
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
(* XOR dataset *)
let x = Nx.create Nx.Float32 [|4; 2|]
[|0.; 0.; 0.; 1.; 1.; 0.; 1.; 1.|] in
let y = Nx.create Nx.Float32 [|4; 1|]
[|0.; 1.; 1.; 0.|] in
(* Define model *)
let model = Layer.sequential [
Layer.linear ~in_features:2 ~out_features:8 ();
Layer.tanh ();
Layer.linear ~in_features:8 ~out_features:1 ();
] in
(* Create trainer and initialize *)
let trainer = Train.make ~model
~optimizer:(Optim.adam ~lr:(Optim.Schedule.constant 0.01) ()) in
let st = Train.init trainer ~dtype:Nx.Float32 in
(* Train *)
let st = Train.fit trainer st
~report:(fun ~step ~loss _st ->
if step mod 250 = 0 then
Printf.printf "step %4d loss %.6f\n" step loss)
(Data.repeat 1000 (x, fun pred -> Loss.binary_cross_entropy pred y))
in
(* Predict *)
let pred = Train.predict trainer st x |> Nx.sigmoid in
Printf.printf "\npredictions (expected 0 1 1 0):\n";
for i = 0 to 3 do
Printf.printf " [%.0f, %.0f] -> %.3f\n"
(Nx.item [i; 0] x) (Nx.item [i; 1] x) (Nx.item [i; 0] pred)
done
Next Steps
- Nx — full guide to arrays, slicing, broadcasting, linear algebra
- Rune — all transformations: grad, jvp, vmap, and more
- Kaun — layers, optimizers, training loops, pretrained models
- Ecosystem Overview — how all 9 libraries fit together