Module Kaun.Ptree

Heterogeneous parameter trees.

A parameter tree is a finite tree with tensor leaves and container nodes. Leaves are packed tensors (Tensor Inspection), and containers are either ordered lists (List) or string-keyed dicts (Dict).

type tensor =
  1. | P : ('a, 'layout) Nx.t -> tensor
    (*

    A packed tensor. The wrapper hides dtype and layout parameters.

    *)
type t =
  1. | Tensor of tensor
    (*

    A tensor leaf.

    *)
  2. | List of t list
    (*

    An ordered list branch.

    *)
  3. | Dict of (string * t) list
    (*

    A dict branch. Keys are strings.

    *)

Constructors

val tensor : ('a, 'layout) Nx.t -> t

tensor x is Tensor (P x).

val list : t list -> t

list xs is List xs.

val dict : (string * t) list -> t

dict kvs is Dict kvs with key validation.

Raises Invalid_argument if a key is empty, duplicated, or contains '.', '[', or ']'.

val empty : t

empty is List []. Canonical value for "no parameters" or "no state".

Tensor Inspection

module Tensor : sig ... end

Dict Access

module Dict : sig ... end

List Access

module List : sig ... end

Leaf Access

type 'r tensor_handler = {
  1. run : 'a 'layout. ('a, 'layout) Nx.t -> 'r;
}

Rank-2 handler for unpacking Tensor Inspection.

val with_tensor : tensor -> 'a tensor_handler -> 'a

with_tensor t h applies h.run to the unpacked tensor in t.

val as_tensor_exn : ?ctx:string -> t -> tensor

as_tensor_exn ?ctx t is t's packed tensor.

Raises Invalid_argument if t is not Tensor _. The optional ctx is prefixed to the error message.

Functional Operations

type map_handler = {
  1. run : 'a 'layout. ('a, 'layout) Nx.t -> ('a, 'layout) Nx.t;
}

Rank-2 tensor mapper used by map.

val map : map_handler -> t -> t

map f t maps f.run over tensor leaves and preserves tree structure.

type map2_handler = {
  1. run : 'a 'layout. ('a, 'layout) Nx.t -> ('a, 'layout) Nx.t -> ('a, 'layout) Nx.t;
}

Rank-2 tensor zipper used by map2.

val map2 : map2_handler -> t -> t -> t

map2 f a b zips a and b and applies f.run to paired tensor leaves.

Lists are matched by position. Dict nodes are matched by key using a's key order.

Raises Invalid_argument on structure mismatch, list or dict size mismatch, missing keys in b, or paired dtype mismatch.

val iter : (tensor -> unit) -> t -> unit

iter f t applies f to each leaf tensor in depth-first order.

Leaves are visited left-to-right in list order and dict field order.

val fold : ('acc -> tensor -> 'acc) -> 'acc -> t -> 'acc

fold f acc t folds leaf tensors in the same traversal order as iter.

Flatten and Rebuild

val flatten : t -> tensor list * (tensor list -> t)

flatten t is (leaves, rebuild) where:

  • leaves are t's leaf tensors in depth-first order;
  • rebuild new_leaves rebuilds t's structure with new_leaves.

rebuild raises Invalid_argument if new_leaves has a different length than leaves.

val flatten_with_paths : t -> (string * tensor) list

flatten_with_paths t returns (path, tensor) pairs where paths are dot-separated strings. Dict keys become path segments; list indices become decimal segments (e.g. "layers.0.weight").

If t is a tensor leaf, its path is the empty string.

The path encoding is injective for trees built with Dict Access, because Dict Access rejects keys containing '.', '[', or ']'.

Utilities

val zeros_like : t -> t

zeros_like t has the same structure as t, with each tensor replaced by Nx.zeros_like.

val count_parameters : t -> int

count_parameters t is the sum of Tensor.numel over all leaf tensors.

val pp : Stdlib.Format.formatter -> t -> unit

pp formats trees for debugging.