Dqn
open Fehu
open Kaun
let buffer_capacity = 50_000
let batch_size = 64
let gamma = 0.99
let lr = 5e-4
let epsilon_start = 1.0
let epsilon_end = 0.05
let epsilon_decay_steps = 10_000
let target_update_interval = 250
let learning_starts = 1000
let n_steps = 50_000
let eval_interval = 2000
let eval_episodes = 20
let sparkline values =
let blocks =
[|
"\xe2\x96\x81";
"\xe2\x96\x82";
"\xe2\x96\x83";
"\xe2\x96\x84";
"\xe2\x96\x85";
"\xe2\x96\x86";
"\xe2\x96\x87";
"\xe2\x96\x88";
|]
in
let lo = Array.fold_left Float.min Float.infinity values in
let hi = Array.fold_left Float.max Float.neg_infinity values in
let range = hi -. lo in
if range < 1e-9 then
String.concat "" (Array.to_list (Array.map (fun _ -> blocks.(4)) values))
else
String.concat ""
(Array.to_list
(Array.map
(fun v ->
let idx = Float.to_int ((v -. lo) /. range *. 7.0) in
blocks.(max 0 (min 7 idx)))
values))
let q_network =
Layer.sequential
[
Layer.linear ~in_features:4 ~out_features:128 ();
Layer.relu ();
Layer.linear ~in_features:128 ~out_features:128 ();
Layer.relu ();
Layer.linear ~in_features:128 ~out_features:2 ();
]
let forward params net_state obs =
let vars = Layer.make_vars ~params ~state:net_state ~dtype:Nx.float32 in
fst (Layer.apply q_network vars ~training:false obs)
let epsilon step =
let t =
Float.min 1.0 (Float.of_int step /. Float.of_int epsilon_decay_steps)
in
epsilon_start +. (t *. (epsilon_end -. epsilon_start))
let copy_params params = Ptree.map { run = (fun t -> Nx.copy t) } params
let () =
Printf.printf "DQN on CartPole-v1\n";
Printf.printf "===================\n\n";
Printf.printf
"Network: Linear(4 -> 128) -> ReLU -> Linear(128 -> 128) -> ReLU -> \
Linear(128 -> 2)\n";
Printf.printf "Buffer: %d, batch: %d, gamma = %.2f, lr = %.4f\n"
buffer_capacity batch_size gamma lr;
Printf.printf
"Epsilon: %.2f -> %.2f over %d steps, target update every %d steps\n\n"
epsilon_start epsilon_end epsilon_decay_steps target_update_interval;
Nx.Rng.run ~seed:42 @@ fun () ->
let env = Fehu_envs.Cartpole.make () in
let vars = Layer.init q_network ~dtype:Nx.float32 in
let params = ref (Layer.params vars) in
let net_state = Layer.state vars in
let target_params = ref (copy_params !params) in
Printf.printf "Parameters: %d\n\n" (Ptree.count_parameters !params);
let algo = Optim.adam ~lr:(Optim.Schedule.constant lr) () in
let opt_state = ref (Optim.init algo !params) in
let buffer = Buffer.create ~capacity:buffer_capacity in
let sample_uniform () =
let t = Nx.rand Nx.float32 [| 1 |] in
(Nx.to_array t : float array).(0)
in
let select_action obs eps =
if sample_uniform () < eps then Space.sample (Env.action_space env)
else begin
let obs_batch = Nx.reshape [| 1; 4 |] obs in
let q_values =
Rune.no_grad (fun () -> forward !params net_state obs_batch)
in
let action_idx =
Nx.argmax q_values ~axis:(-1) ~keepdims:false |> Nx.cast Nx.int32
in
Nx.reshape [||] action_idx
end
in
let greedy_policy obs =
let obs_batch = Nx.reshape [| 1; 4 |] obs in
let q_values =
Rune.no_grad (fun () -> forward !params net_state obs_batch)
in
let action_idx =
Nx.argmax q_values ~axis:(-1) ~keepdims:false |> Nx.cast Nx.int32
in
Nx.reshape [||] action_idx
in
let train_step () =
let obs_arr, act_arr, rew_arr, next_obs_arr, term_arr, trunc_arr =
Buffer.sample_arrays buffer ~batch_size
in
let n = Array.length obs_arr in
let obs_batch = Nx.stack (Array.to_list obs_arr) in
let next_obs_batch = Nx.stack (Array.to_list next_obs_arr) in
let actions_batch =
Nx.stack
(Array.to_list (Array.map (fun a -> Nx.reshape [| 1 |] a) act_arr))
in
let rewards_t = Nx.create Nx.float32 [| n |] rew_arr in
let done_mask =
Array.init n (fun i -> if term_arr.(i) || trunc_arr.(i) then 0.0 else 1.0)
in
let done_mask_t = Nx.create Nx.float32 [| n |] done_mask in
let td_target =
Rune.no_grad (fun () ->
let target_q = forward !target_params net_state next_obs_batch in
let max_q = Nx.max target_q ~axes:[ 1 ] ~keepdims:false in
Nx.add rewards_t
(Nx.mul (Nx.scalar Nx.float32 gamma) (Nx.mul max_q done_mask_t)))
in
let td_target = Rune.detach td_target in
let loss_fn p =
let q_values = forward p net_state obs_batch in
let q_selected = Nx.take_along_axis ~axis:1 actions_batch q_values in
let q_selected = Nx.reshape [| n |] q_selected in
let diff = Nx.sub q_selected td_target in
Nx.mean (Nx.mul diff diff)
in
let loss, grads = Grad.value_and_grad loss_fn !params in
let new_params, new_opt_state =
Optim.update algo !opt_state !params grads
in
params := new_params;
opt_state := new_opt_state;
Nx.item [] loss
in
Printf.printf "Filling buffer (%d steps)...\n\n" learning_starts;
let obs = ref (fst (Env.reset env ())) in
let last_loss = ref 0.0 in
let n_evals = n_steps / eval_interval in
let reward_history = Array.make n_evals 0.0 in
let eval_idx = ref 0 in
Printf.printf "Training...\n\n";
for step = 1 to n_steps do
let eps = epsilon step in
let action = select_action !obs eps in
let s = Env.step env action in
Buffer.add buffer
{
observation = !obs;
action;
reward = s.reward;
next_observation = s.observation;
terminated = s.terminated;
truncated = s.truncated;
};
if s.terminated || s.truncated then obs := fst (Env.reset env ())
else obs := s.observation;
if step >= learning_starts then begin
last_loss := train_step ();
if step mod target_update_interval = 0 then
target_params := copy_params !params
end;
if step mod eval_interval = 0 then begin
let stats =
Eval.run env ~policy:greedy_policy ~n_episodes:eval_episodes ()
in
Printf.printf
" step %5d epsilon = %.2f loss = %6.4f eval: reward = %5.1f +/- \
%4.1f\n\
%!"
step eps !last_loss stats.mean_reward stats.std_reward;
reward_history.(!eval_idx) <- stats.mean_reward;
incr eval_idx;
obs := fst (Env.reset env ())
end
done;
Printf.printf "\n reward: %s\n" (sparkline reward_history);
Printf.printf "\nFinal evaluation (%d episodes):\n" 50;
let stats = Eval.run env ~policy:greedy_policy ~n_episodes:50 () in
Printf.printf " mean reward: %5.1f +/- %.1f\n" stats.mean_reward
stats.std_reward;
Printf.printf " mean length: %5.1f\n" stats.mean_length;
if stats.mean_reward >= 475.0 then
Printf.printf "\nSolved! (mean reward >= 475)\n"
else Printf.printf "\nNot solved yet (mean reward < 475).\n";
Env.close env