Module Kaun.Optim
Optimizers and learning-rate schedules.
An algorithm combines a learning-rate schedule and an update rule. state stores optimizer-specific accumulators and the step count.
Types
Core
step algo state params grads is (updates, state') where updates are additive parameter deltas.
The step count is incremented before the learning-rate schedule is evaluated. Use apply_updates to apply updates to params.
apply_updates params updates is params + updates element-wise.
update algo st params grads is let u, st' = step algo st params grads in (apply_updates params u, st').
Convenience for the common case where you want updated parameters directly rather than additive deltas.
Learning-Rate Schedules
module Schedule : sig ... endOptimizers
val sgd :
lr:Schedule.t ->
?momentum:float ->
?nesterov:bool ->
unit ->
algorithmsgd ~lr ?momentum ?nesterov () is stochastic gradient descent.
momentum defaults to 0.. nesterov defaults to false. Nesterov mode is ignored when momentum = 0..
Raises Invalid_argument if momentum is not in 0.0 <= momentum < 1.0.
val adam :
lr:Schedule.t ->
?b1:float ->
?b2:float ->
?eps:float ->
unit ->
algorithmadam ~lr ?b1 ?b2 ?eps () is Adam with bias correction.
b1 defaults to 0.9. b2 defaults to 0.999. eps defaults to 1e-8.
Raises Invalid_argument if b1 or b2 is not in 0.0 <= b < 1.0, or if eps <= 0.0.
val adamw :
lr:Schedule.t ->
?b1:float ->
?b2:float ->
?eps:float ->
?weight_decay:float ->
unit ->
algorithmadamw ~lr ?b1 ?b2 ?eps ?weight_decay () is AdamW.
b1 defaults to 0.9. b2 defaults to 0.999. eps defaults to 1e-8. weight_decay defaults to 0.01.
Weight decay is decoupled from the Adam moment estimates.
Raises Invalid_argument if b1 or b2 is not in 0.0 <= b < 1.0, if eps <= 0.0, or if weight_decay < 0.0.
val rmsprop :
lr:Schedule.t ->
?decay:float ->
?eps:float ->
?momentum:float ->
unit ->
algorithmrmsprop ~lr ?decay ?eps ?momentum () is RMSprop.
decay defaults to 0.9. eps defaults to 1e-8. momentum defaults to 0. (no momentum).
Raises Invalid_argument if decay or momentum is not in 0.0 <= x < 1.0, or if eps <= 0.0.
val adagrad : lr:Schedule.t -> ?eps:float -> unit -> algorithmadagrad ~lr ?eps () is Adagrad.
eps defaults to 1e-8.
Raises Invalid_argument if eps <= 0.0.
Serialization
state_to_trees st is (count, trees) where count is the optimizer step count and trees are the internal state as parameter trees.
SGD without momentum returns an empty list. Adam returns [mu; nu].
state_of_trees algo ~count trees reconstructs optimizer state from an algorithm, step count, and serialized trees.
Raises Invalid_argument if the number of trees does not match the algorithm's expectation.
Gradient Utilities
clip_by_global_norm max_norm grads rescales grads so their global L2 norm does not exceed max_norm. Returns grads unchanged if the norm is already within bounds.
Raises Invalid_argument if a leaf tensor is not floating point.
val global_norm : Ptree.t -> floatglobal_norm t is the L2 norm across all leaf tensors of t.
Raises Invalid_argument if a leaf tensor is not floating point.