Module Kaun.Loss
Loss functions.
Losses are differentiable through Rune's autodiff and return scalar means. Invalid_argument messages are prefixed with Loss.<function>:.
Classification
cross_entropy logits one_hot_labels is softmax cross-entropy.
logits has shape [...; num_classes] and must be rank >= 1. one_hot_labels must have the same shape.
Uses the log-sum-exp trick for numerical stability.
Raises Invalid_argument if ranks or shapes differ, or if num_classes is not positive.
cross_entropy_sparse logits class_indices is cross_entropy with integer labels.
class_indices has shape [...] and must match logits without the last dimension. The class dimension is logits' last axis.
Raises Invalid_argument if labels are non-integer, ranks mismatch, non-class dimensions differ, or the class dimension is non-positive.
binary_cross_entropy logits labels is sigmoid binary cross-entropy.
logits are raw (not sigmoid-normalized). labels are typically in [0;1]. Uses log-sigmoid for numerical stability.
Raises Invalid_argument if logits and labels shapes differ.
Regression
mse predictions targets is mean ((predictions - targets)^2).
Shape compatibility follows Nx broadcasting semantics.