kaun

PyTorch's ease. Flax's modularity. OCaml's type safety.


why kaun?

functional models

Models are immutable records. Parameters are data, not hidden state. Everything composes.

type-safe training

Catch shape mismatches at compile time. Never debug another runtime dimension error.

built on rune

Automatic differentiation built in. Your loss function is just a function.

pure optimizers

Optimizers are functions, not stateful objects. Perfect for distributed training.


show me the code

PyTorch

import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = MLP()
optimizer = torch.optim.Adam(model.parameters())

KAUN

open Kaun

(* Model is a record *)
type model = {
  fc1: Linear.t;
  fc2: Linear.t;
}

(* Forward is a function *)
let forward model x =
  x
  |> Linear.forward model.fc1
  |> Activation.relu
  |> Linear.forward model.fc2

(* Initialize *)
let rng = Rng.make 42 in
let model = {
  fc1 = Linear.create rng ~input_dim:784 ~output_dim:128;
  fc2 = Linear.create rng ~input_dim:128 ~output_dim:10;
}

training loop

(* Loss function *)
let loss_fn model x y =
  let logits = forward model x in
  Loss.sigmoid_binary_cross_entropy ~targets:y logits

(* Get gradients using Rune *)
let loss, grads = value_and_grad loss_fn model x y

(* Update with optimizer *)
let optimizer = Optimizer.adam ~lr:0.001 () in
let model', opt_state' = Optimizer.update optimizer opt_state model grads

(* Pure functional - old model unchanged *)

what's implemented

Kaun is in early development. Here's what works today:

layers

  • ✓ Linear (dense/fully-connected)
  • ✓ Parameter trees for composition
  • ⏳ Conv2d, BatchNorm (coming for alpha)
  • ⏳ Dropout, LayerNorm (coming for alpha)

training

  • ✓ SGD and Adam optimizers
  • ✓ Binary cross-entropy loss
  • ✓ Activation functions (relu, sigmoid, tanh)
  • ⏳ More losses and metrics (coming for alpha)

design principles

Models are data. No classes, no inheritance. A model is just a record containing parameters. This makes serialization, inspection, and manipulation trivial.

Training is functional. Optimizers don't mutate state - they return new parameters. This enables techniques like checkpointing and distributed training without special frameworks.

Leverage Rune. We don't reimplement autodiff or device management. Kaun is a thin layer of neural network abstractions over Rune's primitives.


get started

Kaun isn't released yet. For now, check out the documentation to learn more.

When it's ready:

# Install
opam install kaun

# Try it
open Kaun

(* XOR problem *)
let x = Tensor.of_float_list [|4; 2|] [0.; 0.; 0.; 1.; 1.; 0.; 1.; 1.]
let y = Tensor.of_float_list [|4; 1|] [0.; 1.; 1.; 0.]

(* Train a model *)
let model = train_xor x y ~epochs:1000