Module Kaun.Init
Weight initialization strategies.
Initializers map a shape and float dtype to tensors. Random keys are obtained implicitly via Nx.Rng.next_key. Named families (Glorot, He, LeCun) are defined in terms of variance_scaling.
Types
t is the type for initializers.
i.f shape dtype is an initialized tensor for shape and dtype. Random keys are drawn from the implicit RNG scope.
Constant
val zeros : tzeros is the initializer that fills with 0.0.
val ones : tones is the initializer that fills with 1.0.
val constant : float -> tconstant v is the initializer that fills with v.
Random
val uniform : ?scale:float -> unit -> tuniform ?scale () is the initializer that samples from U(0, scale).
scale defaults to 0.01.
Raises Invalid_argument if scale is negative.
val normal : ?stddev:float -> unit -> tnormal ?stddev () is the initializer that samples from N(0, stddev).
stddev defaults to 0.01.
Raises Invalid_argument if stddev is negative.
Variance Scaling
The type for variance-scaling distribution families.
val variance_scaling :
scale:float ->
mode:mode ->
distribution:distribution ->
?in_axis:int ->
?out_axis:int ->
unit ->
tvariance_scaling ~scale ~mode ~distribution ?in_axis ?out_axis () is the variance-scaling initializer.
in_axis defaults to -2 and out_axis defaults to -1. Negative axes are interpreted from the end.
The target variance is scale / n, with:
n = fan_infor`Fan_in.n = fan_outfor`Fan_out.n = (fan_in + fan_out) / 2for`Fan_avg.
Distributions are:
`Normal:N(0, scale / n).`Uniform:U(-limit, limit)withlimit = sqrt (3 * scale / n).`Truncated_normal: normal samples truncated to [-2;2] and rescaled to matchscale / n.
Raises Invalid_argument if:
scaleis negative.in_axisorout_axisis out of bounds for rank > 1.- the computed fan is non-positive.
Glorot/Xavier
val glorot_uniform : ?in_axis:int -> ?out_axis:int -> unit -> tglorot_uniform ?in_axis ?out_axis () is Glorot/Xavier uniform initialization.
It samples from U(-limit, limit) with limit = sqrt (6 / (fan_in + fan_out)).
This is the Xavier/Glorot scheme of Glorot and Bengio (2010). It is implemented via variance_scaling with fan-average mode.
Raises Invalid_argument under the same conditions as variance_scaling.
val glorot_normal : ?in_axis:int -> ?out_axis:int -> unit -> tglorot_normal ?in_axis ?out_axis () is Glorot/Xavier normal initialization.
It uses truncated normal sampling with fan-average target variance 2 / (fan_in + fan_out).
This is the Xavier/Glorot family of Glorot and Bengio (2010). It is implemented via variance_scaling.
Raises Invalid_argument under the same conditions as variance_scaling.
He/Kaiming
val he_uniform : ?in_axis:int -> ?out_axis:int -> unit -> the_uniform ?in_axis ?out_axis () is He/Kaiming uniform initialization.
It samples from U(-limit, limit) with limit = sqrt (6 / fan_in).
This is the Kaiming/He scheme of He et al. (2015), commonly used for ReLU-like activations. It is implemented via variance_scaling in fan-in mode.
Raises Invalid_argument under the same conditions as variance_scaling.
val he_normal : ?in_axis:int -> ?out_axis:int -> unit -> the_normal ?in_axis ?out_axis () is He/Kaiming normal initialization.
It uses truncated normal sampling with fan-in target variance 2 / fan_in.
This is the Kaiming/He family of He et al. (2015). It is implemented via variance_scaling.
Raises Invalid_argument under the same conditions as variance_scaling.
LeCun
val lecun_uniform : ?in_axis:int -> ?out_axis:int -> unit -> tlecun_uniform ?in_axis ?out_axis () is LeCun uniform initialization.
It samples from U(-limit, limit) with limit = sqrt (3 / fan_in).
This is the LeCun fan-in family (Efficient BackProp, LeCun et al., 1998). It is implemented via variance_scaling.
Raises Invalid_argument under the same conditions as variance_scaling.
val lecun_normal : ?in_axis:int -> ?out_axis:int -> unit -> tlecun_normal ?in_axis ?out_axis () is LeCun normal initialization.
It uses truncated normal sampling with fan-in target variance 1 / fan_in.
This is the LeCun fan-in family (Efficient BackProp, LeCun et al., 1998). It is implemented via variance_scaling.
Raises Invalid_argument under the same conditions as variance_scaling.