Getting Started

This guide covers installation, data types, array creation, slicing, broadcasting, and basic operations.

Installation

opam install nx

Or build from source:

git clone https://github.com/raven-ml/raven
cd raven && dune build packages/nx

Add to your dune file:

(executable
 (name main)
 (libraries nx))

Creating Arrays

open Nx

let () =
  (* From explicit values: provide dtype, shape, and flat data *)
  let a = create Float32 [|2; 3|] [|1.; 2.; 3.; 4.; 5.; 6.|] in
  print_data a;

  (* Filled arrays *)
  let z = zeros Float32 [|3; 3|] in
  let o = ones Int32 [|5|] in
  let f = full Float64 [|2; 2|] 3.14 in
  ignore (z, o, f);

  (* Ranges and sequences *)
  let r = arange Int32 0 10 1 in          (* [0, 1, ..., 9] *)
  let l = linspace Float32 0. 1. 5 in     (* 5 points in [0, 1] *)
  ignore (r, l);

  (* Random arrays *)
  let x = rand Float32 [|3; 4|] in
  let y = randn Float32 [|3; 4|] in
  ignore (x, y);

  (* Special matrices *)
  let i = eye Float32 3 in               (* 3×3 identity *)
  print_data i

Data Types

Every array has a dtype that determines its element type. Common dtypes:

Dtype OCaml type Typical use
Float32 float Neural networks, images
Float64 float Scientific computing
Int32 int32 Integer data, indices
Int64 int64 Large integers
Bool bool Masks, conditions
Complex128 Complex.t Signal processing

Nx does not automatically cast between types. Convert explicitly with astype:

open Nx

let () =
  let x = create Int32 [|3|] [|1l; 2l; 3l|] in
  let y = astype Float32 x in
  print_data y   (* [1. 2. 3.] as float32 *)

Array Properties

open Nx

let () =
  let x = rand Float32 [|2; 3; 4|] in
  Printf.printf "shape: [|%s|]\n"
    (Array.to_list (shape x) |> List.map string_of_int |> String.concat "; ");
  Printf.printf "ndim: %d\n" (ndim x);         (* 3 *)
  Printf.printf "size: %d\n" (size x);          (* 24 *)
  Printf.printf "dtype: %s\n" (Dtype.to_string (dtype x))

Element-wise Operations

Binary operations work element-wise and support broadcasting:

open Nx

let () =
  let a = create Float32 [|3|] [|1.; 2.; 3.|] in
  let b = create Float32 [|3|] [|4.; 5.; 6.|] in

  let _ = add a b in       (* [5. 7. 9.] *)
  let _ = mul a b in       (* [4. 10. 18.] *)
  let _ = sub a b in       (* [-3. -3. -3.] *)
  let _ = div a b in       (* [0.25 0.4 0.5] *)

  (* Scalar operations *)
  let _ = add a (scalar Float32 10.) in   (* [11. 12. 13.] *)

  (* Math functions *)
  let _ = sin a in
  let _ = exp a in
  let _ = sqrt (abs a) in
  ()

Reductions

open Nx

let () =
  let x = create Float32 [|2; 3|] [|1.; 2.; 3.; 4.; 5.; 6.|] in

  (* Reduce all elements *)
  Printf.printf "sum = %.1f\n" (item [] (sum x));
  Printf.printf "mean = %.1f\n" (item [] (mean x));

  (* Reduce along an axis *)
  let col_sums = sum ~axes:[0] x in    (* sum each column *)
  print_data col_sums;   (* [5. 7. 9.] *)

  let row_sums = sum ~axes:[1] x in    (* sum each row *)
  print_data row_sums    (* [6. 15.] *)

Slicing and Indexing

Basic indexing

open Nx

let () =
  let x = create Int32 [|3; 3|] [|1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l|] in

  (* Get a row *)
  let row = get [1] x in           (* [4, 5, 6] *)
  print_data row;

  (* Get a scalar *)
  let v = item [1; 2] x in        (* 6l *)
  Printf.printf "x[1,2] = %ld\n" v

Advanced slicing

open Nx

let () =
  let x = create Int32 [|4; 4|]
    [|1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l;
      9l; 10l; 11l; 12l; 13l; 14l; 15l; 16l|] in

  (* Range: rows 0 to 2 (exclusive), all columns *)
  let sub = slice [R (0, 2); A] x in
  print_data sub;

  (* Single index on one axis, range on another *)
  let row1_cols = slice [I 1; R (0, 3)] x in
  print_data row1_cols;

  (* Gather specific indices *)
  let picked = slice [L [0; 3]; L [1; 2]] x in
  print_data picked

Index types: I i (single index), R (start, stop) (half-open range), Rs (start, stop, step) (strided range), L indices (gather), A (all), N (new axis).

Broadcasting

Operations automatically broadcast arrays with compatible shapes. Dimensions are aligned from the right, and each pair must be equal or one must be 1:

open Nx

let () =
  let matrix = ones Float32 [|3; 4|] in
  let row = create Float32 [|1; 4|] [|10.; 20.; 30.; 40.|] in
  let result = add matrix row in    (* row added to every row *)
  print_data result

Matrix Multiplication

open Nx

let () =
  let a = rand Float32 [|3; 4|] in
  let b = rand Float32 [|4; 2|] in
  let c = matmul a b in
  Printf.printf "(%d×%d) × (%d×%d) = (%d×%d)\n"
    (dim 0 a) (dim 1 a) (dim 0 b) (dim 1 b) (dim 0 c) (dim 1 c)

Next Steps