Module Kaun.Metric
Training metrics.
Metric provides running scalar tracking and dataset evaluation.
A tracker accumulates named running means during training. For dataset evaluation, eval and eval_many fold user-supplied functions over a Data.t pipeline and return averaged results.
Metric functions such as accuracy are plain tensor-to-scalar functions that compose freely with eval.
Running Tracker
val tracker : unit -> trackertracker () is a fresh tracker with no observations.
val observe : tracker -> string -> float -> unitobserve t name value records value under name.
val mean : tracker -> string -> floatmean t name is the running mean of observations under name.
Raises Not_found if name was never observed.
val count : tracker -> string -> intcount t name is the number of observations under name.
Raises Not_found if name was never observed.
val reset : tracker -> unitreset t clears all observations.
val to_list : tracker -> (string * float) listto_list t is the current means as (name, mean) pairs, sorted by name.
val summary : tracker -> stringsummary t is a human-readable one-liner of all current means, e.g. "accuracy: 0.9150 loss: 0.4231".
Dataset Evaluation
val eval : ('a -> float) -> 'a Data.t -> floateval f data is the mean of f batch over all elements of data.
Raises Invalid_argument if data yields no elements.
val eval_many :
('a -> (string * float) list) ->
'a Data.t ->
(string * float) listeval_many f data is the per-name mean of f batch over all elements of data. Returns (name, mean) pairs sorted by name.
Raises Invalid_argument if data yields no elements.
Averaging
Common Metric Functions
accuracy predictions targets is the fraction of correct predictions.
Multi-class: predictions has shape [batch; num_classes] (logits or probabilities), targets has shape [batch] (integer class indices). Predicted class is argmax along the last axis.
Binary: both tensors have shape [batch] or [batch; 1]. Predictions above 0.5 count as class 1.
binary_accuracy ?threshold predictions targets is the fraction of correct binary predictions.
threshold defaults to 0.5. Predictions above threshold count as class 1, targets are expected in [0;1].
Classification
precision avg predictions targets is the precision score.
predictions has shape [batch; num_classes] (logits or probabilities). targets has shape [batch] (integer class indices). Predicted class is argmax along the last axis.
When a class has no predicted instances, its precision is 0.0.
recall avg predictions targets is the recall score.
Input convention is the same as precision.
When a class has no true instances, its recall is 0.0.