Module Kaun.Layer
Composable neural network layers.
A t pairs parameter/state initialization with a forward computation. Layers compose with compose for heterogeneous pipelines (for example embeddings to dense layers) and with sequential for homogeneous float pipelines.
Types
The type for model variables.
params are trainable variables consumed by Optim. state is non-trainable mutable state updated by forward passes (for example running statistics in batch_norm).
with_params v params is v with replaced trainable parameters.
with_state v state is v with replaced non-trainable state.
make_vars ~params ~state ~dtype builds model variables.
This is mainly useful for layer constructors implemented outside the Layer module.
type ('input, 'output) t = {init : 'layout. dtype:(float, 'layout) Nx.dtype -> 'layout vars;apply : 'layout 'in_elt. params:Ptree.t -> state:Ptree.t -> dtype:(float, 'layout) Nx.dtype -> training:bool -> ?ctx:Context.t -> ('input, 'in_elt) Nx.t -> ('output, 'layout) Nx.t * Ptree.t;
}The type for layers.
init creates fresh params and state. apply computes a forward pass and returns updated state. Random operations (weight initialization, dropout) use the implicit RNG scope established by Nx.Rng.run or Nx.Rng.with_key.
The input tensor's dtype witness 'in_elt is independent of the model's float dtype witness 'layout. This allows layers like embedding to accept int32_elt indices while the model parameters use float32_elt. Float-consuming layers (e.g. linear) require the input dtype to match the model dtype exactly and raise Invalid_argument on mismatch.
ctx carries per-call auxiliary data (attention masks, position ids, encoder memory). Most layers ignore it; transformer layers read from it using well-known key names. See Context.
init m ~dtype is m's fresh variables.
Composite layers (compose, sequential) isolate sub-network RNG streams via Nx.Rng.with_key.
val apply :
('a, 'b) t ->
'layout vars ->
training:bool ->
?ctx:Context.t ->
('a, 'in_elt) Nx.t ->
('b, 'layout) Nx.t * 'layout varsapply m vars ~training ?ctx x is the forward pass of m.
Returns (y, vars') where params vars' = params vars and state vars' is the updated state from the forward pass.
The input tensor's dtype witness 'in_elt is independent of the model's float dtype witness 'layout. For float-consuming layers, the input must have the same dtype as the model; a mismatch raises Invalid_argument.
training controls stochastic/stateful behavior. For example, dropout uses dropout masks only when training = true, and batch_norm updates running statistics only when training = true.
ctx carries per-call auxiliary data such as attention masks. See Context.
Composition
compose left right applies left then right.
Parameters and state are stored as Ptree.t.Dict nodes with keys "left" and "right". The RNG key is split between both layers during initialization and forward pass.
sequential layers applies layers in order.
Parameters and state are stored as Ptree.t.List nodes with one entry per layer. The RNG key is split per layer during initialization and forward pass.
Raises Invalid_argument if runtime parameter/state list lengths do not match layers.
Dense
val linear :
in_features:int ->
out_features:int ->
?weight_init:Init.t ->
?bias_init:Init.t ->
unit ->
(float, float) tlinear ~in_features ~out_features ?weight_init ?bias_init () is the fully connected map xW + b.
weight_init defaults to Init.glorot_uniform(). bias_init defaults to Init.zeros.
Parameters:
weightwith shape[in_features; out_features].biaswith shape[out_features].
Convolution
val conv1d :
in_channels:int ->
out_channels:int ->
?kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding:[ `Same | `Valid | `Causal ] ->
unit ->
(float, float) tconv1d ~in_channels ~out_channels ?kernel_size ?stride ?dilation ?padding () is 1D convolution over inputs shaped [batch; in_channels; length].
kernel_size defaults to 3. stride defaults to 1. dilation defaults to 1. padding defaults to `Same.
Parameters:
weightwith shape[out_channels; in_channels; kernel_size].biaswith shape[out_channels].
val conv2d :
in_channels:int ->
out_channels:int ->
?kernel_size:(int * int) ->
unit ->
(float, float) tconv2d ~in_channels ~out_channels ?kernel_size () is 2D convolution over inputs shaped [batch; in_channels; height; width].
kernel_size defaults to (3, 3). Stride is fixed to (1, 1) and padding mode is `Same.
Parameters:
weightwith shape[out_channels; in_channels; kh; kw].biaswith shape[out_channels].
Normalization
val layer_norm : dim:int -> ?eps:float -> unit -> (float, float) tlayer_norm ~dim ?eps () is layer normalization with learnable affine parameters.
eps defaults to 1e-5.
Parameters:
gammawith shape[dim].betawith shape[dim].
val rms_norm : dim:int -> ?eps:float -> unit -> (float, float) trms_norm ~dim ?eps () is RMS normalization with learnable scale.
eps defaults to 1e-6.
Parameters:
scalewith shape[dim].
val batch_norm : num_features:int -> unit -> (float, float) tbatch_norm ~num_features () is stateful batch normalization.
During training, batch statistics are used and running statistics are updated. During evaluation, running statistics are used and preserved.
Normalization axes are inferred from rank:
- rank 2 uses
[0]. - rank 3 uses
[0; 2]. - rank 4 uses
[0; 2; 3]. - other ranks use
[0].
Parameters:
scalewith shape[num_features].biaswith shape[num_features].
State:
running_meanwith shape[num_features].running_varwith shape[num_features].
Embedding
val embedding :
vocab_size:int ->
embed_dim:int ->
?scale:bool ->
unit ->
(int32, float) tembedding ~vocab_size ~embed_dim ?scale () is an embedding lookup layer.
Inputs are int32 token indices. Output shape is indices_shape ++ [embed_dim].
scale defaults to true. When true, output vectors are multiplied by sqrt embed_dim.
Parameters:
embeddingwith shape[vocab_size; embed_dim].
Regularization
val dropout : rate:float -> unit -> (float, float) tdropout ~rate () is elementwise dropout.
When training = false, it is identity. When training = true, dropout masks are generated using keys from the implicit RNG scope.
Raises Invalid_argument if rate is outside 0.0 <= rate < 1.0.
Activation Layers
val relu : unit -> (float, float) trelu () is max(0, x). No parameters.
val gelu : unit -> (float, float) tgelu () is the Gaussian error linear unit. No parameters.
val silu : unit -> (float, float) tsilu () is x * sigmoid(x). No parameters.
val tanh : unit -> (float, float) ttanh () is hyperbolic tangent. No parameters.
val sigmoid : unit -> (float, float) tsigmoid () is the logistic function. No parameters.
Pooling
val max_pool2d :
kernel_size:(int * int) ->
?stride:(int * int) ->
unit ->
(float, float) tmax_pool2d ~kernel_size ?stride () is 2D max pooling.
stride defaults to kernel_size. No parameters.
val avg_pool2d :
kernel_size:(int * int) ->
?stride:(int * int) ->
unit ->
(float, float) tavg_pool2d ~kernel_size ?stride () is 2D average pooling.
stride defaults to kernel_size. No parameters.
Reshape
val flatten : unit -> (float, float) tflatten () flattens all dimensions after the batch dimension.
[batch; d1; ...; dn] becomes [batch; d1 * ... * dn].
Raises Invalid_argument if the input rank is 0.