Array Operations
This guide covers reshaping, broadcasting, joining, slicing, and the view model that underlies Nx's efficiency.
Views and Copies
Many Nx operations return views — tensors that share the underlying buffer with the original but have different shape, strides, or offset. Views are O(1) and allocate no new data.
View-producing operations: reshape, transpose, slice, squeeze, unsqueeze, flip, get, moveaxis, swapaxes.
Copy-producing operations: contiguous, copy, concatenate, stack, pad, element-wise operations.
Use is_c_contiguous to check whether elements are laid out contiguously in row-major order, and contiguous to force a copy when needed:
let t = Nx.transpose x in
Nx.is_c_contiguous t (* often false *)
let t' = Nx.contiguous t (* force a contiguous copy *)
Reshaping
reshape
Change the shape without changing the data order. The total number of elements must match. Use -1 to infer one dimension:
open Nx
let () =
let x = create Int32 [|6|] [|1l; 2l; 3l; 4l; 5l; 6l|] in
let a = reshape [|2; 3|] x in
let b = reshape [|3; -1|] x in (* -1 inferred as 2 *)
print_data a;
print_data b
flatten and unflatten
flatten collapses dimensions into one. unflatten expands a dimension back:
let x = Nx.zeros Nx.Float32 [|2; 3; 4|] in
Nx.flatten x |> Nx.shape (* [|24|] *)
Nx.flatten ~start_dim:1 x |> Nx.shape (* [|2; 12|] *)
let y = Nx.zeros Nx.Float32 [|2; 12|] in
Nx.unflatten 1 [|3; 4|] y |> Nx.shape (* [|2; 3; 4|] *)
squeeze and unsqueeze
Remove or add dimensions of size 1:
open Nx
let () =
let x = ones Float32 [|1; 3; 1; 4|] in
let a = squeeze x in (* [|3; 4|] *)
let b = squeeze ~axes:[0] x in (* [|3; 1; 4|] *)
Printf.printf "squeeze all: %dx%d\n" (dim 0 a) (dim 1 a);
Printf.printf "squeeze [0]: %dx%dx%d\n" (dim 0 b) (dim 1 b) (dim 2 b);
let y = create Float32 [|3|] [|1.; 2.; 3.|] in
let c = unsqueeze ~axes:[0; 2] y in (* [|1; 3; 1|] *)
Printf.printf "unsqueeze: %dx%dx%d\n" (dim 0 c) (dim 1 c) (dim 2 c)
Broadcasting
Binary operations automatically broadcast operands. Dimensions are aligned from the right, and each pair must be equal or one must be 1:
open Nx
let () =
(* Add a row vector to every row of a matrix *)
let matrix = ones Float32 [|3; 4|] in
let row = create Float32 [|1; 4|] [|10.; 20.; 30.; 40.|] in
let result = add matrix row in
print_data result;
(* Add a column vector to every column *)
let col = create Float32 [|3; 1|] [|100.; 200.; 300.|] in
let result2 = add matrix col in
print_data result2
You can also broadcast explicitly:
let x = Nx.broadcast_to [|3; 3|] (Nx.create Nx.Float32 [|1; 3|] [|1.; 2.; 3.|])
(* Repeats the row 3 times without copying data *)
Broadcasting rules
Shapes are compatible when, aligned from the right, every dimension pair is either equal or one of them is 1. The result shape takes the maximum at each position.
[| 3; 4|] + [|1; 4|] → [|3; 4|] ✓
[|2; 3; 4|] + [| 4|] → [|2; 3; 4|] ✓
[| 3; 4|] + [|3; 1|] → [|3; 4|] ✓
[| 3|] + [| 4|] → error ✗
Transposing and Permuting
transpose
Reverse dimensions (no copy):
open Nx
let () =
let x = create Int32 [|2; 3|] [|1l; 2l; 3l; 4l; 5l; 6l|] in
let t = transpose x in
print_data t
(* [[1, 4],
[2, 5],
[3, 6]] *)
Specify a permutation for higher-rank tensors:
(* Permute [batch; height; width; channels] to [batch; channels; height; width] *)
let nhwc_to_nchw x = Nx.transpose ~axes:[0; 3; 1; 2] x
moveaxis and swapaxes
Move or swap individual dimensions:
Nx.moveaxis 0 2 x (* move axis 0 to position 2 *)
Nx.swapaxes 1 2 x (* swap axes 1 and 2 *)
flip
Reverse elements along axes:
Nx.flip ~axes:[1] x (* mirror columns *)
Nx.flip x (* reverse all dimensions *)
Indexing and Slicing
get
Index from the outermost dimension inward. Returns a sub-tensor (view):
open Nx
let () =
let x = create Int32 [|2; 3|] [|1l; 2l; 3l; 4l; 5l; 6l|] in
let row = get [1] x in (* second row: [4, 5, 6] *)
print_data row
item
Extract a scalar value:
let v = Nx.item [1; 2] matrix (* element at row 1, column 2 *)
slice
Advanced indexing with range and index specifications:
open Nx
let () =
let x = create Int32 [|3; 3|] [|1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l|] in
(* R (start, stop): half-open range *)
let rows_0_1 = slice [R (0, 2); A] x in
print_data rows_0_1;
(* I i: single index (reduces dimension) *)
let col_1 = slice [A; I 1] x in
print_data col_1;
(* L [indices]: gather specific indices *)
let corners = slice [L [0; 2]; L [0; 2]] x in
print_data corners
Index types:
I i— single index (reduces dimension)R (start, stop)— half-open rangeRs (start, stop, step)— strided rangeL indices— gather listed indicesA— all elements (default for trailing axes)N— insert new axis of size 1
Joining and Splitting
concatenate
Join tensors along an existing axis:
open Nx
let () =
let a = ones Float32 [|2; 3|] in
let b = zeros Float32 [|2; 3|] in
let c = concatenate ~axis:0 [a; b] in (* [|4; 3|] *)
Printf.printf "concat axis 0: %dx%d\n" (dim 0 c) (dim 1 c);
let d = concatenate ~axis:1 [a; b] in (* [|2; 6|] *)
Printf.printf "concat axis 1: %dx%d\n" (dim 0 d) (dim 1 d)
Shorthands: vstack (axis 0), hstack (axis 1), dstack (axis 2).
stack
Join tensors along a new axis:
open Nx
let () =
let a = create Float32 [|3|] [|1.; 2.; 3.|] in
let b = create Float32 [|3|] [|4.; 5.; 6.|] in
let c = stack ~axis:0 [a; b] in (* [|2; 3|] *)
print_data c
split
Split a tensor into equal parts along an axis:
let parts = Nx.split ~axis:0 2 x (* split into 2 along axis 0 *)
Tiling and Repeating
tile
Replicate the tensor according to a repeat pattern:
(* Tile a [2; 3] tensor 2x along rows, 3x along columns → [4; 9] *)
Nx.tile [|2; 3|] x
repeat
Repeat elements along a single axis:
(* Repeat each element 3 times along axis 0 *)
Nx.repeat ~axis:0 3 x
pad
Pad with a constant value:
(* Pad: 1 before and 2 after along axis 0, 0 and 1 along axis 1 *)
Nx.pad [|(1, 2); (0, 1)|] 0. x
Next Steps
- Linear Algebra — matrix operations, decompositions, FFT
- Input/Output — reading and writing images, npy, npz files
- NumPy Comparison — side-by-side reference