MNIST Tutorial with Kaun

Welcome to Kaun! In this tutorial, you'll learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset.

1. Load the MNIST Dataset

Load and prepare the MNIST dataset:

(* Set device - using CPU with C FFI for this tutorial *)
let device = Rune.c

(* Load MNIST datasets without normalization for visualization *)
let train_data_raw = Kaun_datasets.mnist ~train:true ~flatten:false ~normalize:false ~device ()
let test_data_raw = Kaun_datasets.mnist ~train:false ~flatten:false ~normalize:false ~device ()

(* Get one sample to check the data shape *)
(* For now, let's just get info about the dataset size *)
let train_size = 60000  (* MNIST has 60k training samples *)
let test_size = 10000   (* MNIST has 10k test samples *)

let () = 
  Printf.printf "Raw MNIST data loaded (unnormalized):\n";
  Printf.printf "Train samples: %d\n" train_size;
  Printf.printf "Test samples: %d\n" test_size
01.07.25 03:51:54.910 .datasets.mnist  INFO Loading MNIST datasets...
01.07.25 03:51:56.380 .datasets.mnist  INFO MNIST loading complete

2. Visualize Sample Images

Let's visualize some sample images from the training set using Hugin:

(* Get samples from our already loaded dataset *)
(* We'll create a small batch dataset and take the first batch *)
let small_batch_dataset = Kaun.Dataset.batch_xy 10 train_data_raw
let[@warning "-8"] [images_batch, labels_batch] = Kaun.Dataset.take 1 small_batch_dataset

(* Convert to Nx for visualization *)
let images_nx = images_batch |> Rune.contiguous |> Rune.unsafe_to_bigarray |> Nx.of_bigarray

(* Extract label values *)
let get_label i = 
  labels_batch
  |> Rune.slice [R [i; i+1]] 
  |> Rune.contiguous
  |> Rune.unsafe_get []
  |> int_of_float

let fig = Hugin.Figure.create ~width:1000 ~height:400 ()

(* Create a 2x5 grid of subplots *)
let () =
  for i = 0 to 9 do
    (* Extract single image and copy to avoid memory layout issues with sliced arrays *)
    let img = Nx.slice [I i; I 0; R []; R []] images_nx |> Nx.squeeze |> Nx.copy in
    let label = get_label i in

    (* Add subplot *)
    Hugin.Figure.add_subplot ~nrows:2 ~ncols:5 ~index:(i+1) fig
    |> Hugin.Plotting.imshow ~data:img
    |> Hugin.Axes.set_title (Printf.sprintf "Label: %d" label)
    |> Hugin.Axes.set_xticks []
    |> Hugin.Axes.set_yticks []
    |> ignore 
  done;;

fig

figure