09-image-processing

Load, transform, and save images as arrays — convolutions, pooling, and pixel math. This example creates a synthetic grayscale image, blurs it, detects edges with Sobel filters, and downsamples with max pooling.

dune exec nx/examples/09-image-processing/main.exe

What You'll Learn

  • Creating synthetic images with init and pixel math
  • Applying 2D convolution with correlate2d (NCHW format)
  • Gaussian blur with a 3x3 kernel
  • Sobel edge detection (horizontal + vertical gradients)
  • Downsampling with max_pool2d
  • Converting between UInt8 and Float32 for computation
  • Saving arrays as PNG files with Nx_io.save_image

Key Functions

Function Purpose
init UInt8 shape f Create an image by computing each pixel
correlate2d ~padding_mode:\Same img kernel` 2D convolution (expects NCHW)
max_pool2d ~kernel_size ~stride img Downsample by taking max in each window
cast Float32 t Convert dtype for floating-point operations
clamp ~min ~max t Clamp values to a valid pixel range
contiguous t Ensure contiguous memory layout (required for I/O)
Nx_io.save_image path t Save a 2D (HxW) array as a grayscale PNG

Output Walkthrough

Synthetic image

A 64x64 horizontal gradient with a bright rectangle in the center:

let img = init UInt8 [| h; w |] (fun idx ->
    let y = idx.(0) and x = idx.(1) in
    let base = x * 255 / (w - 1) in
    if y >= 16 && y < 48 && x >= 16 && x < 48 then 220 else base)

NCHW format

Convolution operations expect 4D tensors in NCHW format (batch, channels, height, width). Convert with:

let img_f = cast Float32 img |> contiguous |> reshape [| 1; 1; h; w |]

Gaussian blur

A 3x3 kernel with weights summing to 1, giving more weight to the center:

let blur_kernel = create Float32 [| 1; 1; 3; 3 |]
  [| 1./16.; 2./16.; 1./16.;
     2./16.; 4./16.; 2./16.;
     1./16.; 2./16.; 1./16. |]

Sobel edge detection

Combines horizontal and vertical gradient magnitudes:

let gx = correlate2d ~padding_mode:`Same img_f sobel_x in
let gy = correlate2d ~padding_mode:`Same img_f sobel_y in
let edges = sqrt (add (mul gx gx) (mul gy gy))

Max pooling

2x downsampling by taking the maximum in each 2x2 window:

Saved: pooled.png (64x64 -> 32x32)

Output Files

Running this example creates four PNG files in the current directory:

File Description
gradient.png Original synthetic image
blurred.png After Gaussian blur
edges.png Sobel edge detection result
pooled.png 2x downsampled via max pooling

Try It

  1. Replace the blur kernel with a sharpening kernel: [| 0.; -1.; 0.; -1.; 5.; -1.; 0.; -1.; 0. |]
  2. Try a larger pooling window (4, 4) and observe the effect on image size and detail.
  3. Chain blur and edge detection: blur first to reduce noise, then apply Sobel.

Next Steps

You've completed the Nx examples! For machine learning workflows, see the kaun examples.

(** Load, transform, and save images as arrays — convolutions, pooling, and
    pixel math.

    Create a synthetic grayscale gradient, blur it, detect edges with Sobel
    filters, and downsample with max pooling. Results are saved as PNG files. *)

open Nx

let () =
  let h = 64 and w = 64 in

  (* --- Create a gradient image with a bright rectangle --- *)
  let img =
    init UInt8 [| h; w |] (fun idx ->
        let y = idx.(0) and x = idx.(1) in
        (* Background: horizontal gradient. *)
        let base = x * 255 / (w - 1) in
        (* Bright rectangle in the center. *)
        if y >= 16 && y < 48 && x >= 16 && x < 48 then 220 else base)
  in
  Printf.printf "Created %dx%d grayscale image\n" h w;

  (* Save the original. *)
  Nx_io.save_image "gradient.png" (contiguous img);
  Printf.printf "Saved: gradient.png\n";

  (* --- Gaussian blur with a 3x3 kernel --- *)

  (* Convert to float for convolution. The scipy-style correlate works on raw
     spatial dims, so we use [H; W] directly. *)
  let img_f = cast Float32 img |> contiguous in

  let blur_kernel =
    create Float32 [| 3; 3 |]
      [|
        1.0 /. 16.0;
        2.0 /. 16.0;
        1.0 /. 16.0;
        2.0 /. 16.0;
        4.0 /. 16.0;
        2.0 /. 16.0;
        1.0 /. 16.0;
        2.0 /. 16.0;
        1.0 /. 16.0;
      |]
  in
  let blurred = correlate ~padding:`Same img_f blur_kernel in
  let blurred_img =
    clamp ~min:0.0 ~max:255.0 blurred |> cast UInt8 |> contiguous
  in
  Nx_io.save_image "blurred.png" blurred_img;
  Printf.printf "Saved: blurred.png\n";

  (* --- Sobel edge detection --- *)
  let sobel_x =
    create Float32 [| 3; 3 |]
      [| -1.0; 0.0; 1.0; -2.0; 0.0; 2.0; -1.0; 0.0; 1.0 |]
  in
  let sobel_y =
    create Float32 [| 3; 3 |]
      [| -1.0; -2.0; -1.0; 0.0; 0.0; 0.0; 1.0; 2.0; 1.0 |]
  in

  let gx = correlate ~padding:`Same img_f sobel_x in
  let gy = correlate ~padding:`Same img_f sobel_y in
  let edges = sqrt (add (mul gx gx) (mul gy gy)) in
  let edges_img = clamp ~min:0.0 ~max:255.0 edges |> cast UInt8 |> contiguous in
  Nx_io.save_image "edges.png" edges_img;
  Printf.printf "Saved: edges.png\n";

  (* --- Max pooling: 2x downsample using maximum_filter --- *)
  let pooled =
    maximum_filter ~kernel_size:[| 2; 2 |] ~stride:[| 2; 2 |] img_f
  in
  let pool_h = (shape pooled).(0) and pool_w = (shape pooled).(1) in
  let pooled_img =
    clamp ~min:0.0 ~max:255.0 pooled |> cast UInt8 |> contiguous
  in
  Nx_io.save_image "pooled.png" pooled_img;
  Printf.printf "Saved: pooled.png (%dx%d -> %dx%d)\n" h w pool_h pool_w;

  Printf.printf "\nAll images saved to the current directory.\n"