Module Kaun.Context
Per-call auxiliary data for layers.
A t carries read-only tensors (attention masks, position ids, encoder memory) that specific layers consume during a forward pass. Most layers ignore the context; transformer layers read from it by well-known key names.
let ctx =
Context.empty
|> Context.set ~name:"attention_mask" (Ptree.P mask)
|> Context.set ~name:"token_type_ids" (Ptree.P ids)
in
Layer.apply model vars ~training:false ~ctx input_idsTypes
Constructors
val empty : tempty is the empty context.
val set : name:string -> Ptree.tensor -> t -> tset ~name tensor ctx is ctx with name bound to tensor.
Shadows any previous binding for name.
Lookup
val find : t -> name:string -> Ptree.tensor optionfind ctx ~name is the tensor bound to name in ctx, if any.
val get_float_exn :
ctx:string ->
t ->
name:string ->
dtype:(float, 'l) Nx.dtype ->
(float, 'l) Nx.tget_float_exn ~ctx t ~name ~dtype is the float tensor bound to name, cast-checked against dtype.
Raises Invalid_argument if name is missing or has a different dtype. ctx is used in error messages.
get_int32_exn ~ctx t ~name is the int32 tensor bound to name.
Raises Invalid_argument if name is missing or has a different dtype.
val get_bool_exn : ctx:string -> t -> name:string -> (bool, Nx.bool_elt) Nx.tget_bool_exn ~ctx t ~name is the bool tensor bound to name.
Raises Invalid_argument if name is missing or has a different dtype.