kaun ᚲ
PyTorch's ease. Flax's modularity. OCaml's type safety.
why kaun?
functional models
Models are immutable records. Parameters are data, not hidden state. Everything composes.
type-safe training
Catch shape mismatches at compile time. Never debug another runtime dimension error.
built on rune
Automatic differentiation built in. Your loss function is just a function.
pure optimizers
Optimizers are functions, not stateful objects. Perfect for distributed training.
show me the code
PyTorch
import torch import torch.nn as nn class MLP(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = torch.relu(self.fc1(x)) return self.fc2(x) model = MLP() optimizer = torch.optim.Adam(model.parameters())
KAUN
open Kaun (* Model is a record *) type model = { fc1: Linear.t; fc2: Linear.t; } (* Forward is a function *) let forward model x = x |> Linear.forward model.fc1 |> Activation.relu |> Linear.forward model.fc2 (* Initialize *) let rng = Rng.make 42 in let model = { fc1 = Linear.create rng ~input_dim:784 ~output_dim:128; fc2 = Linear.create rng ~input_dim:128 ~output_dim:10; }
training loop
(* Loss function *) let loss_fn model x y = let logits = forward model x in Loss.sigmoid_binary_cross_entropy ~targets:y logits (* Get gradients using Rune *) let loss, grads = value_and_grad loss_fn model x y (* Update with optimizer *) let optimizer = Optimizer.adam ~lr:0.001 () in let model', opt_state' = Optimizer.update optimizer opt_state model grads (* Pure functional - old model unchanged *)
what's implemented
Kaun is in early development. Here's what works today:
layers
- ✓ Linear (dense/fully-connected)
- ✓ Parameter trees for composition
- ⏳ Conv2d, BatchNorm (coming for alpha)
- ⏳ Dropout, LayerNorm (coming for alpha)
training
- ✓ SGD and Adam optimizers
- ✓ Binary cross-entropy loss
- ✓ Activation functions (relu, sigmoid, tanh)
- ⏳ More losses and metrics (coming for alpha)
design principles
Models are data. No classes, no inheritance. A model is just a record containing parameters. This makes serialization, inspection, and manipulation trivial.
Training is functional. Optimizers don't mutate state - they return new parameters. This enables techniques like checkpointing and distributed training without special frameworks.
Leverage Rune. We don't reimplement autodiff or device management. Kaun is a thin layer of neural network abstractions over Rune's primitives.
get started
Kaun isn't released yet. For now, check out the documentation to learn more.
When it's ready:
# Install opam install kaun # Try it open Kaun (* XOR problem *) let x = Tensor.of_float_list [|4; 2|] [0.; 0.; 0.; 1.; 1.; 0.; 1.; 1.] let y = Tensor.of_float_list [|4; 1|] [0.; 1.; 1.; 0.] (* Train a model *) let model = train_xor x y ~epochs:1000