MNIST Tutorial with Kaun
Welcome to Kaun! In this tutorial, you'll learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset.
1. Load the MNIST Dataset
Load and prepare the MNIST dataset:
(* Set device - using CPU with C FFI for this tutorial *)
let device = Rune.c
(* Load MNIST datasets without normalization for visualization *)
let train_data_raw = Kaun_datasets.mnist ~train:true ~flatten:false ~normalize:false ~device ()
let test_data_raw = Kaun_datasets.mnist ~train:false ~flatten:false ~normalize:false ~device ()
(* Get one sample to check the data shape *)
(* For now, let's just get info about the dataset size *)
let train_size = 60000 (* MNIST has 60k training samples *)
let test_size = 10000 (* MNIST has 10k test samples *)
let () =
Printf.printf "Raw MNIST data loaded (unnormalized):\n";
Printf.printf "Train samples: %d\n" train_size;
Printf.printf "Test samples: %d\n" test_size
[2m01.07.25 03:51:54.910[0m [37m.datasets.[0mmnist [32m INFO[0m[37m[3m[0;37m[0m Loading MNIST datasets...
[2m01.07.25 03:51:56.380[0m [37m.datasets.[0mmnist [32m INFO[0m[37m[3m[0;37m[0m MNIST loading complete
2. Visualize Sample Images
Let's visualize some sample images from the training set using Hugin:
(* Get samples from our already loaded dataset *)
(* We'll create a small batch dataset and take the first batch *)
let small_batch_dataset = Kaun.Dataset.batch_xy 10 train_data_raw
let[@warning "-8"] [images_batch, labels_batch] = Kaun.Dataset.take 1 small_batch_dataset
(* Convert to Nx for visualization *)
let images_nx = images_batch |> Rune.contiguous |> Rune.unsafe_to_bigarray |> Nx.of_bigarray
(* Extract label values *)
let get_label i =
labels_batch
|> Rune.slice [R [i; i+1]]
|> Rune.contiguous
|> Rune.unsafe_get []
|> int_of_float
let fig = Hugin.Figure.create ~width:1000 ~height:400 ()
(* Create a 2x5 grid of subplots *)
let () =
for i = 0 to 9 do
(* Extract single image and copy to avoid memory layout issues with sliced arrays *)
let img = Nx.slice [I i; I 0; R []; R []] images_nx |> Nx.squeeze |> Nx.copy in
let label = get_label i in
(* Add subplot *)
Hugin.Figure.add_subplot ~nrows:2 ~ncols:5 ~index:(i+1) fig
|> Hugin.Plotting.imshow ~data:img
|> Hugin.Axes.set_title (Printf.sprintf "Label: %d" label)
|> Hugin.Axes.set_xticks []
|> Hugin.Axes.set_yticks []
|> ignore
done;;
fig