Checkpoints and Pretrained Models
This guide covers saving and loading model parameters with SafeTensors, and downloading pretrained weights from the HuggingFace Hub.
SafeTensors Checkpointing
Kaun serializes parameter trees to the
SafeTensors format. Tensor
paths from the tree structure become file keys (e.g. layers.0.weight).
Saving
let vars = Train.vars st in
Checkpoint.save "model.safetensors" (Layer.params vars)
Loading
Checkpoint.load requires a ~like template that defines the expected
tree structure and dtypes. Tensors are cast to the template's dtype if
needed. Extra keys in the file are ignored.
(* Initialize model to get the tree structure *)
let vars = Layer.init model ~dtype:Nx.Float32 in
let params = Checkpoint.load "model.safetensors" ~like:(Layer.params vars) in
let vars = Layer.with_params vars params
Saving and Loading State
To save both parameters and non-trainable state (e.g. batch norm running statistics):
(* Save *)
let vars = Train.vars st in
Checkpoint.save "params.safetensors" (Layer.params vars);
Checkpoint.save "state.safetensors" (Layer.state vars)
(* Load *)
let vars = Layer.init model ~dtype:Nx.Float32 in
let params = Checkpoint.load "params.safetensors" ~like:(Layer.params vars) in
let state = Checkpoint.load "state.safetensors" ~like:(Layer.state vars) in
let vars = Layer.with_params vars params |> fun v -> Layer.with_state v state
Resuming Training
Use Train.make_state to create training state from loaded weights:
let trainer = Train.make ~model ~optimizer in
let st = Train.make_state trainer vars in
(* Continue training from here *)
let st = Train.fit trainer st data
HuggingFace Hub
The kaun-hf package provides access to the HuggingFace Hub for
downloading pretrained model weights and configurations.
Downloading Files
let path = Kaun_hf.download_file ~model_id:"bert-base-uncased"
~filename:"config.json" ()
(* path : string — local filesystem path *)
Files are cached under $RAVEN_CACHE_ROOT/huggingface (or
$XDG_CACHE_HOME/raven/huggingface). Subsequent calls return the cached
path.
Options:
~token— HuggingFace API token for private repositories. Defaults to theHF_TOKENenvironment variable.~cache_dir— override the default cache directory.~offline:true— only return cached files, do not download.~revision:(Rev "v1.0")— download a specific tag, branch, or commit. Default isMain.
Loading Configuration
let config = Kaun_hf.load_config ~model_id:"bert-base-uncased" ()
(* config : Jsont.json *)
Returns the parsed config.json from the repository.
Loading Weights
let weights = Kaun_hf.load_weights ~model_id:"bert-base-uncased" ()
(* weights : (string * Kaun.Ptree.tensor) list *)
Returns a flat list of (name, tensor) pairs from the model's
SafeTensors checkpoint. Sharded checkpoints are handled transparently:
when model.safetensors.index.json is present, all shards are
downloaded and merged.
Tensor names are the raw keys from the SafeTensors file (e.g.
bert.encoder.layer.0.attention.self.query.weight). Your model code
maps these to its own parameter structure.
Loading a Pretrained Model
The typical pattern for loading pretrained weights:
- Build the model architecture from the config.
- Initialize to get the parameter tree structure.
- Load weights and map them to the tree.
(* 1. Build model from config *)
let config = Kaun_hf.load_config ~model_id:"bert-base-uncased" () in
let model = build_bert_model config in
(* 2. Initialize to get tree structure *)
let vars = Layer.init model ~dtype:Nx.Float32 in
(* 3. Load and map weights *)
let weights = Kaun_hf.load_weights ~model_id:"bert-base-uncased" () in
let params = map_weights_to_ptree weights (Layer.params vars) in
let vars = Layer.with_params vars params in
(* 4. Use for inference *)
let trainer = Train.make ~model
~optimizer:(Optim.adam ~lr:(Optim.Schedule.constant 1e-5) ())
in
let st = Train.make_state trainer vars in
let logits = Train.predict trainer st input_ids
Cache Management
(* Clear all cached files *)
Kaun_hf.clear_cache ()
(* Clear a specific model's cache *)
Kaun_hf.clear_cache ~model_id:"bert-base-uncased" ()
Next Steps
- Getting Started — XOR and MNIST examples
- Layers and Models — layer catalog, composition, custom layers
- Training — optimizers, losses, data pipelines, training loops