Module Kaun.Ptree
Heterogeneous parameter trees.
A parameter tree is a finite tree with tensor leaves and container nodes. Leaves are packed tensors (Tensor Inspection), and containers are either ordered lists (List) or string-keyed dicts (Dict).
Constructors
dict kvs is Dict kvs with key validation.
Raises Invalid_argument if a key is empty, duplicated, or contains '.', '[', or ']'.
val empty : tempty is List []. Canonical value for "no parameters" or "no state".
Tensor Inspection
module Tensor : sig ... endDict Access
module Dict : sig ... endList Access
module List : sig ... endLeaf Access
Rank-2 handler for unpacking Tensor Inspection.
val with_tensor : tensor -> 'a tensor_handler -> 'awith_tensor t h applies h.run to the unpacked tensor in t.
as_tensor_exn ?ctx t is t's packed tensor.
Raises Invalid_argument if t is not Tensor _. The optional ctx is prefixed to the error message.
Functional Operations
Rank-2 tensor mapper used by map.
val map : map_handler -> t -> tmap f t maps f.run over tensor leaves and preserves tree structure.
Rank-2 tensor zipper used by map2.
val map2 : map2_handler -> t -> t -> tmap2 f a b zips a and b and applies f.run to paired tensor leaves.
Lists are matched by position. Dict nodes are matched by key using a's key order.
Raises Invalid_argument on structure mismatch, list or dict size mismatch, missing keys in b, or paired dtype mismatch.
iter f t applies f to each leaf tensor in depth-first order.
Leaves are visited left-to-right in list order and dict field order.
fold f acc t folds leaf tensors in the same traversal order as iter.
Flatten and Rebuild
flatten t is (leaves, rebuild) where:
leavesaret's leaf tensors in depth-first order;rebuild new_leavesrebuildst's structure withnew_leaves.
rebuild raises Invalid_argument if new_leaves has a different length than leaves.
flatten_with_paths t returns (path, tensor) pairs where paths are dot-separated strings. Dict keys become path segments; list indices become decimal segments (e.g. "layers.0.weight").
If t is a tensor leaf, its path is the empty string.
The path encoding is injective for trees built with Dict Access, because Dict Access rejects keys containing '.', '[', or ']'.
Utilities
zeros_like t has the same structure as t, with each tensor replaced by Nx.zeros_like.
val count_parameters : t -> intcount_parameters t is the sum of Tensor.numel over all leaf tensors.
val pp : Stdlib.Format.formatter -> t -> unitpp formats trees for debugging.