Kaun vs. PyTorch / Flax -- A Practical Comparison
This guide explains how Kaun relates to PyTorch and Flax, focusing on:
- How core concepts map (modules/layers, parameters, training loops)
- Where the APIs feel similar vs. deliberately different
- How to translate common patterns between frameworks
Kaun's design is closer to Flax than to PyTorch: layers are pure data, parameters are explicit trees, and forward passes are functions rather than method calls. If you know Flax, Kaun will feel familiar. If you know only PyTorch, the main shift is from mutable objects to immutable records.
1. Big-Picture Differences
| Aspect | PyTorch | Flax (Linen) | Kaun (OCaml) |
|---|---|---|---|
| Language | Python, dynamic | Python (JAX), dynamic | OCaml, statically typed |
| Model definition | nn.Module class with forward |
nn.Module class with __call__ |
Layer.t record with init and apply |
| Parameter storage | Mutable attributes on module | Frozen dict returned by init |
Ptree.t tree returned by Layer.init |
| Forward pass | model(x) (stateful method) |
model.apply(params, x) |
Layer.apply model vars ~training x |
| Mutation | Modules are mutable objects | Params are immutable dicts | Layer.vars and Ptree.t are immutable |
| Autograd | Dynamic tape (autograd) |
Functional transforms (jax.grad) |
Rune effect-based autodiff |
| Optimizer | torch.optim.Adam(model.parameters(), lr=...) |
optax.adam(lr) |
Optim.adam ~lr:(Schedule.constant lr) () |
| Training loop | Manual (or Lightning/etc.) | Manual (or Orbax/etc.) | Train.fit or manual Train.step |
| Data loading | DataLoader |
tf.data or manual |
Data.t lazy pipeline |
| Checkpointing | torch.save / torch.load (pickle) |
Orbax / msgpack | SafeTensors via Checkpoint.save / Checkpoint.load |
| RNG | Global torch.manual_seed |
Explicit PRNGKey threading | Implicit scope via Nx.Rng.run ~seed |
| Device management | model.to("cuda"), tensor.cuda() |
jax.device_put |
CPU by default; JIT manages devices internally |
2. Defining Models
PyTorch
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, in_features, hidden, out_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden)
self.fc2 = nn.Linear(hidden, out_features)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
model = MLP(784, 128, 10)
Flax
import flax.linen as nn
import jax
class MLP(nn.Module):
hidden: int
out_features: int
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(self.hidden)(x))
return nn.Dense(self.out_features)(x)
model = MLP(hidden=128, out_features=10)
params = model.init(jax.random.PRNGKey(0), jnp.ones([1, 784]))
Kaun
open Kaun
let model =
Layer.sequential
[
Layer.linear ~in_features:784 ~out_features:128 ();
Layer.relu ();
Layer.linear ~in_features:128 ~out_features:10 ();
]
let vars =
Nx.Rng.run ~seed:0 @@ fun () ->
Layer.init model ~dtype:Nx.Float32
Key differences:
- PyTorch defines models as classes. Flax defines models as dataclasses
with
__call__. Kaun usesLayer.trecords -- plain data, not classes. Layer.sequentialreplaces class-based composition for homogeneous float pipelines.Layer.composehandles heterogeneous types (e.g. embedding into dense).- Activation functions are layers (
Layer.relu ()) rather than free functions called insideforward. This keeps the composition uniform.
3. Parameters
PyTorch
# Parameters live inside the module
for name, param in model.named_parameters():
print(name, param.shape)
# state_dict is an OrderedDict
sd = model.state_dict()
model.load_state_dict(sd)
Flax
# Params are a frozen dict returned by init
params = model.init(key, x)["params"]
jax.tree_util.tree_map(lambda p: p.shape, params)
Kaun
(* vars bundles params, state, and dtype *)
let params = Layer.params vars (* Ptree.t *)
let state = Layer.state vars (* Ptree.t *)
let dt = Layer.dtype vars (* (float, 'layout) Nx.dtype *)
(* Inspect parameter shapes *)
let paths = Ptree.flatten_with_paths params
(* [("0.weight", P tensor); ("0.bias", P tensor); ...] *)
(* Count total parameters *)
let n = Ptree.count_parameters params
(* Replace parameters *)
let vars' = Layer.with_params vars new_params
Key differences:
- PyTorch stores parameters as mutable module attributes. Flax returns
frozen dicts. Kaun returns
Ptree.t-- a tree withTensorleaves,Dictnodes, andListnodes. Ptree.tis plain immutable data. You can map, fold, flatten, and serialize it without going through the model.Layer.varsalso carries non-trainable state (e.g. batch norm running statistics), separate from trainable parameters.
4. Forward Pass
PyTorch
model.train()
output = model(x) # stateful: dropout active, batchnorm updates
model.eval()
with torch.no_grad():
output = model(x) # no dropout, batchnorm uses running stats
Flax
output = model.apply(params, x)
output = model.apply(params, x, train=True, rngs={"dropout": key})
Kaun
(* Training: dropout active, batchnorm updates running stats *)
let output, vars' = Layer.apply model vars ~training:true x
(* Evaluation: no dropout, batchnorm uses running stats *)
let output, vars' = Layer.apply model vars ~training:false x
(* Or through the trainer *)
let logits = Train.predict trainer st x
Key differences:
- PyTorch uses
model.train()/model.eval()to switch mode globally. Kaun passes~trainingas an argument on each call. Layer.applyreturns(output, updated_vars). The updated vars carry new state (e.g. batch norm statistics). Parameters are unchanged.Train.predictis a shortcut for evaluation mode with no state updates.
5. Optimizers and LR Schedules
PyTorch
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10000)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
Flax / Optax
import optax
tx = optax.adam(learning_rate=optax.cosine_decay_schedule(1e-3, 10000))
opt_state = tx.init(params)
updates, opt_state = tx.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
Kaun
(* Schedule is a function: step -> lr *)
let schedule = Optim.Schedule.cosine_decay ~init_value:1e-3 ~decay_steps:10000 ()
(* Optimizer takes the schedule directly *)
let optimizer = Optim.adam ~lr:schedule ()
(* Manual update *)
let st = Optim.init optimizer params in
let updates, st' = Optim.step optimizer st params grads in
let params' = Optim.apply_updates params updates
(* Or use the convenience function *)
let params', st' = Optim.update optimizer st params grads
Available optimizers:
Optim.sgd ~lr:schedule ~momentum:0.9 ~nesterov:true ()
Optim.adam ~lr:schedule ~b1:0.9 ~b2:0.999 ~eps:1e-8 ()
Optim.adamw ~lr:schedule ~weight_decay:0.01 ()
Optim.rmsprop ~lr:schedule ~decay:0.9 ~momentum:0.0 ()
Optim.adagrad ~lr:schedule ~eps:1e-8 ()
Available schedules:
Optim.Schedule.constant 1e-3
Optim.Schedule.cosine_decay ~init_value:1e-3 ~decay_steps:10000 ()
Optim.Schedule.warmup_cosine ~init_value:0. ~peak_value:1e-3 ~warmup_steps:1000
Optim.Schedule.warmup_linear ~init_value:0. ~peak_value:1e-3 ~warmup_steps:1000
Optim.Schedule.exponential_decay ~init_value:1e-3 ~decay_rate:0.96 ~decay_steps:1000
Key differences:
- PyTorch couples the optimizer to the model via
model.parameters(). Kaun and Optax are decoupled -- they operate on parameter trees. - PyTorch separates scheduler from optimizer. Kaun (like Optax) bakes
the schedule into the optimizer via the
~lrargument. - A Kaun schedule is just
int -> float. Compose them by writing a plain OCaml function.
6. Loss Functions
PyTorch
loss = nn.functional.cross_entropy(logits, labels)
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels)
loss = nn.functional.mse_loss(pred, target)
Kaun
(* Multi-class with one-hot labels *)
Loss.cross_entropy logits one_hot_labels
(* Multi-class with integer labels *)
Loss.cross_entropy_sparse logits class_indices
(* Binary classification (raw logits, not sigmoid) *)
Loss.binary_cross_entropy logits labels
(* Regression *)
Loss.mse predictions targets
Loss.mae predictions targets
Key differences:
- PyTorch's
cross_entropyexpects integer labels (likecross_entropy_sparse). Kaun offers both one-hot and integer variants. - All Kaun losses return scalar means and are differentiable through Rune's autodiff.
- Kaun losses are plain functions, not module methods. There is no
nn.CrossEntropyLoss()class.
7. Training Loops
PyTorch (manual loop)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(10):
for x_batch, y_batch in dataloader:
optimizer.zero_grad()
logits = model(x_batch)
loss = nn.functional.cross_entropy(logits, y_batch)
loss.backward()
optimizer.step()
print(f"loss: {loss.item():.4f}")
Kaun with Train.fit
let trainer =
Train.make ~model
~optimizer:(Optim.adam ~lr:(Optim.Schedule.constant 1e-3) ())
let st = Nx.Rng.run ~seed:42 @@ fun () ->
Train.init trainer ~dtype:Nx.Float32
(* Train over a data pipeline *)
let st =
Train.fit trainer st
~report:(fun ~step ~loss _st ->
Printf.printf "step %d loss %.4f\n" step loss)
data
Train.fit takes a Data.t where each element is (input, loss_fn).
The loss function receives the model output and returns a scalar loss.
Gradient computation, optimizer step, and state threading are handled
internally.
Kaun with Train.step (manual loop)
For fine-grained control, use Train.step directly:
let st = ref (Nx.Rng.run ~seed:42 @@ fun () ->
Train.init trainer ~dtype:Nx.Float32)
let () =
Data.iter
(fun (x, y) ->
let loss, st' =
Train.step trainer !st ~training:true
~loss:(fun logits -> Loss.cross_entropy_sparse logits y)
x
in
st := st';
Printf.printf "loss: %.4f\n" (Nx.item [] loss))
data
Early stopping
Raise Train.Early_stop inside the ~report callback:
let st =
Train.fit trainer st
~report:(fun ~step:_ ~loss _st ->
if loss < 0.001 then raise Train.Early_stop)
data
Key differences:
- PyTorch training loops are fully manual: zero gradients, forward,
backward, step. Kaun's
Train.fithandles the entire loop. Train.stepis the escape hatch for custom loops, but you never callbackwardorzero_grad-- differentiation is implicit.- State threading replaces mutation.
Train.fitreturns the final state;Train.stepreturns(loss, new_state).
8. Data Loading
PyTorch
from torch.utils.data import DataLoader, TensorDataset
dataset = TensorDataset(x_train, y_train)
loader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
for x_batch, y_batch in loader:
...
Kaun
(* From tensor pairs -- the common case *)
let data =
Data.prepare ~shuffle:true ~batch_size:64 (x_train, y_train)
|> Data.map (fun (x, y) ->
(x, fun logits -> Loss.cross_entropy_sparse logits y))
(* From arrays *)
let data = Data.of_array examples |> Data.shuffle |> Data.batch 32
(* From a generator function *)
let data = Data.of_fn 10000 generate_example
(* Repeat a fixed example (useful for toy problems) *)
let data = Data.repeat 1000 (x, loss_fn)
(* Consumers *)
Data.iter process data
Data.fold accumulate init data
let arr = Data.to_array data
Key differences:
- PyTorch uses
Dataset+DataLoaderclasses with worker processes for parallel loading. Kaun usesData.t, a lazy composable iterator. Data.prepareis the standard shortcut: it slices tensors, optionally shuffles, and batches in one call.~drop_lastdefaults totrue.- Pipelines are single-pass. Call
Data.resetbefore iterating again (e.g. between epochs). Data.mapattaches the loss function to each batch, producing the(input, loss_fn)pairs thatTrain.fitexpects.
9. Checkpointing
PyTorch
# Save
torch.save(model.state_dict(), "model.pt")
# Load
model.load_state_dict(torch.load("model.pt"))
Kaun
(* Save parameters *)
let vars = Train.vars st in
Checkpoint.save "model.safetensors" (Layer.params vars)
(* Load parameters *)
let vars = Layer.init model ~dtype:Nx.Float32 in
let params = Checkpoint.load "model.safetensors" ~like:(Layer.params vars) in
let vars = Layer.with_params vars params
(* Save both params and state (e.g. batch norm stats) *)
Checkpoint.save "params.safetensors" (Layer.params vars);
Checkpoint.save "state.safetensors" (Layer.state vars)
(* Resume training from loaded weights *)
let st = Train.make_state trainer vars
Key differences:
- PyTorch uses Python pickle by default (arbitrary code execution risk). Kaun uses SafeTensors -- a flat, memory-mappable format with no code execution.
Checkpoint.loadrequires a~liketemplate defining the expected tree structure and dtypes. Extra keys in the file are ignored, and tensors are cast to the template's dtype if needed.- Pretrained weights from HuggingFace Hub are available via
Kaun_hf.load_weights.
10. Quick Cheat Sheet
| Task | PyTorch | Kaun |
|---|---|---|
| Define a model | class M(nn.Module): ... |
Layer.sequential [Layer.linear ...; Layer.relu (); ...] |
| Initialize parameters | model = M() (implicit) |
Layer.init model ~dtype:Nx.Float32 |
| Forward pass (training) | model.train(); y = model(x) |
Layer.apply model vars ~training:true x |
| Forward pass (eval) | model.eval(); y = model(x) |
Train.predict trainer st x |
| Count parameters | sum(p.numel() for p in model.parameters()) |
Ptree.count_parameters (Layer.params vars) |
| Create optimizer | Adam(model.parameters(), lr=1e-3) |
Optim.adam ~lr:(Optim.Schedule.constant 1e-3) () |
| Cosine decay schedule | CosineAnnealingLR(opt, T_max=N) |
Optim.Schedule.cosine_decay ~init_value:lr ~decay_steps:N () |
| Compute loss | F.cross_entropy(logits, labels) |
Loss.cross_entropy_sparse logits labels |
| Training step | zero_grad(); loss.backward(); opt.step() |
Train.step trainer st ~training:true ~loss x |
| Full training loop | Manual for loop |
Train.fit trainer st data |
| Early stopping | Manual condition check | raise Train.Early_stop inside ~report |
| Gradient clipping | clip_grad_norm_(model.parameters(), max_norm) |
Optim.clip_by_global_norm max_norm grads |
| Data loading | DataLoader(dataset, batch_size=64, shuffle=True) |
Data.prepare ~shuffle:true ~batch_size:64 (x, y) |
| Save checkpoint | torch.save(model.state_dict(), path) |
Checkpoint.save path (Layer.params vars) |
| Load checkpoint | model.load_state_dict(torch.load(path)) |
Checkpoint.load path ~like:(Layer.params vars) |
| Compose heterogeneous layers | Define inside forward |
Layer.compose embedding_layer dense_layer |
| Dropout | nn.Dropout(p=0.1) |
Layer.dropout ~rate:0.1 () |
| Batch normalization | nn.BatchNorm2d(32) |
Layer.batch_norm ~num_features:32 () |
| Layer normalization | nn.LayerNorm(128) |
Layer.layer_norm ~dim:128 () |
| Set RNG seed | torch.manual_seed(42) |
Nx.Rng.run ~seed:42 @@ fun () -> ... |