Mnist
open Kaun
let batch_size = 64
let epochs = 3
let lr = 0.001
let model =
Layer.sequential
[
Layer.conv2d ~in_channels:1 ~out_channels:16 ();
Layer.relu ();
Layer.max_pool2d ~kernel_size:(2, 2) ();
Layer.conv2d ~in_channels:16 ~out_channels:32 ();
Layer.relu ();
Layer.max_pool2d ~kernel_size:(2, 2) ();
Layer.flatten ();
Layer.linear ~in_features:(32 * 7 * 7) ~out_features:128 ();
Layer.relu ();
Layer.linear ~in_features:128 ~out_features:10 ();
]
let collect ds =
let xs = ref [] and ys = ref [] in
Data.iter
(fun (x, y) ->
xs := x :: !xs;
ys := y :: !ys)
ds;
( Nx.stack ~axis:0 (List.rev !xs),
Nx.cast Nx.int32 (Nx.stack ~axis:0 (List.rev !ys)) )
let () =
Nx.Rng.run ~seed:42 @@ fun () ->
let dtype = Nx.float32 in
Printf.printf "Loading MNIST...\n%!";
let train_ds, test_ds = Kaun_datasets.mnist () in
let x_train, y_train = collect train_ds in
let x_test, y_test = collect test_ds in
let n_train = (Nx.shape x_train).(0) in
Printf.printf " train: %d test: %d\n%!" n_train (Nx.shape x_test).(0);
let test_batches = Data.prepare ~batch_size (x_test, y_test) in
let trainer =
Train.make ~model
~optimizer:(Optim.adam ~lr:(Optim.Schedule.constant lr) ())
in
let st = ref (Train.init trainer ~dtype) in
for epoch = 1 to epochs do
let train_data =
Data.prepare ~shuffle:true ~batch_size (x_train, y_train)
|> Data.map (fun (x, y) ->
(x, fun logits -> Loss.cross_entropy_sparse logits y))
in
let num_batches = n_train / batch_size in
let tracker = Metric.tracker () in
st :=
Train.fit trainer !st
~report:(fun ~step ~loss _st ->
Metric.observe tracker "loss" loss;
Printf.printf "\r batch %d/%d loss: %.4f%!" step num_batches loss)
train_data;
Printf.printf "\n%!";
Data.reset test_batches;
let test_acc =
Metric.eval
(fun (x, y) ->
let logits = Train.predict trainer !st x in
Metric.accuracy logits y)
test_batches
in
Printf.printf "epoch %d train_loss: %.4f test_acc: %.2f%%\n%!" epoch
(Metric.mean tracker "loss")
(test_acc *. 100.)
done;
Printf.printf "\nDone.\n"