Module Kaun.Train
High-level training loop.
Train composes Layer, Grad, and Optim into a single training driver. Users never touch parameter trees, optimizer state, or gradient computation directly.
For advanced use, step exposes a single training step and vars gives access to the underlying model variables.
Types
Core
val make : model:('i, 'o) Layer.t -> optimizer:Optim.algorithm -> ('i, 'o) tmake ~model ~optimizer creates a trainer.
init trainer ~dtype initializes model variables and optimizer state.
Random keys for weight initialization are drawn from the implicit RNG scope.
val vars : 'l state -> 'l Layer.varsvars st is the current model variables (params + state + dtype).
val make_state : ('i, 'o) t -> 'l Layer.vars -> 'l statemake_state trainer vars is a training state with vars and freshly initialized optimizer state.
Use this to start training from pretrained or externally loaded weights instead of init.
Training
Raise inside report to end training early. fit catches this exception and returns the current state.
val step :
('i, 'o) t ->
'l state ->
training:bool ->
?ctx:Context.t ->
loss:(('o, 'l) Nx.t -> (float, 'l) Nx.t) ->
('i, 'in_elt) Nx.t ->
(float, 'l) Nx.t * 'l statestep trainer st ~training ?ctx ~loss x performs one training step.
Computes the forward pass, differentiates the loss with respect to trainable parameters, applies the optimizer, and threads updated layer state.
ctx is forwarded to the model's forward pass. See Context.
When training = false, gradients are still computed and optimizer is still applied. Use predict for pure inference.
val fit :
('i, 'o) t ->
'l state ->
?ctx:Context.t ->
?report:(step:int -> loss:float -> 'l state -> unit) ->
(('i, 'in_elt) Nx.t * (('o, 'l) Nx.t -> (float, 'l) Nx.t)) Data.t ->
'l statefit trainer st ?ctx ?report data trains the model over data and returns the final state.
Each element of data is a pair (x, loss_fn) where x is the input tensor and loss_fn computes the scalar loss from the model output. This allows the loss to depend on per-batch labels.
ctx is forwarded to the model's forward pass on each step. See Context.
When provided, report is called after every step with the step number (1-based), scalar loss, and training state. Raise Early_stop inside report to end training early.
For fixed-data training (same input every step), use Data.repeat:
Train.fit trainer st (Data.repeat 1000 (x, loss_fn))