Module Nx
N-dimensional arrays.
Nx provides n-dimensional arrays (tensors) with NumPy-like semantics. A tensor ('a, 'b) t holds elements of OCaml type 'a stored in a buffer with element kind 'b.
Tensors, views, and contiguity. A tensor is a view over a flat buffer described by a shape, strides, and an offset. Operations that only rearrange metadata (reshape, transpose, slice, …) return views in O(1) without copying data. Use is_c_contiguous to test whether elements are laid out contiguously in row-major order, and contiguous to obtain a contiguous copy when needed.
Broadcasting. Binary operations automatically broadcast operands whose shapes differ: dimensions are aligned from the right and each pair must be equal or one of them must be 1.
The ?out convention. Many operations accept an optional ?out tensor. When provided, the result is written into out instead of allocating a fresh tensor; the shape of out must match the result shape.
Types
type ('a, 'b) t = ('a, 'b) Nx_effect.tThe type for tensors with OCaml element type 'a and buffer element kind 'b.
Element kinds
Witnesses for the buffer element representation. Used as the second type parameter of t.
type float16_elt = Nx_buffer.float16_elttype float32_elt = Nx_buffer.float32_elttype float64_elt = Nx_buffer.float64_elttype bfloat16_elt = Nx_buffer.bfloat16_elttype float8_e4m3_elt = Nx_buffer.float8_e4m3_elttype float8_e5m2_elt = Nx_buffer.float8_e5m2_elttype int4_elt = Nx_buffer.int4_signed_elttype uint4_elt = Nx_buffer.int4_unsigned_elttype int8_elt = Nx_buffer.int8_signed_elttype uint8_elt = Nx_buffer.int8_unsigned_elttype int16_elt = Nx_buffer.int16_signed_elttype uint16_elt = Nx_buffer.int16_unsigned_elttype int32_elt = Nx_buffer.int32_elttype uint32_elt = Nx_buffer.uint32_elttype int64_elt = Nx_buffer.int64_elttype uint64_elt = Nx_buffer.uint64_elttype complex32_elt = Nx_buffer.complex32_elttype complex64_elt = Nx_buffer.complex64_elttype bool_elt = Nx_buffer.bool_eltData types
type ('a, 'b) dtype = ('a, 'b) Nx_core.Dtype.t = | Float16 : (float, float16_elt) dtype| Float32 : (float, float32_elt) dtype| Float64 : (float, float64_elt) dtype| BFloat16 : (float, bfloat16_elt) dtype| Float8_e4m3 : (float, float8_e4m3_elt) dtype| Float8_e5m2 : (float, float8_e5m2_elt) dtype| Int4 : (int, int4_elt) dtype| UInt4 : (int, uint4_elt) dtype| Int8 : (int, int8_elt) dtype| UInt8 : (int, uint8_elt) dtype| Int16 : (int, int16_elt) dtype| UInt16 : (int, uint16_elt) dtype| Int32 : (int32, int32_elt) dtype| UInt32 : (int32, uint32_elt) dtype| Int64 : (int64, int64_elt) dtype| UInt64 : (int64, uint64_elt) dtype| Complex64 : (Stdlib.Complex.t, complex32_elt) dtype| Complex128 : (Stdlib.Complex.t, complex64_elt) dtype| Bool : (bool, bool_elt) dtype(*The type for data type descriptors. A
*)('a, 'b) dtypelinks the OCaml element type'ato its buffer representation'b.
Tensor aliases
type float16_t = (float, float16_elt) ttype float32_t = (float, float32_elt) ttype float64_t = (float, float64_elt) ttype bfloat16_t = (float, bfloat16_elt) ttype float8_e4m3_t = (float, float8_e4m3_elt) ttype float8_e5m2_t = (float, float8_e5m2_elt) ttype uint16_t = (int, uint16_elt) ttype uint32_t = (int32, uint32_elt) ttype uint64_t = (int64, uint64_elt) ttype complex64_t = (Stdlib.Complex.t, complex32_elt) ttype complex128_t = (Stdlib.Complex.t, complex64_elt) tData type values
val float16 : (float, float16_elt) dtypeval float32 : (float, float32_elt) dtypeval float64 : (float, float64_elt) dtypeval bfloat16 : (float, bfloat16_elt) dtypeval float8_e4m3 : (float, float8_e4m3_elt) dtypeval float8_e5m2 : (float, float8_e5m2_elt) dtypeval uint16 : (int, uint16_elt) dtypeval uint32 : (int32, uint32_elt) dtypeval uint64 : (int64, uint64_elt) dtypeval complex64 : (Stdlib.Complex.t, complex32_elt) dtypeval complex128 : (Stdlib.Complex.t, complex64_elt) dtypeIndex specifications
type index = | I of int(*
*)I iselects a single index, reducing the dimension.| L of int list(*
*)L [i0; i1; …]gathers the listed indices.| R of int * int(*
*)R (start, stop)selects the half-open range [start,stop).| Rs of int * int * int(*
*)Rs (start, stop, step)selects a strided range.| A| M of (bool, bool_elt) t(*
*)M maskselects positions wheremaskistrue.| N(*
*)Ninserts a new axis of size 1 (does not consume an input axis).
Properties
val data : ('a, 'b) t -> ('a, 'b) Nx_buffer.tdata t is the underlying flat buffer of t.
The buffer is shared: mutations through the buffer are visible through t and vice-versa. The buffer may be larger than the tensor's logical extent when t is a strided view.
val shape : ('a, 'b) t -> int arrayshape t is the dimensions of t. A scalar tensor has shape |\||.
val strides : ('a, 'b) t -> int arraystrides t is the byte stride for each dimension of t.
Raises Invalid_argument if t does not have computable strides (e.g. after certain non-contiguous view operations). Use is_c_contiguous or call contiguous first.
See also stride.
val stride : int -> ('a, 'b) t -> intstride i t is the byte stride of dimension i.
Raises Invalid_argument if i is out of bounds or t does not have computable strides.
See also strides.
val dims : ('a, 'b) t -> int arraydims t is Shape manipulation.
val dim : int -> ('a, 'b) t -> intdim i t is the size of dimension i.
Raises Invalid_argument if i is out of bounds.
val ndim : ('a, 'b) t -> intndim t is the number of dimensions of t.
val itemsize : ('a, 'b) t -> intitemsize t is the number of bytes per element.
val size : ('a, 'b) t -> intsize t is the total number of elements.
val nbytes : ('a, 'b) t -> intnbytes t is size t * itemsize t.
val offset : ('a, 'b) t -> intoffset t is the element offset of t in its underlying buffer.
val is_c_contiguous : ('a, 'b) t -> boolis_c_contiguous t is true iff t's elements are laid out contiguously in row-major (C) order.
See also contiguous.
val to_bigarray :
('a, 'b) t ->
('a, 'b, Stdlib.Bigarray.c_layout) Stdlib.Bigarray.Genarray.tto_bigarray t is a contiguous bigarray with the same shape and data as t. Always copies.
Raises Invalid_argument if t's dtype is an extended type not supported by Bigarray.
See also of_bigarray.
val to_buffer : ('a, 'b) t -> ('a, 'b) Nx_buffer.tto_buffer t is a flat, contiguous buffer of t's data.
Returns the underlying buffer directly when t is already contiguous with zero offset and matching size; copies otherwise.
val to_array : ('a, 'b) t -> 'a arrayto_array t is a fresh OCaml array containing the elements of t in row-major order. Always copies.
# let t =
create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |]
in
to_array t
- : int32 array = [|1l; 2l; 3l; 4l|]Creation
create dtype shape data is a tensor of the given dtype and shape initialised from data in row-major order.
Raises Invalid_argument if Array.length data does not equal the product of shape.
# create float32 [| 2; 3 |]
[| 1.; 2.; 3.; 4.; 5.; 6. |]
- : (float, float32_elt) t = [[1, 2, 3],
[4, 5, 6]]init dtype shape f is a tensor where the element at multi-index i is f i.
# init int32 [| 2; 3 |]
(fun i -> Int32.of_int (i.(0) + i.(1)))
- : (int32, int32_elt) t = [[0, 1, 2],
[1, 2, 3]]empty dtype shape is an uninitialized tensor.
Warning. Elements contain arbitrary values until written.
full dtype shape v is a tensor filled with v.
# full float32 [| 2; 3 |] 3.14
- : (float, float32_elt) t = [[3.14, 3.14, 3.14],
[3.14, 3.14, 3.14]]zeros dtype shape is a tensor filled with zeros.
scalar dtype v is a 0-dimensional tensor containing v. The result has shape |\||.
full_like t v is full with the same dtype and shape as t.
eye ?m ?k dtype n is an n × m matrix with ones on the k-th diagonal and zeros elsewhere. m defaults to n. k defaults to 0 (main diagonal); positive k selects an upper diagonal, negative k a lower one.
# eye int32 3
- : (int32, int32_elt) t = [[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
# eye ~k:1 int32 3
- : (int32, int32_elt) t = [[0, 1, 0],
[0, 0, 1],
[0, 0, 0]]diag ?k v extracts or constructs a diagonal.
When v is 1-D, returns a 2-D tensor with v on the k-th diagonal. When v is 2-D, returns the k-th diagonal as a 1-D tensor. k defaults to 0.
Raises Invalid_argument if v is not 1-D or 2-D.
# let v = create int32 [| 3 |] [| 1l; 2l; 3l |] in
diag v
- : (int32, int32_elt) t = [[1, 0, 0],
[0, 2, 0],
[0, 0, 3]]
# let x =
arange int32 0 9 1 |> reshape [| 3; 3 |]
in
diag x
- : (int32, int32_elt) t = [0, 4, 8]arange dtype start stop step is a 1-D tensor of values from start (inclusive) to stop (exclusive) with stride step.
Raises Invalid_argument if step = 0.
# arange int32 0 10 2
- : (int32, int32_elt) t = [0, 2, 4, 6, 8]
# arange int32 5 0 (-1)
- : (int32, int32_elt) t = [5, 4, 3, 2, 1]linspace dtype ?endpoint start stop n is n values evenly spaced from start to stop. endpoint defaults to true (include stop).
Raises Invalid_argument if n is negative.
# linspace float32 0. 10. 5
- : (float, float32_elt) t = [0, 2.5, 5, 7.5, 10]
# linspace float32 ~endpoint:false 0. 10. 5
- : (float, float32_elt) t = [0, 2, 4, 6, 8]val logspace :
(float, 'a) dtype ->
?endpoint:bool ->
?base:float ->
float ->
float ->
int ->
(float, 'a) tlogspace dtype ?endpoint ?base start stop n is n values evenly spaced on a logarithmic scale: base{^x} where x ranges from start to stop. endpoint defaults to true. base defaults to 10.0.
Raises Invalid_argument if n is negative.
# logspace float32 0. 2. 3
- : (float, float32_elt) t = [1, 10, 100]
# logspace float32 ~base:2.0 0. 3. 4
- : (float, float32_elt) t = [1, 2, 4, 8]meshgrid ?indexing x y is a pair of 2-D coordinate grids built from 1-D arrays x and y. indexing defaults to `xy (Cartesian: X varies along columns, Y along rows). With `ij (matrix), X varies along rows, Y along columns.
Raises Invalid_argument if x or y is not 1-D.
# let x = linspace float32 0. 2. 3 in
let y = linspace float32 0. 1. 2 in
meshgrid x y
- : (float, float32_elt) t * (float, float32_elt) t =
([[0, 1, 2],
[0, 1, 2]], [[0, 0, 0],
[1, 1, 1]])tril ?k x is the lower-triangular part of x with elements above the k-th diagonal set to zero. k defaults to 0.
Raises Invalid_argument if x has fewer than 2 dimensions.
See also triu.
triu ?k x is the upper-triangular part of x with elements below the k-th diagonal set to zero. k defaults to 0.
Raises Invalid_argument if x has fewer than 2 dimensions.
See also tril.
val of_bigarray :
('a, 'b, Stdlib.Bigarray.c_layout) Stdlib.Bigarray.Genarray.t ->
('a, 'b) tof_bigarray ba is a tensor sharing memory with ba.
Zero-copy: mutations through either are visible to both.
See also to_bigarray.
val of_buffer : ('a, 'b) Nx_buffer.t -> shape:int array -> ('a, 'b) tof_buffer buf ~shape is a tensor viewing buf with the given shape. The product of shape must equal the buffer length.
one_hot ~num_classes indices is a one-hot encoded tensor.
Appends a new trailing dimension of size num_classes. Values in indices must lie in [0, num_classes). Out-of-range indices produce all-zero rows.
Raises Invalid_argument if indices is not an integer dtype or num_classes <= 0.
# let idx =
create int32 [| 3 |] [| 0l; 1l; 3l |]
in
one_hot ~num_classes:4 idx
- : (int, uint8_elt) t = [[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]]Random number generation
Sampling functions use the implicit RNG state managed by Rng. Wrap calls in Rng.run for reproducibility:
Rng.run ~seed:42 (fun () -> rand float32 [| 3 |])
module Rng = Nx_core.RngSplittable RNG keys with implicit-state management.
rand dtype shape samples uniformly from [0, 1).
Raises Invalid_argument if dtype is not a float type.
randn dtype shape samples from the standard normal distribution (mean 0, variance 1) via the Box–Muller transform.
Raises Invalid_argument if dtype is not a float type.
randint dtype ?high shape low samples integers uniformly from [low, high). high defaults to 10.
Raises Invalid_argument if dtype is not an integer type or low >= high.
val bernoulli : p:float -> int array -> bool_tbernoulli ~p shape samples booleans that are true with probability p.
Raises Invalid_argument if p is not in [0, 1].
val permutation : int -> int32_tpermutation n is a random permutation of [0, n-1].
Raises Invalid_argument if n <= 0.
shuffle t is a copy of t with the first axis randomly permuted. No-op on scalars.
categorical ?axis ?shape logits samples category indices from unnormalised log-probabilities using the Gumbel-max trick. axis defaults to -1 (last axis). shape prepends extra batch dimensions.
Raises Invalid_argument if logits is not a float type or axis is out of bounds.
truncated_normal dtype ~lower ~upper shape samples from a standard normal distribution truncated to [lower, upper].
Raises Invalid_argument if dtype is not a float type or lower >= upper.
Shape manipulation
reshape shape t is a view of t with the given shape.
At most one dimension may be -1; it is inferred from the total number of elements. The product of shape must equal size t.
Raises Invalid_argument if shape is incompatible or contains more than one -1.
# create int32 [| 6 |] [| 1l; 2l; 3l; 4l; 5l; 6l |]
|> reshape [| 2; 3 |]
- : (int32, int32_elt) t = [[1, 2, 3],
[4, 5, 6]]
# create int32 [| 6 |] [| 1l; 2l; 3l; 4l; 5l; 6l |]
|> reshape [| 3; -1 |]
- : (int32, int32_elt) t = [[1, 2],
[3, 4],
[5, 6]]broadcast_to shape t is a view of t broadcast to shape.
Dimensions are aligned from the right; each dimension of t must be 1 or equal to the corresponding target dimension. Broadcast dimensions have zero byte-stride (no copy).
Raises Invalid_argument if the shapes are incompatible.
# create int32 [| 1; 3 |] [| 1l; 2l; 3l |]
|> broadcast_to [| 3; 3 |]
- : (int32, int32_elt) t = [[1, 2, 3],
[1, 2, 3],
[1, 2, 3]]See also broadcasted, expand.
broadcasted ?reverse t1 t2 is (t1', t2') where both are broadcast to their common shape. When reverse is true (default false), returns (t2', t1').
Raises Invalid_argument if the shapes are incompatible.
See also broadcast_to, broadcast_arrays.
expand shape t is like broadcast_to but -1 in shape preserves the corresponding dimension of t.
Raises Invalid_argument if any dimension in shape is negative (other than -1).
# ones float32 [| 1; 4; 1 |]
|> expand [| 3; -1; 5 |] |> shape
- : int array = [|3; 4; 5|]See also broadcast_to.
flatten ?start_dim ?end_dim t collapses dimensions start_dim through end_dim (inclusive) into a single dimension. start_dim defaults to 0. end_dim defaults to -1 (last). Negative indices count from the end.
Raises Invalid_argument if indices are out of bounds.
# zeros float32 [| 2; 3; 4 |] |> flatten |> shape
- : int array = [|24|]
# zeros float32 [| 2; 3; 4; 5 |]
|> flatten ~start_dim:1 ~end_dim:2 |> shape
- : int array = [|2; 12; 5|]unflatten dim sizes t expands dimension dim into multiple dimensions given by sizes. At most one element of sizes may be -1 (inferred). The product of sizes must equal the size of dimension dim.
Raises Invalid_argument if the product mismatches or dim is out of bounds.
# zeros float32 [| 2; 12; 5 |]
|> unflatten 1 [| 3; 4 |] |> shape
- : int array = [|2; 3; 4; 5|]See also flatten.
ravel t is t reshaped to 1-D. Returns a view when possible.
Raises Invalid_argument if t cannot be flattened without copying; call contiguous first.
See also flatten, contiguous.
squeeze ?axes t removes dimensions of size 1. When axes is given, only those axes are removed. Negative indices count from the end.
Raises Invalid_argument if a specified axis does not have size 1.
# ones float32 [| 1; 3; 1; 4 |]
|> squeeze |> shape
- : int array = [|3; 4|]
# ones float32 [| 1; 3; 1; 4 |]
|> squeeze ~axes:[ 0 ] |> shape
- : int array = [|3; 1; 4|]See also unsqueeze.
unsqueeze ?axes t inserts dimensions of size 1 at the positions listed in axes. Positions refer to the result tensor.
Raises Invalid_argument if axes is not specified, contains duplicates, or values are out of bounds.
# create float32 [| 3 |] [| 1.; 2.; 3. |]
|> unsqueeze ~axes:[ 0; 2 ] |> shape
- : int array = [|1; 3; 1|]See also squeeze, expand_dims.
squeeze_axis i t removes dimension i if its size is 1.
Raises Invalid_argument if dimension i is not 1.
See also squeeze.
unsqueeze_axis i t inserts a dimension of size 1 at position i.
See also unsqueeze.
transpose ?axes t permutes the dimensions of t.
axes must be a permutation of [0; …; ndim t - 1]. When omitted, reverses all dimensions. Returns a view (no copy).
Raises Invalid_argument if axes is not a valid permutation.
# create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |]
|> transpose
- : (int32, int32_elt) t = [[1, 4],
[2, 5],
[3, 6]]See also matrix_transpose, moveaxis, swapaxes.
flip ?axes t reverses elements along the given axes. When omitted, flips all dimensions.
Raises Invalid_argument if any axis is out of bounds.
# create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |]
|> flip ~axes:[ 1 ]
- : (int32, int32_elt) t = [[3, 2, 1],
[6, 5, 4]]roll ?axis shift t shifts elements along axis by shift positions, wrapping around. When axis is omitted, operates on the flattened tensor. Negative shift rolls backward.
Raises Invalid_argument if axis is out of bounds.
# create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |]
|> roll 2
- : (int32, int32_elt) t = [4, 5, 1, 2, 3]pad widths value t pads t with value. widths.(i) is (before, after) for dimension i.
Raises Invalid_argument if Array.length widths does not match ndim t or any width is negative.
# create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |]
|> pad [| (1, 1); (1, 1) |] 0. |> shape
- : int array = [|4; 4|]See also shrink.
shrink ranges t extracts a slice where ranges.(i) is (start, stop) (exclusive) for dimension i. Returns a view.
# create int32 [| 3; 3 |]
[| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |]
|> shrink [| (1, 3); (0, 2) |]
- : (int32, int32_elt) t = [[4, 5],
[7, 8]]See also pad.
tile reps t is t repeated according to reps. reps.(i) gives the repetition count along dimension i. If reps is longer than ndim t, dimensions are prepended.
Raises Invalid_argument if any repetition count is negative.
# create int32 [| 1; 2 |] [| 1l; 2l |]
|> tile [| 2; 3 |]
- : (int32, int32_elt) t = [[1, 2, 1, 2, 1, 2],
[1, 2, 1, 2, 1, 2]]See also repeat.
repeat ?axis n t repeats each element n times along axis. When axis is omitted, operates on the flattened tensor.
Raises Invalid_argument if n is negative or axis is out of bounds.
# create int32 [| 3 |] [| 1l; 2l; 3l |]
|> repeat 2
- : (int32, int32_elt) t = [1, 1, 2, 2, 3, 3]See also tile.
Combining and splitting
concatenate ?axis ts joins tensors along an existing axis. All tensors must have the same shape except on the concatenation axis. When axis is omitted, every tensor is flattened first. Always copies.
Raises Invalid_argument if the list is empty or shapes are incompatible.
# let a = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let b = create int32 [| 1; 2 |] [| 5l; 6l |] in
concatenate ~axis:0 [ a; b ]
- : (int32, int32_elt) t = [[1, 2],
[3, 4],
[5, 6]]stack ?axis ts joins tensors along a new axis. All tensors must have identical shape. axis defaults to 0. Negative values count from the end of the result shape.
Raises Invalid_argument if the list is empty, shapes differ, or axis is out of bounds.
# let a = create int32 [| 2 |] [| 1l; 2l |] in
let b = create int32 [| 2 |] [| 3l; 4l |] in
stack [ a; b ]
- : (int32, int32_elt) t = [[1, 2],
[3, 4]]
# let a = create int32 [| 2 |] [| 1l; 2l |] in
let b = create int32 [| 2 |] [| 3l; 4l |] in
stack ~axis:1 [ a; b ]
- : (int32, int32_elt) t = [[1, 3],
[2, 4]]See also concatenate.
vstack ts stacks vertically (along axis 0). 1-D tensors are treated as row vectors (shape [1; n]).
Raises Invalid_argument if shapes are incompatible.
# let a = create int32 [| 3 |] [| 1l; 2l; 3l |] in
let b = create int32 [| 3 |] [| 4l; 5l; 6l |] in
vstack [ a; b ]
- : (int32, int32_elt) t = [[1, 2, 3],
[4, 5, 6]]See also hstack, dstack, concatenate.
hstack ts stacks horizontally. 1-D tensors are concatenated directly; higher-D tensors concatenate along axis 1.
Raises Invalid_argument if shapes are incompatible.
# let a = create int32 [| 2; 1 |] [| 1l; 2l |] in
let b = create int32 [| 2; 1 |] [| 3l; 4l |] in
hstack [ a; b ]
- : (int32, int32_elt) t = [[1, 3],
[2, 4]]See also vstack, dstack, concatenate.
dstack ts stacks depth-wise (along axis 2). Tensors are reshaped to at least 3-D before concatenation: 1-D [n] → [1; n; 1], 2-D [m; n] → [m; n; 1].
Raises Invalid_argument if the resulting shapes are incompatible.
See also vstack, hstack, concatenate.
broadcast_arrays ts broadcasts every tensor to their common shape. Returns views (no copies).
Raises Invalid_argument if shapes are incompatible.
See also broadcast_to, broadcasted.
val array_split :
axis:int ->
[< `Count of int | `Indices of int list ] ->
('a, 'b) t ->
('a, 'b) t listarray_split ~axis spec t splits t into sub-tensors.
With `Count n, divides as evenly as possible (first sections absorb extra elements). With `Indices [i0; i1; …], splits at the given indices producing [0, i0), [i0, i1), …, [ik, end).
Raises Invalid_argument if axis is out of bounds or spec is invalid.
# create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |]
|> array_split ~axis:0 (`Count 3)
- : (int32, int32_elt) t list = [[1, 2]; [3, 4]; [5]]See also split.
split ~axis n t splits t into n equal parts along axis.
Raises Invalid_argument if the axis size is not divisible by n.
See also array_split.
Type conversion and copying
cast dtype t is a copy of t with elements converted to dtype.
# create float32 [| 3 |] [| 1.5; 2.7; 3.1 |]
|> cast int32
- : (int32, int32_elt) t = [1, 2, 3]See also contiguous, copy.
contiguous t is t if it is already C-contiguous, or a fresh contiguous copy otherwise.
See also is_c_contiguous, copy.
copy t is a deep copy of t. Always allocates new memory; the result is contiguous.
# let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
let y = copy x in
set_item [ 0 ] 999. y;
x, y
- : (float, float32_elt) t * (float, float32_elt) t =
([1, 2, 3], [999, 2, 3])See also contiguous.
blit src dst copies the elements of src into dst in-place. Shapes must match exactly.
Raises Invalid_argument if shapes differ.
fill v t is a fresh copy of t with every element set to v. Does not mutate t.
Indexing and slicing
get indices t is the sub-tensor at indices, indexing from the outermost dimension inward. Returns a scalar tensor when all dimensions are indexed; otherwise a view of the remaining dimensions. Negative indices count from the end.
Raises Invalid_argument if any index is out of bounds.
# let x =
create int32 [| 2; 3 |]
[| 1l; 2l; 3l; 4l; 5l; 6l |]
in
get [ 1 ] x
- : (int32, int32_elt) t = [4, 5, 6]set indices t v writes v at the position given by indices.
Raises Invalid_argument if indices are out of bounds.
slice specs t extracts a sub-tensor using advanced indexing.
Each element of specs addresses one axis from left to right:
I i— single index (reduces dimension; negative from end).L [i0; i1; …]— gather listed indices.R (start, stop)— half-open range [start,stop).Rs (start, stop, step)— strided range.A— full axis (default for trailing axes).M mask— boolean mask selecting positions wheremaskistrue.N— insert a new axis of size 1.
Returns a view when possible.
Raises Invalid_argument if specs are out of bounds, if step is zero, or if a mask spec is used (not yet supported).
# let x =
create int32 [| 3; 3 |]
[| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |]
in
slice [ R (0, 2); L [ 0; 2 ] ] x
- : (int32, int32_elt) t = [[1, 3],
[4, 6]]set_slice specs t v writes v into the region of t selected by specs. v is broadcast if needed.
Raises Invalid_argument if N (new-axis) specs are used (not supported for writes).
See also slice.
val item : int list -> ('a, 'b) t -> 'aval set_item : int list -> 'a -> ('a, 'b) t -> unitset_item indices v t sets the element at indices to v in-place. Indices must cover all dimensions.
Raises Invalid_argument if the number of indices is wrong or any index is out of bounds.
See also item.
val take :
?axis:int ->
?mode:[ `raise | `wrap | `clip ] ->
(int32, int32_elt) t ->
('a, 'b) t ->
('a, 'b) ttake ?axis ?mode indices t gathers elements from t at indices along axis. When axis is omitted, t is flattened first. mode controls out-of-bounds indices: `raise (default) raises, `wrap uses modular indexing, `clip clamps to bounds.
Raises Invalid_argument if mode is `raise and any index is out of bounds.
# let x =
create int32 [| 5 |]
[| 0l; 1l; 2l; 3l; 4l |]
in
take
(create int32 [| 3 |] [| 1l; 3l; 0l |])
x
- : (int32, int32_elt) t = [1, 3, 0]See also put, take_along_axis.
take_along_axis ~axis indices t gathers values from t along axis using indices. indices must match t's shape except along axis. Useful for gathering from argmax/argmin results.
Raises Invalid_argument if shapes are incompatible.
# let x =
create float32 [| 2; 3 |]
[| 4.; 1.; 2.; 3.; 5.; 6. |]
in
let idx =
create int32 [| 2; 1 |] [| 1l; 0l |]
in
take_along_axis ~axis:1 idx x
- : (float, float32_elt) t = [[1],
[3]]See also take, put_along_axis.
val put :
?axis:int ->
indices:(int32, int32_elt) t ->
values:('a, 'b) t ->
?mode:[ `raise | `wrap | `clip ] ->
('a, 'b) t ->
unitput ?axis ~indices ~values ?mode t writes values into t at positions given by indices. When axis is omitted, t is flattened first. mode defaults to `raise. Modifies t in-place.
Raises Invalid_argument if mode is `raise and any index is out of bounds.
See also take, put_along_axis, index_put.
val index_put :
indices:(int32, int32_elt) t array ->
values:('a, 'b) t ->
?mode:[ `raise | `wrap | `clip ] ->
('a, 'b) t ->
unitindex_put ~indices ~values ?mode t writes values into t at the coordinates given by indices.
indices contains one index tensor per axis of t; they are broadcast to a common shape that determines the number of updates. values is broadcast to the same shape. Duplicate coordinates overwrite. mode defaults to `raise.
Raises Invalid_argument if the number of index tensors does not match ndim t.
# let t = zeros float32 [| 3; 3 |] in
let rows =
create int32 [| 3 |] [| 0l; 2l; 1l |]
in
let cols =
create int32 [| 3 |] [| 1l; 0l; 2l |]
in
index_put ~indices:[| rows; cols |]
~values:(create float32 [| 3 |]
[| 10.; 20.; 30. |])
t;
t
- : (float, float32_elt) t = [[0, 10, 0],
[0, 0, 30],
[20, 0, 0]]See also put.
val put_along_axis :
axis:int ->
indices:(int32, int32_elt) t ->
values:('a, 'b) t ->
('a, 'b) t ->
unitput_along_axis ~axis ~indices ~values t writes values into t at positions selected by indices along axis. Modifies t in-place.
Raises Invalid_argument if shapes are incompatible.
See also take_along_axis, put.
compress ?axis ~condition t selects elements where condition is true along axis. condition must be 1-D. When axis is omitted, t is flattened first.
Raises Invalid_argument if the condition length is incompatible.
# let x =
create int32 [| 5 |]
[| 1l; 2l; 3l; 4l; 5l |]
in
compress
~condition:(create bool [| 5 |]
[| true; false; true; false; true |])
x
- : (int32, int32_elt) t = [1, 3, 5]nonzero t is an array of 1-D index tensors, one per dimension, giving the coordinates of non-zero elements.
# let x =
create int32 [| 3; 3 |]
[| 0l; 1l; 0l;
2l; 0l; 3l;
0l; 0l; 4l |]
in
let idx = nonzero x in
idx.(0), idx.(1)
- : (int32, int32_elt) t * (int32, int32_elt) t =
([0, 1, 1, 2], [1, 0, 2, 2])See also argwhere.
argwhere t is a 2-D tensor of shape [k; ndim t] whose rows are the coordinates of the k non-zero elements.
See also nonzero.
Arithmetic
Element-wise arithmetic with broadcasting. Each operation op has variants:
op_s t s— tensor-scalar.rop_s s t— scalar-tensor (reversed operands).
add ?out a b is the element-wise sum of a and b. out defaults to a fresh allocation.
add_s ?out t s adds scalar s to each element of t. out defaults to a fresh allocation.
sub ?out a b is the element-wise difference a - b. out defaults to a fresh allocation.
sub_s ?out t s subtracts scalar s from each element. out defaults to a fresh allocation.
rsub_s ?out s t is s - t element-wise. out defaults to a fresh allocation.
mul ?out a b is the element-wise product of a and b. out defaults to a fresh allocation.
mul_s ?out t s multiplies each element by scalar s. out defaults to a fresh allocation.
div ?out a b is the element-wise quotient a / b. out defaults to a fresh allocation.
Float dtypes use true division. Integer dtypes truncate toward zero.
# let x = create int32 [| 2 |] [| -7l; 8l |] in
let y = create int32 [| 2 |] [| 2l; 2l |] in
div x y
- : (int32, int32_elt) t = [-3, 4]div_s ?out t s divides each element by scalar s. out defaults to a fresh allocation.
rdiv_s ?out s t is s / t element-wise. out defaults to a fresh allocation.
pow ?out base exp is base raised to exp element-wise. out defaults to a fresh allocation.
pow_s ?out t s raises each element to scalar power s. out defaults to a fresh allocation.
rpow_s ?out s t is s{^t} element-wise. out defaults to a fresh allocation.
mod_ ?out a b is the element-wise remainder of a / b. out defaults to a fresh allocation.
mod_s ?out t s is the remainder of each element divided by scalar s. out defaults to a fresh allocation.
rmod_s ?out s t is s mod t element-wise. out defaults to a fresh allocation.
neg ?out t is the element-wise negation of t. out defaults to a fresh allocation.
conjugate t is the complex conjugate of t. For complex dtypes, negates the imaginary part. For real dtypes, returns t unchanged.
Mathematical functions
Basic
abs ?out t is the element-wise absolute value. out defaults to a fresh allocation.
sign ?out t is -1, 0, or 1 according to the sign of each element. For unsigned types, returns 1 for non-zero, 0 for zero. out defaults to a fresh allocation.
# create float32 [| 3 |] [| -2.; 0.; 3.5 |]
|> sign
- : (float, float32_elt) t = [-1, 0, 1]square ?out t is the element-wise square. out defaults to a fresh allocation.
sqrt ?out t is the element-wise square root. out defaults to a fresh allocation.
rsqrt ?out t is the element-wise reciprocal square root (1 / sqrt t). out defaults to a fresh allocation.
recip ?out t is the element-wise reciprocal (1 / t). out defaults to a fresh allocation.
Exponential and logarithmic
log ?out t is the element-wise natural logarithm. out defaults to a fresh allocation.
log2 ?out t is the element-wise base-2 logarithm. out defaults to a fresh allocation.
exp ?out t is the element-wise exponential. out defaults to a fresh allocation.
exp2 ?out t is 2{^t} element-wise. out defaults to a fresh allocation.
Trigonometric
sin ?out t is the element-wise sine. out defaults to a fresh allocation.
cos ?out t is the element-wise cosine. out defaults to a fresh allocation.
tan ?out t is the element-wise tangent. out defaults to a fresh allocation.
asin ?out t is the element-wise arcsine. out defaults to a fresh allocation.
acos ?out t is the element-wise arccosine. out defaults to a fresh allocation.
atan ?out t is the element-wise arctangent. out defaults to a fresh allocation.
atan2 ?out y x is the element-wise two-argument arctangent, returning angles in [-π, π]. out defaults to a fresh allocation.
Hyperbolic
sinh ?out t is the element-wise hyperbolic sine. out defaults to a fresh allocation.
cosh ?out t is the element-wise hyperbolic cosine. out defaults to a fresh allocation.
tanh ?out t is the element-wise hyperbolic tangent. out defaults to a fresh allocation.
asinh ?out t is the element-wise inverse hyperbolic sine. out defaults to a fresh allocation.
acosh ?out t is the element-wise inverse hyperbolic cosine. out defaults to a fresh allocation.
atanh ?out t is the element-wise inverse hyperbolic tangent. out defaults to a fresh allocation.
Rounding
trunc ?out t rounds each element toward zero. out defaults to a fresh allocation.
ceil ?out t rounds each element toward positive infinity. out defaults to a fresh allocation.
floor ?out t rounds each element toward negative infinity. out defaults to a fresh allocation.
round ?out t rounds each element to the nearest integer. Ties round away from zero (not banker's rounding). out defaults to a fresh allocation.
# create float32 [| 4 |] [| 2.5; 3.5; -2.5; -3.5 |]
|> round
- : (float, float32_elt) t = [3, 4, -3, -4]Other
hypot ?out x y is sqrt(x² + y²) computed without intermediate overflow. out defaults to a fresh allocation.
# hypot (scalar float32 3.) (scalar float32 4.)
|> item []
- : float = 5.lerp ?out a b w is the linear interpolation a + w * (b - a). w is typically in [0, 1]. out defaults to a fresh allocation.
# let a = create float32 [| 2 |] [| 1.; 2. |] in
let b = create float32 [| 2 |] [| 5.; 8. |] in
lerp a b (scalar float32 0.25)
- : (float, float32_elt) t = [2, 3.5]lerp_scalar_weight ?out a b w is like lerp with a scalar weight. out defaults to a fresh allocation.
isinf ?out t is true where t is positive or negative infinity, false elsewhere. Non-float dtypes always return all false. out defaults to a fresh allocation.
# create float32 [| 4 |]
[| 1.; Float.infinity;
Float.neg_infinity; Float.nan |]
|> isinf
- : (bool, bool_elt) t = [false, true, true, false]Comparison and logic
cmplt ?out a b is true where a < b, false elsewhere. out defaults to a fresh allocation.
less a b is cmplt.
less_s ?out t s is true where t < s. out defaults to a fresh allocation.
cmpne ?out a b is true where a ≠ b, false elsewhere. out defaults to a fresh allocation.
not_equal a b is cmpne.
not_equal_s ?out t s is true where t ≠ s. out defaults to a fresh allocation.
cmpeq ?out a b is true where a = b, false elsewhere. out defaults to a fresh allocation.
equal a b is cmpeq.
equal_s ?out t s is true where t = s. out defaults to a fresh allocation.
cmpgt ?out a b is true where a > b, false elsewhere. out defaults to a fresh allocation.
greater a b is cmpgt.
greater_s ?out t s is true where t > s. out defaults to a fresh allocation.
cmple ?out a b is true where a ≤ b, false elsewhere. out defaults to a fresh allocation.
less_equal a b is cmple.
less_equal_s ?out t s is true where t ≤ s. out defaults to a fresh allocation.
cmpge ?out a b is true where a ≥ b, false elsewhere. out defaults to a fresh allocation.
greater_equal a b is cmpge.
greater_equal_s ?out t s is true where t ≥ s. out defaults to a fresh allocation.
array_equal a b is a scalar true iff all elements of a and b are equal. Returns false if shapes differ.
# let a = create int32 [| 3 |] [| 1l; 2l; 3l |] in
let b = create int32 [| 3 |] [| 1l; 2l; 3l |] in
array_equal a b |> item []
- : bool = truemaximum ?out a b is the element-wise maximum of a and b. out defaults to a fresh allocation.
maximum_s ?out t s is the element-wise maximum of t and scalar s. out defaults to a fresh allocation.
rmaximum_s ?out s t is maximum_s ?out t s.
minimum ?out a b is the element-wise minimum of a and b. out defaults to a fresh allocation.
minimum_s ?out t s is the element-wise minimum of t and scalar s. out defaults to a fresh allocation.
rminimum_s ?out s t is minimum_s ?out t s.
logical_and ?out a b is the element-wise logical AND. Non-zero is true. out defaults to a fresh allocation.
logical_or ?out a b is the element-wise logical OR. out defaults to a fresh allocation.
logical_xor ?out a b is the element-wise logical XOR. out defaults to a fresh allocation.
logical_not ?out t is the element-wise logical NOT: non-zero becomes 0, zero becomes 1. out defaults to a fresh allocation.
where ?out cond if_true if_false selects elements from if_true where cond is true and from if_false elsewhere. All three inputs broadcast to a common shape. out defaults to a fresh allocation.
# let x =
create float32 [| 4 |] [| -1.; 2.; -3.; 4. |]
in
where
(cmpgt x (scalar float32 0.))
x (scalar float32 0.)
- : (float, float32_elt) t = [0, 2, 0, 4]clamp ?out ?min ?max t clamps elements to [min, max]. Either bound may be omitted. out defaults to a fresh allocation.
See also clip.
clip ?out ?min ?max t is clamp.
Bitwise operations
bitwise_xor ?out a b is the element-wise bitwise XOR. out defaults to a fresh allocation.
bitwise_or ?out a b is the element-wise bitwise OR. out defaults to a fresh allocation.
bitwise_and ?out a b is the element-wise bitwise AND. out defaults to a fresh allocation.
bitwise_not ?out t is the element-wise bitwise NOT. out defaults to a fresh allocation.
invert ?out t is bitwise_not.
lshift ?out t n left-shifts each element by n bits. out defaults to a fresh allocation.
Raises Invalid_argument if n is negative or the dtype is not an integer type.
# create int32 [| 3 |] [| 1l; 2l; 3l |]
|> Fun.flip lshift 2
- : (int32, int32_elt) t = [4, 8, 12]See also rshift.
rshift ?out t n right-shifts each element by n bits. out defaults to a fresh allocation.
Raises Invalid_argument if n is negative or the dtype is not an integer type.
See also lshift.
Infix operators
module Infix : sig ... endReductions
sum ?out ?axes ?keepdims t sums elements along axes. When axes is omitted, reduces all axes (returns a scalar). When keepdims is true, reduced axes are kept with size 1. keepdims defaults to false. Negative axes count from the end. out defaults to a fresh allocation.
# create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |]
|> sum |> item []
- : float = 10.
# create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |]
|> sum ~axes:[ 0 ]
- : (float, float32_elt) t = [4, 6]
# create float32 [| 1; 2 |] [| 1.; 2. |]
|> sum ~axes:[ 1 ] ~keepdims:true
- : (float, float32_elt) t = [[3]]max ?out ?axes ?keepdims t is the maximum along axes. NaN propagates. keepdims defaults to false. out defaults to a fresh allocation.
# create float32 [| 2; 3 |]
[| 1.; 2.; 3.; 4.; 5.; 6. |]
|> max |> item []
- : float = 6.min ?out ?axes ?keepdims t is the minimum along axes. NaN propagates. keepdims defaults to false. out defaults to a fresh allocation.
prod ?out ?axes ?keepdims t is the product along axes. keepdims defaults to false. out defaults to a fresh allocation.
# create int32 [| 3 |] [| 2l; 3l; 4l |]
|> prod |> item []
- : int32 = 24lcumsum ?axis t is the inclusive cumulative sum along axis. When axis is omitted, operates on the flattened tensor.
See also cumprod.
cumprod ?axis t is the inclusive cumulative product along axis. When axis is omitted, operates on the flattened tensor.
See also cumsum.
cummax ?axis t is the inclusive cumulative maximum along axis. NaN propagates for floating-point dtypes. When axis is omitted, operates on the flattened tensor.
See also cummin.
cummin ?axis t is the inclusive cumulative minimum along axis. NaN propagates for floating-point dtypes. When axis is omitted, operates on the flattened tensor.
See also cummax.
mean ?out ?axes ?keepdims t is the arithmetic mean along axes. NaN propagates. keepdims defaults to false. out defaults to a fresh allocation.
# create float32 [| 4 |] [| 1.; 2.; 3.; 4. |]
|> mean |> item []
- : float = 2.5val var :
?out:('a, 'b) t ->
?axes:int list ->
?keepdims:bool ->
?ddof:int ->
('a, 'b) t ->
('a, 'b) tvar ?out ?axes ?keepdims ?ddof t is the variance along axes. ddof (delta degrees of freedom) defaults to 0 (population variance); use 1 for sample variance. Computed as E[(X - E[X])²] / (N - ddof). keepdims defaults to false. out defaults to a fresh allocation.
Raises Invalid_argument if ddof >= N.
# create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |]
|> var |> item []
- : float = 2.
# create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |]
|> var ~ddof:1 |> item []
- : float = 2.5See also std.
val std :
?out:('a, 'b) t ->
?axes:int list ->
?keepdims:bool ->
?ddof:int ->
('a, 'b) t ->
('a, 'b) tstd ?out ?axes ?keepdims ?ddof t is the standard deviation: sqrt({!var} ~ddof t). ddof defaults to 0. keepdims defaults to false. out defaults to a fresh allocation.
See also var.
val all :
?out:(bool, bool_elt) t ->
?axes:int list ->
?keepdims:bool ->
('a, 'b) t ->
(bool, bool_elt) tall ?out ?axes ?keepdims t is true iff every element along axes is non-zero. keepdims defaults to false. out defaults to a fresh allocation.
# create int32 [| 3 |] [| 1l; 2l; 3l |]
|> all |> item []
- : bool = true
# create int32 [| 3 |] [| 1l; 0l; 3l |]
|> all |> item []
- : bool = falseSee also any.
val any :
?out:(bool, bool_elt) t ->
?axes:int list ->
?keepdims:bool ->
('a, 'b) t ->
(bool, bool_elt) tany ?out ?axes ?keepdims t is true iff at least one element along axes is non-zero. keepdims defaults to false. out defaults to a fresh allocation.
See also all.
argmax ?axis ?keepdims t is the index of the maximum along axis. Returns the first occurrence for ties. When axis is omitted, operates on the flattened tensor. keepdims defaults to false.
Raises Invalid_argument if axis is out of bounds.
# create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |]
|> argmax |> item []
- : int32 = 4lSee also argmin.
argmin ?axis ?keepdims t is the index of the minimum along axis. Returns the first occurrence for ties. When axis is omitted, operates on the flattened tensor. keepdims defaults to false.
Raises Invalid_argument if axis is out of bounds.
See also argmax.
Sorting and searching
sort ?descending ?axis t sorts elements along axis and returns (sorted, indices) where indices maps sorted positions back to originals. descending defaults to false. axis defaults to -1 (last).
The sort is stable (equal elements preserve their relative order). NaN sorts to the end in ascending order and to the beginning in descending order.
Raises Invalid_argument if axis is out of bounds.
# create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |]
|> sort
- : (int32, int32_elt) t * (int32, int32_elt) t =
([1, 1, 3, 4, 5], [1, 3, 0, 2, 4])See also argsort.
argsort ?descending ?axis t is snd (sort ?descending ?axis t).
See also sort.
Linear algebra
Products
dot ?out a b is the generalised dot product. out defaults to a fresh allocation.
Contracts the last axis of a with:
- the only axis of
bwhenbis 1-D, - the second-to-last axis of
botherwise.
Dimension rules:
- 1-D × 1-D → scalar (inner product).
- 2-D × 2-D → matrix multiplication.
- N-D × M-D → contraction; output axes are the non-contracted axes of
afollowed by those ofb.
Note. Unlike matmul, dot does not broadcast batch dimensions—it concatenates them.
Raises Invalid_argument if contraction axes differ in size or either input is 0-D.
# let a = create float32 [| 2 |] [| 1.; 2. |] in
let b = create float32 [| 2 |] [| 3.; 4. |] in
dot a b |> item []
- : float = 11.
# dot (ones float32 [| 3; 4; 5 |])
(ones float32 [| 5; 6 |]) |> shape
- : int array = [|3; 4; 6|]matmul ?out a b is the matrix product of a and b with batch broadcasting. out defaults to a fresh allocation; ignored when either input is 1-D.
Dimension rules:
- 1-D × 1-D → scalar (inner product).
- 1-D × N-D →
ais treated as a row vector. - N-D × 1-D →
bis treated as a column vector. - N-D × M-D → matrix multiply on last two axes; leading axes are broadcast.
Raises Invalid_argument if inputs are 0-D or inner dimensions mismatch.
# let a =
create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |]
in
let b = create float32 [| 2 |] [| 5.; 6. |] in
matmul a b
- : (float, float32_elt) t = [17, 39]
# matmul (ones float32 [| 1; 3; 4 |])
(ones float32 [| 5; 4; 2 |]) |> shape
- : int array = [|5; 3; 2|]matrix_transpose t swaps the last two axes: […; m; n] → […; n; m]. For 1-D tensors, returns t unchanged.
See also transpose.
outer ?out a b is the outer product. Inputs are flattened to 1-D; the result has shape [numel a; numel b]. out defaults to a fresh allocation.
See also inner.
tensordot ?axes a b contracts a and b along the specified axis pairs. axes defaults to contracting the last axis of a with the first axis of b.
Raises Invalid_argument if the contracted axes have different sizes.
kron a b is the Kronecker product. The result has shape [a.shape.(i) * b.shape.(i)] for each i.
multi_dot ts is the chained matrix product of ts, automatically choosing the association order that minimises computation.
Raises Invalid_argument if the array is empty, shapes are incompatible, or dtypes are not floating-point or complex.
See also matmul.
matrix_power t n raises square matrix t to integer power n. n = 0 returns the identity; n < 0 uses the inverse.
Raises Invalid_argument if t is not square, the dtype is not floating-point or complex, or n < 0 and t is singular.
cross ?out ?axis a b is the cross product of 3-element vectors along axis. axis defaults to -1. out defaults to a fresh allocation.
Raises Invalid_argument if the axis dimension is not 3.
Decompositions
cholesky ?upper a is the Cholesky factor of positive- definite matrix a. When upper is true, returns the upper-triangular factor U such that a = Uᵀ U; otherwise (default) returns the lower-triangular factor L such that a = L Lᵀ.
Raises Invalid_argument if a is not positive-definite or the dtype is not floating-point or complex.
See also solve.
qr ?mode a is (Q, R) where a = Q R, Q is orthogonal, and R is upper-triangular. mode defaults to `Reduced.
Raises Invalid_argument if the dtype is not floating-point or complex.
See also svd.
val svd :
?full_matrices:bool ->
('a, 'b) t ->
('a, 'b) t * (float, float64_elt) t * ('a, 'b) tval svdvals : ('a, 'b) t -> (float, float64_elt) tsvdvals a is the singular values of a in descending order. More efficient than svd when only the values are needed.
Raises Invalid_argument if the dtype is not floating-point or complex.
Eigenvalues and eigenvectors
val eig :
('a, 'b) t ->
(Stdlib.Complex.t, complex64_elt) t * (Stdlib.Complex.t, complex64_elt) tval eigh :
?uplo:[ `U | `L ] ->
('a, 'b) t ->
(float, float64_elt) t * ('a, 'b) tval eigvals : ('a, 'b) t -> (Stdlib.Complex.t, complex64_elt) tval eigvalsh : ?uplo:[ `U | `L ] -> ('a, 'b) t -> (float, float64_elt) tNorms and invariants
val norm :
?ord:
[ `Fro
| `Nuc
| `One
| `Two
| `Inf
| `NegOne
| `NegTwo
| `NegInf
| `P of float ] ->
?axes:int list ->
?keepdims:bool ->
('a, 'b) t ->
('a, 'b) tnorm ?ord ?axes ?keepdims t is the matrix or vector norm. ord defaults to Frobenius for matrices, 2-norm for vectors. keepdims defaults to false.
`Fro— Frobenius norm.`Nuc— nuclear norm (sum of singular values).`One— max absolute column sum (matrix) or 1-norm (vector).`Two— largest singular value (matrix) or 2-norm (vector).`Inf— max absolute row sum (matrix) or ∞-norm (vector).`P p— p-norm (vectors only).`NegOne,`NegTwo,`NegInf— corresponding minimum norms.
Raises Invalid_argument if ord requires a floating-point or complex dtype.
val cond :
?p:[ `One | `Two | `Inf | `NegOne | `NegTwo | `NegInf | `Fro ] ->
('a, 'b) t ->
('a, 'b) tcond ?p a is the condition number of a in the p-norm. p defaults to `Two.
Raises Invalid_argument if the dtype is not floating-point or complex.
det a is the determinant of square matrix a.
Raises Invalid_argument if a is not square or the dtype is not floating-point or complex.
val slogdet : ('a, 'b) t -> (float, float32_elt) t * (float, float32_elt) tslogdet a is (sign, log_abs_det) where det a = sign * exp(log_abs_det). More numerically stable than det for matrices with very large or small determinants.
Raises Invalid_argument if a is not square or the dtype is not floating-point or complex.
val matrix_rank :
?tol:float ->
?rtol:float ->
?hermitian:bool ->
('a, 'b) t ->
intmatrix_rank ?tol ?rtol ?hermitian a is the rank of a, counting singular values above the tolerance. rtol defaults to max(M, N) * ε * σ_max. When hermitian is true (default false), uses a more efficient eigenvalue-based algorithm.
Raises Invalid_argument if the dtype is not floating-point or complex.
trace ?out ?offset t is the sum along the offset-th diagonal. offset defaults to 0. out defaults to a fresh allocation.
Raises Invalid_argument if t has fewer than 2 dimensions.
See also diagonal.
Solving
val lstsq :
?rcond:float ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t * ('a, 'b) t * int * (float, float64_elt) tlstsq ?rcond a b is (x, residuals, rank, sv) — the least-squares solution to a @@ x ≈ b. rcond defaults to machine precision.
Raises Invalid_argument if the dtype is not floating-point or complex.
See also solve.
pinv ?rtol ?hermitian a is the Moore–Penrose pseudoinverse of a. Handles non-square and singular matrices. hermitian defaults to false.
Raises Invalid_argument if the dtype is not floating-point or complex.
See also inv.
tensorsolve ?axes a b solves the tensor equation tensordot a x axes = b for x.
Raises Invalid_argument if shapes are incompatible or the dtype is not floating-point or complex.
tensorinv ?ind a is the tensor inverse such that tensordot a (tensorinv a) ind is the identity. ind defaults to 2.
Raises Invalid_argument if the result is not square in the specified dimensions or the dtype is not floating-point or complex.
Fourier transforms
FFT normalisation mode.
`Backward— normalise by1/non the inverse (default).`Forward— normalise by1/non the forward.`Ortho— normalise by1/√non both.
val ifft2 :
?out:(Stdlib.Complex.t, 'a) t ->
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(Stdlib.Complex.t, 'a) t ->
(Stdlib.Complex.t, 'a) tifft2 ?out ?axes ?s ?norm x is the inverse of fft2. out defaults to a fresh allocation.
val fftn :
?out:(Stdlib.Complex.t, 'a) t ->
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(Stdlib.Complex.t, 'a) t ->
(Stdlib.Complex.t, 'a) tfftn ?out ?axes ?s ?norm x is the N-D FFT. axes defaults to all. out defaults to a fresh allocation.
See also ifftn.
val ifftn :
?out:(Stdlib.Complex.t, 'a) t ->
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(Stdlib.Complex.t, 'a) t ->
(Stdlib.Complex.t, 'a) tifftn ?out ?axes ?s ?norm x is the inverse of fftn. out defaults to a fresh allocation.
val rfft :
?out:(Stdlib.Complex.t, complex64_elt) t ->
?axis:int ->
?n:int ->
?norm:fft_norm ->
(float, 'a) t ->
(Stdlib.Complex.t, complex64_elt) trfft ?out ?axis ?n ?norm x is the 1-D FFT of real input. Returns only the non-redundant positive frequencies; the output size along the transformed axis is n/2 + 1. out defaults to a fresh allocation.
# create float64 [| 4 |] [| 0.; 1.; 2.; 3. |]
|> rfft |> shape
- : int array = [|3|]val irfft :
?out:(float, float64_elt) t ->
?axis:int ->
?n:int ->
?norm:fft_norm ->
(Stdlib.Complex.t, 'a) t ->
(float, float64_elt) tval rfft2 :
?out:(Stdlib.Complex.t, complex64_elt) t ->
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(float, 'a) t ->
(Stdlib.Complex.t, complex64_elt) tval irfft2 :
?out:(float, float64_elt) t ->
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(Stdlib.Complex.t, 'a) t ->
(float, float64_elt) tirfft2 ?out ?axes ?s ?norm x is the inverse of rfft2. out defaults to a fresh allocation.
val rfftn :
?out:(Stdlib.Complex.t, complex64_elt) t ->
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(float, 'a) t ->
(Stdlib.Complex.t, complex64_elt) tval irfftn :
?out:(float, float64_elt) t ->
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(Stdlib.Complex.t, 'a) t ->
(float, float64_elt) tirfftn ?out ?axes ?s ?norm x is the inverse of rfftn. out defaults to a fresh allocation.
val hfft :
?axis:int ->
?n:int ->
?norm:fft_norm ->
(Stdlib.Complex.t, 'a) t ->
(float, float64_elt) thfft ?axis ?n ?norm x is the FFT of a signal with Hermitian symmetry, producing real output.
val ihfft :
?axis:int ->
?n:int ->
?norm:fft_norm ->
(float, 'a) t ->
(Stdlib.Complex.t, complex64_elt) tihfft ?axis ?n ?norm x is the inverse of hfft.
val fftfreq : ?d:float -> int -> (float, float64_elt) tfftfreq ?d n is the DFT sample frequencies for window length n and sample spacing d (default 1.0).
# fftfreq 4
- : (float, float64_elt) t = [0, 0.25, -0.5, -0.25]See also rfftfreq.
val rfftfreq : ?d:float -> int -> (float, float64_elt) trfftfreq ?d n is the positive DFT sample frequencies: [0, 1, …, n/2] / (d * n).
See also fftfreq.
fftshift ?axes t shifts the zero-frequency component to the centre. axes defaults to all.
# fftfreq 5 |> fftshift
- : (float, float64_elt) t = [-0.4, -0.2, 0, 0.2, 0.4]See also ifftshift.
ifftshift ?axes t is the inverse of fftshift.
Activation functions
relu ?out t is max(0, t) element-wise. out defaults to a fresh allocation.
# create float32 [| 5 |]
[| -2.; -1.; 0.; 1.; 2. |]
|> relu
- : (float, float32_elt) t = [0, 0, 0, 1, 2]sigmoid ?out t is 1 / (1 + exp(-t)) element-wise. Output in (0, 1). out defaults to a fresh allocation.
# sigmoid (scalar float32 0.) |> item []
- : float = 0.5softmax ?out ?axes ?scale t is the softmax normalisation exp(scale * (t - max t)) / Σ exp(scale * (t - max t)). axes defaults to [-1]. scale defaults to 1.0. Output sums to 1 along the specified axes. out defaults to a fresh allocation.
# create float32 [| 3 |] [| 1.; 2.; 3. |]
|> softmax |> sum |> item []
- : float = 1.See also log_softmax.
logsumexp ?out ?axes ?keepdims t is log(Σ exp(t)) computed in a numerically stable way. axes defaults to all. keepdims defaults to false. out defaults to a fresh allocation.
See also logmeanexp, log_softmax.
val standardize :
?out:('a, 'b) t ->
?axes:int list ->
?mean:('a, 'b) t ->
?variance:('a, 'b) t ->
?epsilon:float ->
('a, 'b) t ->
('a, 'b) tstandardize ?out ?axes ?mean ?variance ?epsilon t is (t - mean) / sqrt(variance + epsilon). When mean or variance are omitted, they are computed along axes (default all). epsilon defaults to 1e-5. out defaults to a fresh allocation.
erf ?out t is the error function erf(x) = (2/√π) ∫₀ˣ e^{-u²} du. out defaults to a fresh allocation.
# erf (scalar float32 0.) |> item []
- : float = 0.Sliding windows
Patches
val extract_patches :
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t ->
('a, 'b) textract_patches ~kernel_size ~stride ~dilation ~padding t extracts sliding windows from the last K spatial dimensions where K = Array.length kernel_size.
Input: [leading…; spatial…]. Output: [leading…; prod(kernel_size); L].
# arange_f float32 0. 16. 1.
|> reshape [| 1; 1; 4; 4 |]
|> extract_patches
~kernel_size:[| 2; 2 |]
~stride:[| 1; 1 |]
~dilation:[| 1; 1 |]
~padding:[| (0, 0); (0, 0) |]
|> shape
- : int array = [|1; 1; 4; 9|]See also combine_patches.
val combine_patches :
output_size:int array ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t ->
('a, 'b) tcombine_patches ~output_size ~kernel_size ~stride ~dilation ~padding t is the inverse of extract_patches. Overlapping values are summed.
See also extract_patches.
Cross-correlation and convolution
correlate ?padding x kernel is the N-D cross-correlation (no kernel flip). Spatial dimensions K = ndim kernel. Leading dimensions of x beyond K are batch dimensions. padding defaults to `Valid.
See also convolve.
Filters
maximum_filter ~kernel_size ?stride t is the sliding-window maximum over the last K dimensions. stride defaults to kernel_size.
See also minimum_filter, uniform_filter.
minimum_filter ~kernel_size ?stride t is the sliding-window minimum over the last K dimensions. stride defaults to kernel_size.
See also maximum_filter.
uniform_filter ~kernel_size ?stride t is the sliding-window mean over the last K dimensions. stride defaults to kernel_size.
See also maximum_filter, minimum_filter.
Iteration
map_item f t applies f to each scalar element of t and returns a fresh tensor of the results.
val iter_item : ('a -> unit) -> ('a, 'b) t -> unititer_item f t applies f to each scalar element of t for its side effects.
val fold_item : ('a -> 'b -> 'a) -> 'a -> ('b, 'c) t -> 'afold_item f init t folds f over the scalar elements of t in row-major order, starting with init.
map f t applies tensor function f to each element of t, presented as a scalar tensor.
See also map_item.
iter f t applies tensor function f to each element of t, presented as a scalar tensor.
See also iter_item.
fold f init t folds tensor function f over the elements of t, each presented as a scalar tensor.
See also fold_item.
Formatting
val pp_data : Stdlib.Format.formatter -> ('a, 'b) t -> unitpp_data fmt t formats the data of t.
format_to_string pp x is the string produced by pp.
print_with_formatter pp x prints x to stdout using pp.
val data_to_string : ('a, 'b) t -> stringdata_to_string t is the data of t as a string.
val print_data : ('a, 'b) t -> unitprint_data t prints the data of t to stdout.
val pp_dtype : Stdlib.Format.formatter -> ('a, 'b) dtype -> unitpp_dtype fmt dt formats dt.
val dtype_to_string : ('a, 'b) dtype -> stringdtype_to_string dt is dt as a string.
val pp : Stdlib.Format.formatter -> ('a, 'b) t -> unitpp fmt t formats t for debugging (dtype, shape, and data).
val print : ('a, 'b) t -> unitprint t prints t to stdout.
val to_string : ('a, 'b) t -> stringto_string t is t formatted as a string (dtype, shape, and data).