Gpt2
gpt2.ml
open Kaun
let invalid_argf fmt = Printf.ksprintf invalid_arg fmt
type config = {
vocab_size : int;
n_positions : int;
n_embd : int;
n_layer : int;
n_head : int;
n_inner : int;
resid_pdrop : float;
embd_pdrop : float;
attn_pdrop : float;
layer_norm_eps : float;
}
let config ~vocab_size ~n_embd ~n_layer ~n_head ?(n_positions = 1024)
?(n_inner = 4 * n_embd) ?(resid_pdrop = 0.1) ?(embd_pdrop = 0.1)
?(attn_pdrop = 0.1) ?(layer_norm_eps = 1e-5) () =
if n_embd mod n_head <> 0 then
invalid_argf "Gpt2.config: n_embd (%d) not divisible by n_head (%d)" n_embd
n_head;
{
vocab_size;
n_positions;
n_embd;
n_layer;
n_head;
n_inner;
resid_pdrop;
embd_pdrop;
attn_pdrop;
layer_norm_eps;
}
let fields ~ctx t = Ptree.Dict.fields_exn ~ctx t
let get fs ~name dtype = Ptree.Dict.get_tensor_exn fs ~name dtype
let find ~ctx key fs = Ptree.Dict.find_exn ~ctx key fs
let causal_self_attention (type l) ~(cfg : config)
~(dtype : (float, l) Nx.dtype) ~training ~params (x : (float, l) Nx.t) :
(float, l) Nx.t =
let shape = Nx.shape x in
let batch = shape.(0) in
let seq = shape.(1) in
let h = cfg.n_embd in
let heads = cfg.n_head in
let head_dim = h / heads in
let fs = fields ~ctx:"Gpt2.attention" params in
let qkv_w = get fs ~name:"qkv_weight" dtype in
let qkv_b = get fs ~name:"qkv_bias" dtype in
let qkv = Nx.add (Nx.matmul x qkv_w) qkv_b in
let qkv_parts = Nx.split ~axis:(-1) 3 qkv in
let q = List.nth qkv_parts 0 in
let k = List.nth qkv_parts 1 in
let v = List.nth qkv_parts 2 in
let split_heads t =
Nx.reshape [| batch; seq; heads; head_dim |] t
|> Nx.transpose ~axes:[ 0; 2; 1; 3 ]
in
let q = split_heads q in
let k = split_heads k in
let v = split_heads v in
let dropout_rate =
if training && cfg.attn_pdrop > 0.0 then Some cfg.attn_pdrop else None
in
let attn =
Kaun.Fn.dot_product_attention ~is_causal:true ?dropout_rate q k v
in
let merged =
Nx.transpose attn ~axes:[ 0; 2; 1; 3 ]
|> Nx.contiguous
|> Nx.reshape [| batch; seq; h |]
in
let o_w = get fs ~name:"o_weight" dtype in
let o_b = get fs ~name:"o_bias" dtype in
Nx.add (Nx.matmul merged o_w) o_b
let transformer_block (type l) ~(cfg : config) ~(dtype : (float, l) Nx.dtype)
~training ~params (x : (float, l) Nx.t) : (float, l) Nx.t =
let fs = fields ~ctx:"Gpt2.block" params in
let ln1_g = get fs ~name:"ln1_gamma" dtype in
let ln1_b = get fs ~name:"ln1_beta" dtype in
let x' =
Kaun.Fn.layer_norm ~gamma:ln1_g ~beta:ln1_b ~epsilon:cfg.layer_norm_eps x
in
let attn_params = find ~ctx:"Gpt2.block" "attention" fs in
let attn =
causal_self_attention ~cfg ~dtype ~training ~params:attn_params x'
in
let attn =
if training && cfg.resid_pdrop > 0.0 then
Kaun.Fn.dropout ~rate:cfg.resid_pdrop attn
else attn
in
let x = Nx.add x attn in
let ln2_g = get fs ~name:"ln2_gamma" dtype in
let ln2_b = get fs ~name:"ln2_beta" dtype in
let x' =
Kaun.Fn.layer_norm ~gamma:ln2_g ~beta:ln2_b ~epsilon:cfg.layer_norm_eps x
in
let ffn_up_w = get fs ~name:"ffn_up_weight" dtype in
let ffn_up_b = get fs ~name:"ffn_up_bias" dtype in
let ffn_down_w = get fs ~name:"ffn_down_weight" dtype in
let ffn_down_b = get fs ~name:"ffn_down_bias" dtype in
let y =
Nx.add (Nx.matmul x' ffn_up_w) ffn_up_b |> Kaun.Activation.gelu_approx
in
let y = Nx.add (Nx.matmul y ffn_down_w) ffn_down_b in
let y =
if training && cfg.resid_pdrop > 0.0 then
Kaun.Fn.dropout ~rate:cfg.resid_pdrop y
else y
in
Nx.add x y
let decode (type l in_elt) ~(cfg : config) ~params
~(dtype : (float, l) Nx.dtype) ~training (input_ids : (int32, in_elt) Nx.t)
: (float, l) Nx.t =
let input_ids = Nx.cast Nx.int32 input_ids in
let shape = Nx.shape input_ids in
let batch = shape.(0) in
let seq = shape.(1) in
if seq > cfg.n_positions then
invalid_argf "Gpt2.decode: seq_len=%d exceeds n_positions=%d" seq
cfg.n_positions;
let root = fields ~ctx:"Gpt2.decode" params in
let wte = get root ~name:"wte" dtype in
let wpe = get root ~name:"wpe" dtype in
let layers_t = find ~ctx:"Gpt2.decode" "layers" root in
let position_ids =
Nx.arange_f Nx.float32 0.0 (float_of_int seq) 1.0
|> Nx.cast Nx.int32
|> Nx.reshape [| 1; seq |]
|> Nx.broadcast_to [| batch; seq |]
|> Nx.contiguous
in
let tok = Kaun.Fn.embedding ~scale:false ~embedding:wte input_ids in
let pos = Kaun.Fn.embedding ~scale:false ~embedding:wpe position_ids in
let x = Nx.add tok pos in
let x =
if training && cfg.embd_pdrop > 0.0 then
Kaun.Fn.dropout ~rate:cfg.embd_pdrop x
else x
in
let blocks = Ptree.List.items_exn ~ctx:"Gpt2.decode.layers" layers_t in
let x =
List.fold_left
(fun h block_params ->
transformer_block ~cfg ~dtype ~training ~params:block_params h)
x blocks
in
let ln_f_g = get root ~name:"ln_f_gamma" dtype in
let ln_f_b = get root ~name:"ln_f_beta" dtype in
Kaun.Fn.layer_norm ~gamma:ln_f_g ~beta:ln_f_b ~epsilon:cfg.layer_norm_eps x
let init_block_params ~dtype ~n_embd ~n_inner =
let w = Init.normal ~stddev:0.02 () in
let zeros n = Nx.zeros dtype [| n |] in
let ones n = Nx.ones dtype [| n |] in
let attn_params =
Ptree.dict
[
("qkv_weight", Ptree.tensor (w.f [| n_embd; 3 * n_embd |] dtype));
("qkv_bias", Ptree.tensor (zeros (3 * n_embd)));
("o_weight", Ptree.tensor (w.f [| n_embd; n_embd |] dtype));
("o_bias", Ptree.tensor (zeros n_embd));
]
in
Ptree.dict
[
("attention", attn_params);
("ln1_gamma", Ptree.tensor (ones n_embd));
("ln1_beta", Ptree.tensor (zeros n_embd));
("ffn_up_weight", Ptree.tensor (w.f [| n_embd; n_inner |] dtype));
("ffn_up_bias", Ptree.tensor (zeros n_inner));
("ffn_down_weight", Ptree.tensor (w.f [| n_inner; n_embd |] dtype));
("ffn_down_bias", Ptree.tensor (zeros n_embd));
("ln2_gamma", Ptree.tensor (ones n_embd));
("ln2_beta", Ptree.tensor (zeros n_embd));
]
let init_decoder_params ~cfg ~dtype =
let h = cfg.n_embd in
let w = Init.normal ~stddev:0.02 () in
let wte = w.f [| cfg.vocab_size; h |] dtype in
let wpe = w.f [| cfg.n_positions; h |] dtype in
let blocks =
List.init cfg.n_layer (fun _ ->
init_block_params ~dtype ~n_embd:h ~n_inner:cfg.n_inner)
in
Ptree.dict
[
("wte", Ptree.tensor wte);
("wpe", Ptree.tensor wpe);
("layers", Ptree.list blocks);
("ln_f_gamma", Ptree.tensor (Nx.ones dtype [| h |]));
("ln_f_beta", Ptree.tensor (Nx.zeros dtype [| h |]));
]
let decoder (cfg : config) () : (int32, float) Layer.t =
{
Layer.init =
(fun ~dtype ->
Layer.make_vars
~params:(init_decoder_params ~cfg ~dtype)
~state:Ptree.empty ~dtype);
apply =
(fun ~params ~state ~dtype ~training ?ctx x ->
ignore (state, ctx);
let y = decode ~cfg ~params ~dtype ~training x in
(y, Ptree.empty));
}
let for_causal_lm (cfg : config) () : (int32, float) Layer.t =
{
Layer.init =
(fun ~dtype ->
Layer.make_vars
~params:(init_decoder_params ~cfg ~dtype)
~state:Ptree.empty ~dtype);
apply =
(fun ~params ~state ~dtype ~training ?ctx x ->
ignore (state, ctx);
let hidden = decode ~cfg ~params ~dtype ~training x in
let root = fields ~ctx:"Gpt2.lm_head" params in
let wte = get root ~name:"wte" dtype in
let logits = Nx.matmul hidden (Nx.transpose wte ~axes:[ 1; 0 ]) in
(logits, Ptree.empty));
}
let json_mem name = function
| Jsont.Object (mems, _) -> (
match Jsont.Json.find_mem name mems with
| Some (_, v) -> v
| None -> Jsont.Null ((), Jsont.Meta.none))
| _ -> Jsont.Null ((), Jsont.Meta.none)
let json_to_int = function
| Jsont.Number (f, _) -> int_of_float f
| _ -> failwith "expected int"
let json_to_int_option = function
| Jsont.Number (f, _) -> Some (int_of_float f)
| _ -> None
let json_to_float_option = function Jsont.Number (f, _) -> Some f | _ -> None
let parse_config json =
let n_embd = json |> json_mem "n_embd" |> json_to_int in
config
~vocab_size:(json |> json_mem "vocab_size" |> json_to_int)
~n_embd
~n_layer:(json |> json_mem "n_layer" |> json_to_int)
~n_head:(json |> json_mem "n_head" |> json_to_int)
?n_positions:(json |> json_mem "n_positions" |> json_to_int_option)
?n_inner:(json |> json_mem "n_inner" |> json_to_int_option)
?resid_pdrop:(json |> json_mem "resid_pdrop" |> json_to_float_option)
?embd_pdrop:(json |> json_mem "embd_pdrop" |> json_to_float_option)
?attn_pdrop:(json |> json_mem "attn_pdrop" |> json_to_float_option)
?layer_norm_eps:
(json |> json_mem "layer_norm_epsilon" |> json_to_float_option)
()
let cast_tensor dtype (Ptree.P t) = Ptree.P (Nx.cast dtype t)
let map_hf_weights ~cfg ~dtype hf_weights =
let tbl = Hashtbl.create (List.length hf_weights) in
List.iter (fun (name, tensor) -> Hashtbl.add tbl name tensor) hf_weights;
let hf name =
match Hashtbl.find_opt tbl name with
| Some t -> cast_tensor dtype t
| None -> invalid_argf "from_pretrained: missing HF weight %S" name
in
let hf_t name = Ptree.Tensor (hf name) in
let layer i =
let p s = Printf.sprintf "h.%d.%s" i s in
Ptree.dict
[
( "attention",
Ptree.dict
[
("qkv_weight", hf_t (p "attn.c_attn.weight"));
("qkv_bias", hf_t (p "attn.c_attn.bias"));
("o_weight", hf_t (p "attn.c_proj.weight"));
("o_bias", hf_t (p "attn.c_proj.bias"));
] );
("ln1_gamma", hf_t (p "ln_1.weight"));
("ln1_beta", hf_t (p "ln_1.bias"));
("ffn_up_weight", hf_t (p "mlp.c_fc.weight"));
("ffn_up_bias", hf_t (p "mlp.c_fc.bias"));
("ffn_down_weight", hf_t (p "mlp.c_proj.weight"));
("ffn_down_bias", hf_t (p "mlp.c_proj.bias"));
("ln2_gamma", hf_t (p "ln_2.weight"));
("ln2_beta", hf_t (p "ln_2.bias"));
]
in
Ptree.dict
[
("wte", hf_t "wte.weight");
("wpe", hf_t "wpe.weight");
("layers", Ptree.list (List.init cfg.n_layer layer));
("ln_f_gamma", hf_t "ln_f.weight");
("ln_f_beta", hf_t "ln_f.bias");
]
let from_pretrained ?(model_id = "gpt2") () =
let json = Kaun_hf.load_config ~model_id () in
let cfg = parse_config json in
let hf_weights = Kaun_hf.load_weights ~model_id () in
let params = map_hf_weights ~cfg ~dtype:Nx.float32 hf_weights in
(cfg, params)
main.ml
open Kaun
let load_tokenizer model_id =
let vocab = Kaun_hf.download_file ~model_id ~filename:"vocab.json" () in
let merges = Kaun_hf.download_file ~model_id ~filename:"merges.txt" () in
Brot.from_model_file ~vocab ~merges
~pre:
(Brot.Pre_tokenizer.byte_level ~add_prefix_space:false ~use_regex:true ())
~decoder:(Brot.Decoder.byte_level ())
()
let encode tokenizer text =
Array.map Int32.of_int (Brot.encode_ids tokenizer text)
let decode tokenizer ids = Brot.decode tokenizer (Array.map Int32.to_int ids)
let generate model vars ~max_tokens prompt =
let tokens = ref (Array.to_list prompt) in
for _ = 1 to max_tokens do
let ids = Array.of_list !tokens in
let n = Array.length ids in
let input = Nx.create Nx.int32 [| 1; n |] ids in
let logits, _ = Layer.apply model vars ~training:false input in
let last = Nx.slice [ I 0; I (n - 1) ] logits in
let next : int32 = Nx.item [] (Nx.argmax ~axis:0 last) in
tokens := !tokens @ [ next ]
done;
Array.of_list !tokens
let print_top_k ~k model vars input_ids ~pos =
let logits, _ = Layer.apply model vars ~training:false input_ids in
let row = Nx.slice [ I 0; I pos ] logits in
let sorted = Nx.argsort ~descending:true ~axis:0 row in
let probs = Nx.softmax ~axes:[ 0 ] row in
for i = 0 to k - 1 do
let idx = Int32.to_int (Nx.item [ i ] sorted) in
let prob : float = Nx.item [ idx ] probs in
Printf.printf " #%d token %-6d p=%.4f\n" (i + 1) idx prob
done
let () =
let model_id = "gpt2" in
let dtype = Nx.float32 in
Printf.printf "Loading %s...\n%!" model_id;
let tokenizer = load_tokenizer model_id in
let cfg, params = Gpt2.from_pretrained ~model_id () in
Printf.printf " vocab=%d n_embd=%d layers=%d heads=%d\n\n" cfg.vocab_size
cfg.n_embd cfg.n_layer cfg.n_head;
let model = Gpt2.for_causal_lm cfg () in
let vars = Layer.make_vars ~params ~state:Ptree.empty ~dtype in
Printf.printf "=== Next-token predictions ===\n";
Printf.printf " Prompt: \"Hello world\"\n";
Printf.printf " Top 5 continuations:\n";
let hello_ids = encode tokenizer "Hello world" in
let hello = Nx.create Nx.int32 [| 1; Array.length hello_ids |] hello_ids in
print_top_k ~k:5 model vars hello ~pos:(Array.length hello_ids - 1);
Printf.printf "\n=== Greedy generation (30 tokens each) ===\n\n";
let prompts =
[ "The meaning of life is"; "Once upon a time"; "The quick brown fox" ]
in
List.iter
(fun text ->
let prompt = encode tokenizer text in
let generated = generate model vars ~max_tokens:30 prompt in
let continuation =
Array.sub generated (Array.length prompt)
(Array.length generated - Array.length prompt)
in
Printf.printf " \"%s\" ->\n" text;
Printf.printf " %s\n\n" (decode tokenizer continuation))
prompts