Module Kaun.Data
Lazy, composable data pipelines for training.
A t is a resettable iterator over elements of type 'a. Pipelines are built by composing constructors, transformers, and consumers.
Data.of_array examples |> Data.shuffle
|> Data.map_batch 32 collate
|> Data.iter train_stepTypes
Constructors
val of_array : 'a array -> 'a tof_array a is a pipeline yielding the elements of a in order.
of_tensor t is a pipeline yielding slices along the first dimension of t. Each element has shape t.shape[1:].
of_tensors (x, y) is a pipeline yielding paired slices along the first dimension of x and y.
Raises Invalid_argument if x and y have different first dimension sizes.
val of_fn : int -> (int -> 'a) -> 'a tof_fn n f is a pipeline yielding f 0, f 1, ..., f (n - 1).
Raises Invalid_argument if n < 0.
val repeat : int -> 'a -> 'a trepeat n v is a pipeline that yields v exactly n times.
Raises Invalid_argument if n < 0.
Transformers
batch ?drop_last n t is a pipeline yielding arrays of n consecutive elements from t.
drop_last defaults to false. When true, the final batch is dropped if it has fewer than n elements.
Raises Invalid_argument if n <= 0.
map_batch ?drop_last n f t is map f (batch ?drop_last n t).
shuffle t is a pipeline that yields the elements of t in a random order. The permutation is computed once when the pipeline is created.
Random keys are drawn from the implicit RNG scope.
Raises Invalid_argument if t has unknown length.
Consumers
val iter : ('a -> unit) -> 'a t -> unititer f t applies f to each element of t.
val iteri : (int -> 'a -> unit) -> 'a t -> unititeri f t applies f i x to each element x of t, where i is the 0-based index.
val fold : ('acc -> 'a -> 'acc) -> 'acc -> 'a t -> 'accfold f init t folds f over the elements of t.
val to_array : 'a t -> 'a arrayto_array t collects all elements of t into an array.
val to_seq : 'a t -> 'a Stdlib.Seq.tto_seq t is a standard Seq.t view of t. Does not reset t.
Properties
val reset : 'a t -> unitreset t resets t so that iteration starts from the beginning.
val length : 'a t -> int optionlength t is the number of elements in t, if known.
Utilities
stack_batch tensors stacks an array of tensors along a new first axis. Equivalent to Nx.stack (Array.to_list tensors).
val prepare :
?shuffle:bool ->
batch_size:int ->
?drop_last:bool ->
(('a, 'b) Nx.t * ('c, 'd) Nx.t) ->
(('a, 'b) Nx.t * ('c, 'd) Nx.t) tprepare ?shuffle ~batch_size (x, y) is a pipeline that yields batched tensor pairs from x and y.
Each yielded pair has shape [batch_size; ...] along the first dimension.
shuffle defaults to false. When true, elements are yielded in a random order. drop_last defaults to true.
Raises Invalid_argument if x and y have different first dimension sizes, or if batch_size <= 0.