Q Learning
open Fehu
let n_bins = 12
let n_actions = 2
let alpha = 0.1
let gamma = 0.99
let epsilon_start = 1.0
let epsilon_end = 0.01
let epsilon_decay = 2000.0
let n_episodes = 10_000
let eval_interval = 500
let sparkline values =
let blocks =
[|
"\xe2\x96\x81";
"\xe2\x96\x82";
"\xe2\x96\x83";
"\xe2\x96\x84";
"\xe2\x96\x85";
"\xe2\x96\x86";
"\xe2\x96\x87";
"\xe2\x96\x88";
|]
in
let lo = Array.fold_left Float.min Float.infinity values in
let hi = Array.fold_left Float.max Float.neg_infinity values in
let range = hi -. lo in
if range < 1e-9 then
String.concat "" (Array.to_list (Array.map (fun _ -> blocks.(4)) values))
else
String.concat ""
(Array.to_list
(Array.map
(fun v ->
let idx = Float.to_int ((v -. lo) /. range *. 7.0) in
blocks.(max 0 (min 7 idx)))
values))
let n_states = n_bins * n_bins * n_bins * n_bins
let q = Array.make (n_states * n_actions) 0.0
let q_get s a = q.((s * n_actions) + a)
let q_set s a v = q.((s * n_actions) + a) <- v
let clip_ranges = [| (-2.4, 2.4); (-3.0, 3.0); (-0.21, 0.21); (-3.0, 3.0) |]
let discretize obs =
let arr = (Nx.to_array obs : float array) in
let bin i =
let lo, hi = clip_ranges.(i) in
let v = Float.max lo (Float.min hi arr.(i)) in
let normalized = (v -. lo) /. (hi -. lo) in
Float.to_int (normalized *. Float.of_int (n_bins - 1))
|> max 0
|> min (n_bins - 1)
in
let b0 = bin 0 in
let b1 = bin 1 in
let b2 = bin 2 in
let b3 = bin 3 in
(b0 * n_bins * n_bins * n_bins) + (b1 * n_bins * n_bins) + (b2 * n_bins) + b3
let best_action s = if q_get s 0 >= q_get s 1 then 0 else 1
let () =
Printf.printf "Q-Learning on CartPole-v1\n";
Printf.printf "==========================\n\n";
Printf.printf "States: %d bins/dim (%d total), Actions: left/right\n" n_bins
n_states;
Printf.printf "alpha = %.2f, gamma = %.2f, episodes = %d\n\n" alpha gamma
n_episodes;
Nx.Rng.run ~seed:42 @@ fun () ->
let sample_uniform () =
let t = Nx.rand Nx.float32 [| 1 |] in
(Nx.to_array t : float array).(0)
in
let sample_random_action () =
let t = Nx.randint Nx.int32 ~high:n_actions [| 1 |] 0 in
Int32.to_int (Nx.to_array t : Int32.t array).(0)
in
let env = Fehu_envs.Cartpole.make () in
let n_evals = n_episodes / eval_interval in
let reward_history = Array.make n_evals 0.0 in
let eval_idx = ref 0 in
Printf.printf "Training...\n\n";
for episode = 1 to n_episodes do
let epsilon =
epsilon_end
+. (epsilon_start -. epsilon_end)
*. exp (-.Float.of_int episode /. epsilon_decay)
in
let obs, _info = Env.reset env () in
let state = ref (discretize obs) in
let done_ = ref false in
while not !done_ do
let a =
if sample_uniform () < epsilon then sample_random_action ()
else best_action !state
in
let s = Env.step env (Space.Discrete.of_int a) in
let next_state = discretize s.observation in
let done_flag = s.terminated || s.truncated in
let bootstrap =
if done_flag then 0.0
else Float.max (q_get next_state 0) (q_get next_state 1)
in
let target = s.reward +. (gamma *. bootstrap) in
let old_q = q_get !state a in
q_set !state a (old_q +. (alpha *. (target -. old_q)));
state := next_state;
done_ := done_flag
done;
if episode mod eval_interval = 0 then begin
let greedy_policy obs =
Space.Discrete.of_int (best_action (discretize obs))
in
let stats = Eval.run env ~policy:greedy_policy ~n_episodes:20 () in
Printf.printf
" episode %5d eps = %.2f eval: reward = %5.1f +/- %4.1f\n%!" episode
epsilon stats.mean_reward stats.std_reward;
reward_history.(!eval_idx) <- stats.mean_reward;
incr eval_idx
end
done;
Printf.printf "\n reward: %s\n" (sparkline reward_history);
Printf.printf "\nFinal evaluation (100 episodes):\n";
let greedy_policy obs =
Space.Discrete.of_int (best_action (discretize obs))
in
let stats = Eval.run env ~policy:greedy_policy ~n_episodes:100 () in
Printf.printf " mean reward: %5.1f +/- %.1f\n" stats.mean_reward
stats.std_reward;
Printf.printf " mean length: %5.1f\n" stats.mean_length;
if stats.mean_reward >= 195.0 then
Printf.printf "\nSolved! (mean reward >= 195)\n"
else Printf.printf "\nNot solved yet (mean reward < 195).\n";
Env.close env