MLP

Train a multi-layer perceptron from scratch using Rune's automatic differentiation. Computes MSE loss, derives gradients with Rune.grad, and updates parameters in a manual training loop.

(*---------------------------------------------------------------------------
  Copyright (c) 2026 The Raven authors. All rights reserved.
  SPDX-License-Identifier: ISC
  ---------------------------------------------------------------------------*)

open Nx
open Rune

(* Forward pass: computes the MLP output *)
let forward params inputs =
  match params with
  | [ w1; b1; w2; b2 ] ->
      (* Input layer to hidden layer *)
      let z1 = add (matmul inputs w1) b1 in
      (* Hidden layer activation *)
      let a1 = maximum (scalar Float32 0.0) z1 in
      (* Hidden layer to output layer *)
      let z2 = add (matmul a1 w2) b2 in
      (* Output layer *)
      z2
  | _ -> failwith "Invalid parameters"

(* Mean Squared Error loss *)
let mse_loss y_pred y_true =
  let diff = sub y_pred y_true in
  let squared_diff = mul diff diff in
  mean squared_diff

(* Training function *)
let train_mlp inputs y_true learning_rate epochs =
  (* Initialize MLP parameters *)
  let d = dim 1 inputs in
  (* Number of input features *)
  let h = 3 in
  (* Hidden layer size *)
  let c = dim 1 y_true in
  (* Number of outputs *)
  let w1 = rand Float32 [| d; h |] in
  let b1 = zeros Float32 [| h |] in
  let w2 = rand Float32 [| h; c |] in
  let b2 = zeros Float32 [| c |] in
  let params = [ w1; b1; w2; b2 ] in

  (* Define the loss as a function of parameters *)
  let loss_fn params =
    let y_pred = forward params inputs in
    mse_loss y_pred y_true
  in

  (* Training loop *)
  for epoch = 1 to epochs do
    (* Compute gradients using the provided grad function *)
    let loss, grad_params = value_and_grads loss_fn params in

    Printf.printf "Epoch %d: Loss = %f\n" epoch (item [] loss);

    List.combine params grad_params
    |> List.iter (fun (param, grad) ->
        ignore (sub ~out:param param (mul (scalar Float32 learning_rate) grad)))
  done;
  params

(* Example usage *)
let () =
  (* Dummy input data: 4 samples with 2 features *)
  let inputs =
    create Float32 [| 4; 2 |] [| 1.0; 2.0; 3.0; 4.0; 5.0; 6.0; 7.0; 8.0 |]
  in
  (* Dummy target data: 4 samples with 1 output *)
  let y_true = create Float32 [| 4; 1 |] [| 1.0; 2.0; 3.0; 4.0 |] in
  let learning_rate = 0.01 in
  let epochs = 100 in

  (* Train the MLP *)
  let trained_params = train_mlp inputs y_true learning_rate epochs in

  (* Make predictions with trained parameters *)
  let y_pred = forward trained_params inputs in
  print_endline "Predictions after training:";
  print y_pred