Training
This guide covers optimizers, learning-rate schedules, loss functions, data pipelines, the high-level training loop, metrics, and custom training.
Optimizers
An Optim.algorithm pairs a learning-rate schedule with an update rule.
All optimizers take ~lr as a Schedule.t:
(* SGD with momentum *)
Optim.sgd ~lr:(Optim.Schedule.constant 0.1) ~momentum:0.9 ()
(* Adam *)
Optim.adam ~lr:(Optim.Schedule.constant 1e-3) ()
(* AdamW with weight decay *)
Optim.adamw ~lr:(Optim.Schedule.constant 1e-3) ~weight_decay:0.01 ()
(* RMSprop *)
Optim.rmsprop ~lr:(Optim.Schedule.constant 1e-3) ()
(* Adagrad *)
Optim.adagrad ~lr:(Optim.Schedule.constant 0.01) ()
sgd supports optional ~momentum (default 0.0) and ~nesterov
(default false). adam and adamw support ~b1 (default 0.9), ~b2
(default 0.999), and ~eps (default 1e-8). rmsprop supports ~decay
(default 0.9), ~eps, and ~momentum.
Learning-Rate Schedules
A schedule is a function int -> float mapping 1-based step numbers to
learning rates:
(* Fixed learning rate *)
Optim.Schedule.constant 1e-3
(* Cosine decay from 1e-3 to 0 over 10000 steps *)
Optim.Schedule.cosine_decay ~init_value:1e-3 ~decay_steps:10000 ()
(* Cosine decay with minimum alpha *)
Optim.Schedule.cosine_decay ~init_value:1e-3 ~decay_steps:10000 ~alpha:1e-5 ()
(* Linear warmup from 0 to 1e-3 over 1000 steps *)
Optim.Schedule.warmup_linear ~init_value:0. ~peak_value:1e-3 ~warmup_steps:1000
(* Cosine warmup *)
Optim.Schedule.warmup_cosine ~init_value:0. ~peak_value:1e-3 ~warmup_steps:1000
(* Exponential decay *)
Optim.Schedule.exponential_decay ~init_value:1e-3 ~decay_rate:0.96 ~decay_steps:1000
Compose schedules by writing a custom function:
let warmup_then_cosine step =
if step <= 1000 then
Optim.Schedule.warmup_linear ~init_value:0. ~peak_value:1e-3 ~warmup_steps:1000 step
else
Optim.Schedule.cosine_decay ~init_value:1e-3 ~decay_steps:9000 () (step - 1000)
Loss Functions
All loss functions return scalar tensors that are differentiable through Rune's autodiff:
(* Multi-class: logits [batch; num_classes], one-hot labels [batch; num_classes] *)
Loss.cross_entropy logits one_hot_labels
(* Multi-class with integer labels: logits [batch; num_classes], labels [batch] *)
Loss.cross_entropy_sparse logits class_indices
(* Binary: raw logits (not sigmoid), labels in {0, 1} *)
Loss.binary_cross_entropy logits labels
(* Regression *)
Loss.mse predictions targets
Loss.mae predictions targets
Data Pipelines
Data.t is a lazy, composable iterator. Build pipelines by chaining
constructors, transformers, and consumers.
Constructors
(* From arrays *)
Data.of_array [| example1; example2; example3 |]
(* From tensors: slices along first dimension *)
Data.of_tensor x (* yields x[0], x[1], ... *)
Data.of_tensors (x, y) (* yields (x[0], y[0]), (x[1], y[1]), ... *)
(* From a function *)
Data.of_fn 1000 (fun i -> generate_example i)
(* Repeat a value *)
Data.repeat 1000 (x, loss_fn)
Transformers
(* Map each element *)
Data.map (fun (x, y) -> (preprocess x, y)) data
(* Batch into arrays of size n *)
Data.batch 32 data (* yields arrays of 32 elements *)
Data.batch ~drop_last:true 32 data
(* Batch and map in one step *)
Data.map_batch 32 collate_fn data
(* Shuffle *)
Data.shuffle rng_key data
Consumers
Data.iter (fun x -> process x) data
Data.iteri (fun i x -> Printf.printf "%d: %f\n" i x) data
Data.fold (fun acc x -> acc +. x) 0. data
Data.to_array data
Data.to_seq data
The prepare Shortcut
Data.prepare combines tensor slicing, optional shuffle, and batching
into one call. It is the standard way to feed tensor data to training:
let train_data =
Data.prepare ~shuffle:rng_key ~batch_size:64 (x_train, y_train)
|> Data.map (fun (x, y) ->
(x, fun logits -> Loss.cross_entropy_sparse logits y))
Data.prepare yields (x_batch, y_batch) tensor pairs. The Data.map
step attaches the loss function, producing the (input, loss_fn) pairs
that Train.fit expects.
~drop_last defaults to true in prepare.
Resetting
Pipelines are single-pass. Call Data.reset to iterate again:
Data.reset test_batches;
let acc = Metric.eval eval_fn test_batches
High-Level Training
Train.make and Train.init
Create a trainer by pairing a model with an optimizer, then initialize:
let trainer = Train.make ~model
~optimizer:(Optim.adam ~lr:(Optim.Schedule.constant 1e-3) ())
let st = Train.init trainer ~dtype:Nx.Float32
Train.fit
Train.fit trains over a data pipeline and returns the final state:
let st = Train.fit trainer st
~report:(fun ~step ~loss _st ->
Printf.printf "step %d loss %.4f\n" step loss)
data
Each element of data is (input, loss_fn) where loss_fn takes the
model output and returns a scalar loss.
The optional ~report callback is called after every step. The ~step
number is 1-based.
Early Stopping
Raise Train.Early_stop inside ~report to end training early.
Train.fit catches it and returns the current state:
let st = Train.fit trainer st
~report:(fun ~step:_ ~loss st ->
if loss < 0.001 then raise Train.Early_stop)
data
Train.predict
Run inference in evaluation mode (no dropout, no state updates):
let logits = Train.predict trainer st x
Train.step
For manual control over single training steps:
let loss, st' = Train.step trainer st ~training:true
~loss:(fun logits -> Loss.cross_entropy_sparse logits y)
x
Starting from Pretrained Weights
Use Train.make_state to create training state from externally loaded
weights instead of random initialization:
let vars = (* load from checkpoint *) in
let st = Train.make_state trainer vars
Metrics
Metric Functions
Metric functions are plain predictions -> targets -> float functions:
(* Multi-class: logits [batch; num_classes], labels [batch] *)
Metric.accuracy logits targets
(* Binary classification *)
Metric.binary_accuracy ~threshold:0.5 predictions targets
(* Precision, recall, F1 with averaging mode *)
Metric.precision Metric.Macro logits targets
Metric.recall Metric.Micro logits targets
Metric.f1 Metric.Weighted logits targets
Averaging modes: Macro (unweighted mean of per-class scores), Micro
(global aggregation), Weighted (mean weighted by class support).
Dataset Evaluation
Metric.eval folds a function over a data pipeline and returns the
mean:
Data.reset test_batches;
let test_acc =
Metric.eval
(fun (x, y) ->
let logits = Train.predict trainer st x in
Metric.accuracy logits y)
test_batches
Metric.eval_many evaluates multiple named metrics at once:
let results =
Metric.eval_many
(fun (x, y) ->
let logits = Train.predict trainer st x in
[ ("accuracy", Metric.accuracy logits y);
("f1", Metric.f1 Metric.Macro logits y) ])
test_batches
(* results : (string * float) list *)
Running Tracker
Metric.tracker accumulates running means during training:
let tracker = Metric.tracker () in
(* In the training loop: *)
Metric.observe tracker "loss" loss_value;
Metric.observe tracker "accuracy" acc_value;
(* After an epoch: *)
Printf.printf "%s\n" (Metric.summary tracker);
(* "accuracy: 0.9150 loss: 0.4231" *)
Metric.reset tracker
Gradient Utilities
Gradient Clipping
Clip gradients by global L2 norm to prevent exploding gradients. Use
this with Train.step in custom training loops:
let clipped_grads = Optim.clip_by_global_norm 1.0 grads
Global Norm
Compute the L2 norm across all leaf tensors:
let norm = Optim.global_norm grads
Manual Gradient Computation
Grad.value_and_grad differentiates a function with respect to a
Ptree.t:
let loss, grads = Grad.value_and_grad
(fun params ->
let output, _state = model.apply ~params ~state ~dtype ~training:true x in
Loss.mse output y)
params
Grad.value_and_grad_aux returns auxiliary data alongside the loss:
let loss, grads, new_state = Grad.value_and_grad_aux
(fun params ->
let output, new_state = model.apply ~params ~state ~dtype ~training:true x in
(Loss.mse output y, new_state))
params
Next Steps
- Layers and Models — full layer catalog, composition, custom layers
- Checkpoints and Pretrained Models — saving, loading, HuggingFace Hub