04-reshaping-and-broadcasting
Change array shapes and let broadcasting align dimensions automatically. This example reshapes a flat signal into frames, centers data by subtracting column means, and builds an outer product — all without explicit loops.
dune exec nx/examples/04-reshaping-and-broadcasting/main.exe
What You'll Learn
- Reshaping flat arrays into multi-dimensional frames with
reshape - Flattening back to 1D with
flatten - Transposing rows and columns
- Stacking arrays vertically and horizontally:
vstack,hstack - Broadcasting: how
keepdimsenables operations on different-shaped arrays - Building outer products via broadcasting
- Adding and removing dimensions with
expand_dimsandsqueeze
Key Functions
| Function | Purpose |
|---|---|
reshape shape t |
Change array shape (total elements must match) |
flatten t |
Collapse all dimensions into 1D |
transpose t |
Reverse all axes (swap rows and columns) |
vstack ts |
Stack arrays vertically (along axis 0) |
hstack ts |
Stack arrays horizontally (along axis 1) |
expand_dims axes t |
Insert size-1 dimensions at specified positions |
squeeze t |
Remove all size-1 dimensions |
mean ~keepdims:true |
Reduce while keeping axis as size 1 (for broadcasting) |
Output Walkthrough
Reshape a flat 12-element signal into a 3×4 matrix of frames:
let signal = arange_f float64 0.0 12.0 1.0 in
let frames = reshape [| 3; 4 |] signal
Flat signal (12 samples):
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
Reshaped into 3 frames of 4:
[[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]]
Broadcasting in action
Subtracting column means from data. The keepdims:true parameter gives
the mean shape [1; 3] instead of [3], which broadcasts against [4; 3]:
let col_means = mean ~axes:[ 0 ] ~keepdims:true data in
let centered = data - col_means
Outer product via broadcasting
Reshape vectors into compatible shapes and multiply — no loops needed:
let outer = reshape [| 4; 1 |] x * reshape [| 1; 3 |] y
Outer product (x × y):
[[10, 20, 30],
[20, 40, 60],
[30, 60, 90],
[40, 80, 120]]
Broadcasting Rules
Two dimensions are compatible for broadcasting when they are either:
- Equal, or
- One of them is 1
When dimensions differ, the size-1 dimension is stretched to match. This is
why keepdims:true is essential for reductions used in arithmetic.
Try It
- Reshape the signal into
[4; 3]instead of[3; 4]and compare with the transpose of the original frames. - Stack three 1D arrays of different values with
vstack, then compute row-wise means usingmean ~axes:[1]. - Compute an outer product of two vectors of different lengths (e.g., 5 and 3)
using
reshapeand broadcasting.
Next Steps
Continue to 05-reductions-and-statistics to learn how to summarize data with aggregations along any axis.
(** Change array shapes and let broadcasting align dimensions automatically.
Reshape a flat signal into frames, center data by subtracting column means
(broadcasting in action), and build an outer product without any loops. *)
open Nx
open Nx.Infix
let () =
(* --- Reshape: flat signal → frames --- *)
let signal = arange_f float64 0.0 12.0 1.0 in
Printf.printf "Flat signal (12 samples):\n%s\n\n" (data_to_string signal);
let frames = reshape [| 3; 4 |] signal in
Printf.printf "Reshaped into 3 frames of 4:\n%s\n\n" (data_to_string frames);
let flat_again = flatten frames in
Printf.printf "Flattened back: %s\n\n" (data_to_string flat_again);
(* --- Transpose: swap rows and columns --- *)
Printf.printf "Transposed:\n%s\n\n" (data_to_string (transpose frames));
(* --- Stacking arrays --- *)
let a = create float64 [| 3 |] [| 1.0; 2.0; 3.0 |] in
let b = create float64 [| 3 |] [| 4.0; 5.0; 6.0 |] in
Printf.printf "vstack [a; b]:\n%s\n" (data_to_string (vstack [ a; b ]));
Printf.printf "hstack [a; b]: %s\n\n" (data_to_string (hstack [ a; b ]));
(* --- Broadcasting: subtract column means to center data --- *)
let data =
create float64 [| 4; 3 |]
[|
10.0;
200.0;
3000.0;
20.0;
400.0;
1000.0;
30.0;
100.0;
2000.0;
40.0;
300.0;
4000.0;
|]
in
Printf.printf "Raw data (4 samples × 3 features):\n%s\n" (data_to_string data);
(* Mean along axis 0 with keepdims — shape [1; 3] broadcasts against [4;
3]. *)
let col_means = mean ~axes:[ 0 ] ~keepdims:true data in
Printf.printf "Column means: %s\n" (data_to_string col_means);
let centered = data - col_means in
Printf.printf "Centered (zero-mean columns):\n%s\n\n"
(data_to_string centered);
(* --- Outer product via broadcasting --- *)
let x = create float64 [| 4 |] [| 1.0; 2.0; 3.0; 4.0 |] in
let y = create float64 [| 3 |] [| 10.0; 20.0; 30.0 |] in
(* x as column [4;1], y as row [1;3] → result is [4;3]. *)
let outer = reshape [| 4; 1 |] x * reshape [| 1; 3 |] y in
Printf.printf "x = %s\n" (data_to_string x);
Printf.printf "y = %s\n" (data_to_string y);
Printf.printf "Outer product (x × y):\n%s\n\n" (data_to_string outer);
(* --- expand_dims / squeeze --- *)
let v = arange float64 0 4 1 in
let row = expand_dims [ 0 ] v in
let col = expand_dims [ 1 ] v in
Printf.printf "Vector: shape %s → %s\n"
(shape_to_string (shape v))
(data_to_string v);
Printf.printf "Row vector: shape %s → %s\n"
(shape_to_string (shape row))
(data_to_string row);
Printf.printf "Col vector: shape %s\n%s\n"
(shape_to_string (shape col))
(data_to_string col)