rune

JAX's transformations. OCaml's guarantees. Differentiable everything.

ᚠ ᚢ ᚦ ᚨ ᚱ ᚲ

why rune?

composable transformations

grad, jit, vmap - they compose like functions should. Build complex ML systems from simple parts.

effect-based autodiff

OCaml 5's effects make automatic differentiation elegant. No tape, no graph - just functions.

multi-device support

Same code runs on CPU or Metal. Device placement is explicit and type-safe.

type-safe gradients

Shape errors at compile time. Device mismatches impossible. Your gradients always match your parameters.


show me the code

JAX

import jax
import jax.numpy as jnp
from jax import grad, jit

# Define function
def f(x):
    return jnp.sum(x ** 2)

# Transform it
grad_f = grad(f)
fast_grad_f = jit(grad_f)

RUNE

open Rune

(* Define function *)
let f x = 
  sum (mul x x)

(* Transform it *)
let grad_f = grad f
let fast_grad_f = jit grad_f

automatic differentiation

Rune uses OCaml's effect system to implement autodiff. Write normal functions, get derivatives for free:

(* Any function works *)
let my_function x =
  let y = sin x in
  let z = mul x y in
  sum z

(* Get gradient function *)
let df_dx = grad my_function

(* Compute value and gradient together *)
let value, gradient = value_and_grad my_function x

device placement

(* CPU computation *)
let x = rand cpu Float32 [|100|]

(* Metal GPU (macOS) *)
let gpu = metal () in
let y = rand gpu Float32 [|100|]

(* Operations run on tensor's device *)
let z = add y y  (* runs on GPU *)

neural network example

(* Simple two-layer network *)
let mlp w1 b1 w2 b2 x =
  let h = add (matmul x w1) b1 in
  let h = maximum h (zeros_like h) in  (* ReLU *)
  add (matmul h w2) b2

(* Loss function *)
let loss params x y =
  let [w1; b1; w2; b2] = params in
  let pred = mlp w1 b1 w2 b2 x in
  mean (mul (sub pred y) (sub pred y))

(* Get gradients for all parameters *)
let grad_loss = grads loss

(* Training step *)
let update params x y lr =
  let grads = grad_loss params x y in
  List.map2 
    (fun p g -> sub p (mul (scalar cpu Float32 lr) g))
    params grads

what's coming

Rune works today for automatic differentiation. Post-v1, we're adding:


ecosystem

kaun - neural networks

High-level neural network library built on Rune. Layers, optimizers, and training loops that just work.

Learn more →

sowilo - computer vision

Differentiable image processing. Every operation supports autodiff.

Learn more →


get started

Rune isn't released yet. For now, check out the documentation to learn more.

When it's ready:

# Install
opam install rune

# Try it
open Rune

let () = 
  (* Define a function *)
  let f x = sum (mul x x) in
  
  (* Get its gradient *)
  let grad_f = grad f in
  
  (* Test it *)
  let x = scalar cpu Float32 3.0 in
  let g = grad_f x in
  Printf.printf "f(3) = %.1f\n" (unsafe_get (f x) [||]);
  Printf.printf "f'(3) = %.1f\n" (unsafe_get g [||])