Parameter Make_frontend.B
Types
'a is the OCaml element type (e.g., float, int32). 'b is a phantom type that tags the dtype for type safety.
Backend execution context.
Carries backend-specific state such as memory pools, device handles, command queues, or computation graphs.
Tensor Properties
view t returns the strided view metadata describing t's logical layout (shape, strides, offset) over its underlying buffer.
val to_host : ('a, 'b) t -> ('a, 'b) Nx_buffer.tto_host t returns t's data as a flat, C-contiguous host buffer.
Use view to interpret the logical structure. CPU backends may return a direct reference (zero-copy); GPU backends copy from device to host.
Tensor Creation
buffer ctx dtype shape allocates an uninitialized tensor.
Contents are undefined. Used internally by the frontend to pre-allocate ~out buffers before calling operations.
Backend must: return a tensor with the given shape and dtype whose view is C-contiguous.
full ctx dtype shape value creates a tensor where every element is value.
For scalars, shape is [||]. Subsumes zeros, ones, and constant fill.
Backend must: return a C-contiguous tensor of the given shape and dtype with all elements set to value.
val from_host : context -> ('a, 'b) Nx_buffer.t -> ('a, 'b) tfrom_host ctx buf creates a tensor from a flat, C-contiguous host buffer.
CPU backends may share the buffer directly (zero-copy). GPU backends copy from host to device.
Frontend guarantees: buf is C-contiguous.
Element-wise Binary Operations
Frontend guarantees: out, a, and b have identical shapes (after broadcasting) and compatible dtypes (after promotion). out is C-contiguous and pre-allocated with the correct shape.
Backend must: write exactly numel elements to out, respecting the strides of a and b (which may be non-contiguous or broadcast).
Arithmetic
add ~out a b computes out.{i} <- a.{i} + b.{i}.
sub ~out a b computes out.{i} <- a.{i} - b.{i}.
mul ~out a b computes out.{i} <- a.{i} * b.{i}.
div ~out a b computes out.{i} <- a.{i} / b.{i}.
Integer dtypes use truncation toward zero (C division). Floating-point dtypes use IEEE 754 division.
mod_ ~out a b computes the remainder of a / b.
Integers use C's % operator (truncated division). Floats use fmod. The sign of the result follows the dividend a.
pow ~out base exponent computes out.{i} <- base.{i} ^ exponent.{i}.
atan2 ~out y x computes out.{i} <- atan2(y.{i}, x.{i}).
Returns the angle in radians in (-π, π], handling all quadrants.
Comparison
Comparison operations produce boolean tensors.
Frontend guarantees: out is a (bool, bool_elt) tensor with the same shape as a and b.
val cmpeq : out:(bool, Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unitcmpeq ~out a b computes out.{i} <- (a.{i} = b.{i}).
val cmpne : out:(bool, Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unitcmpne ~out a b computes out.{i} <- (a.{i} <> b.{i}).
val cmplt : out:(bool, Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unitcmplt ~out a b computes out.{i} <- (a.{i} < b.{i}).
val cmple : out:(bool, Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unitcmple ~out a b computes out.{i} <- (a.{i} <= b.{i}).
Min/Max
max ~out a b computes out.{i} <- max(a.{i}, b.{i}).
min ~out a b computes out.{i} <- min(a.{i}, b.{i}).
Bitwise
Operate on the binary representation of integer and boolean dtypes. For booleans, these are equivalent to logical AND/OR/XOR.
Element-wise Unary Operations
Frontend guarantees: out and x have the same shape and dtype. out is C-contiguous.
Backend must: write exactly numel elements to out, respecting the strides of x.
Arithmetic
sign ~out x computes the sign function: -1 for negative, 0 for zero, 1 for positive. Returns NaN for floating-point NaN inputs.
Exponential and Logarithm
Trigonometric
All inputs are in radians.
asin ~out x computes out.{i} <- arcsin(x.{i}).
Returns values in [-π/2, π/2].
acos ~out x computes out.{i} <- arccos(x.{i}).
Returns values in [0, π].
atan ~out x computes out.{i} <- arctan(x.{i}).
Returns values in [-π/2, π/2].
Hyperbolic
Rounding
For integer dtypes, all rounding operations are the identity.
round ~out x rounds to nearest integer, half away from zero (C's round).
Special Functions
erf ~out x computes the error function erf(x) = 2/√π ∫₀ˣ e^(-t²) dt.
Ternary Operations
val where :
out:('a, 'b) t ->
(bool, Dtype.bool_elt) t ->
('a, 'b) t ->
('a, 'b) t ->
unitwhere ~out cond if_true if_false selects elements: if_true.{i} where cond.{i} is true, if_false.{i} otherwise.
Frontend guarantees: all four tensors have identical shapes. cond is boolean. out, if_true, if_false share the same dtype.
Reduction Operations
Reductions aggregate values along one or more axes.
Frontend guarantees: axes contains valid, non-negative, deduplicated axis indices. out is pre-allocated with the correct shape: reduced axes are either removed or kept as size-1 dimensions depending on keepdims.
reduce_sum ~out ~axes ~keepdims x sums elements along axes.
reduce_prod ~out ~axes ~keepdims x multiplies elements along axes.
reduce_max ~out ~axes ~keepdims x finds maximum along axes.
reduce_min ~out ~axes ~keepdims x finds minimum along axes.
val argmax :
out:(int32, Dtype.int32_elt) t ->
axis:int ->
keepdims:bool ->
('a, 'b) t ->
unitargmax ~out ~axis ~keepdims x writes int32 indices of maximum values along axis to out. For ties, returns the first occurrence.
Frontend guarantees: axis is valid and non-negative. out has the correct reduced shape with int32 dtype.
val argmin :
out:(int32, Dtype.int32_elt) t ->
axis:int ->
keepdims:bool ->
('a, 'b) t ->
unitargmin ~out ~axis ~keepdims x writes int32 indices of minimum values along axis to out. For ties, returns the first occurrence.
Frontend guarantees: axis is valid and non-negative. out has the correct reduced shape with int32 dtype.
val associative_scan :
out:('a, 'b) t ->
axis:int ->
op:[ `Sum | `Prod | `Max | `Min ] ->
('a, 'b) t ->
unitassociative_scan ~out ~axis ~op x computes an inclusive prefix scan along axis. `Sum for cumulative sum, `Prod for cumulative product, `Max/`Min for running max/min.
Frontend guarantees: axis is valid and non-negative. out has the same shape as x.
Sort Operations
Frontend guarantees: axis is valid and non-negative. out is pre-allocated with the correct shape and dtype.
sort ~out ~axis ~descending x sorts elements along axis. NaN values are placed at the end regardless of sort direction.
Frontend guarantees: out has the same shape and dtype as x.
val argsort :
out:(int32, Dtype.int32_elt) t ->
axis:int ->
descending:bool ->
('a, 'b) t ->
unitargsort ~out ~axis ~descending x writes int32 indices that would sort elements along axis to out.
Frontend guarantees: out has the same shape as x with int32 dtype.
Movement Operations
Movement operations manipulate view metadata (shape, strides, offset) without copying data when possible. They return new tensor handles sharing the underlying buffer.
Frontend guarantees: all parameters are validated (axes in range, shapes compatible, bounds within limits).
Backend must: return a tensor with the correct view metadata. May share the underlying buffer (zero-copy) or allocate if necessary.
expand t shape broadcasts dimensions of size 1 to match shape by setting their stride to 0. Non-singleton dimensions must already match. Zero-copy.
reshape t shape changes the logical shape, preserving element count.
Zero-copy when t is C-contiguous or the reshape is compatible with the current strides. May copy if t is non-contiguous.
permute t axes reorders dimensions according to axes, which must be a permutation of [0, ..., ndim-1]. Zero-copy.
shrink t ranges extracts a contiguous slice. ranges.(i) is (start, stop) with exclusive stop. Zero-copy (adjusts offset and shape).
flip t axes reverses dimensions where axes.(i) = true by negating strides. Zero-copy.
pad t padding fill_value extends t with fill_value. padding.(i) is (before, after) for dimension i.
Backend must: allocate a new buffer and copy data.
cat ~out tensors ~axis concatenates tensors along axis into out.
Frontend guarantees: all tensors have the same shape except along axis. axis is valid. The list is non-empty. out is pre-allocated with the correct concatenated shape.
Type Conversion and Memory
cast ~out x converts elements of x to the dtype of out.
Float-to-int truncates toward zero. Int-to-float may lose precision for large values.
Frontend guarantees: out is pre-allocated with the correct shape and target dtype.
contiguous t returns a C-contiguous version of t.
May return t unchanged if already C-contiguous. Otherwise allocates and copies.
Backend must: return a C-contiguous tensor with the same data.
copy t creates an independent copy with its own buffer.
Backend must: always allocate a new buffer, even if t is already contiguous.
assign dst src copies elements from src into dst in-place.
Frontend guarantees: dst and src have matching shapes and dtypes.
Backend must: write src's data into dst's buffer, respecting both tensors' strides.
Random Number Generation
val threefry :
out:(int32, Dtype.int32_elt) t ->
(int32, Dtype.int32_elt) t ->
(int32, Dtype.int32_elt) t ->
unitthreefry ~out key counter applies the Threefry-2x32 hash function.
Frontend guarantees: key and counter are int32 tensors with compatible shapes. out is pre-allocated with the same shape as counter.
Indexed Access Operations
val gather :
out:('a, 'b) t ->
('a, 'b) t ->
(int32, Dtype.int32_elt) t ->
axis:int ->
unitgather ~out data indices ~axis selects elements from data along axis using indices and writes them to out.
Frontend guarantees: rank data = rank indices. axis is valid. Index values are in range for data's size along axis. out has the same shape as indices and the same dtype as data.
val scatter :
?mode:[ `Set | `Add ] ->
?unique_indices:bool ->
('a, 'b) t ->
indices:(int32, Dtype.int32_elt) t ->
updates:('a, 'b) t ->
axis:int ->
('a, 'b) tscatter ?mode ?unique_indices template ~indices ~updates ~axis places updates into a tensor shaped like template along axis.
`Set (default) uses the last update for duplicate indices. `Add accumulates. unique_indices = true hints that indices are unique.
Frontend guarantees: rank indices = rank updates. axis is valid. template has the desired output shape.
Backend must: allocate and return the result tensor, initialized from template's data.
Window Operations
Sliding-window extraction and its inverse. Used to implement convolution as unfold + reshape + matmul and pooling as unfold + reduce.
val unfold :
('a, 'b) t ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) tunfold t ~kernel_size ~stride ~dilation ~padding extracts sliding windows from the last K spatial dimensions, where K = Array.length kernel_size.
Input shape (leading..., spatial...) produces (leading..., prod(kernel_size), L) where L is the number of windows. All dimensions before the last K are preserved as-is.
Frontend guarantees: all array parameters have length K. Values are positive. Input has at least K dimensions.
Backend must: allocate and return the result tensor.
val fold :
('a, 'b) t ->
output_size:int array ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) tfold t ~output_size ~kernel_size ~stride ~dilation ~padding combines sliding windows (inverse of unfold). Overlapping values are summed.
Input shape (leading..., prod(kernel_size), L) produces (leading..., output_size...).
Frontend guarantees: parameters are consistent with a valid unfold configuration.
Backend must: allocate and return the result tensor.
Matrix Operations
matmul ~out a b computes matrix multiplication a × b.
For 2D inputs: standard matrix multiply. For higher dimensions: batched multiply on the last two dimensions, with broadcasting via strides.
Frontend guarantees: a's last dim equals b's second-to-last dim. out is C-contiguous with the correct output shape.
Backend must: write the result to out. May use BLAS for performance. a and b may be non-contiguous.
Fourier Transforms
Frontend guarantees: axes contains valid, non-negative axis indices. Input tensors have compatible complex or real dtypes.
val fft :
?out:(Stdlib.Complex.t, 'b) t ->
(Stdlib.Complex.t, 'b) t ->
axes:int array ->
(Stdlib.Complex.t, 'b) tfft ?out t ~axes computes the forward DFT along axes.
val ifft :
?out:(Stdlib.Complex.t, 'b) t ->
(Stdlib.Complex.t, 'b) t ->
axes:int array ->
(Stdlib.Complex.t, 'b) tifft ?out t ~axes computes the inverse DFT along axes.
val rfft :
?out:(Stdlib.Complex.t, 'b) t ->
(float, 'a) t ->
dtype:(Stdlib.Complex.t, 'b) Dtype.t ->
axes:int array ->
(Stdlib.Complex.t, 'b) trfft ?out t ~dtype ~axes computes the real-input DFT along axes.
Exploits conjugate symmetry to return only the non-redundant half of the spectrum along the last transformed axis.
val irfft :
?out:(float, 'b) t ->
?s:int array ->
(Stdlib.Complex.t, 'a) t ->
dtype:(float, 'b) Dtype.t ->
axes:int array ->
(float, 'b) tirfft ?out ?s t ~dtype ~axes computes the inverse real-input DFT along axes.
Takes conjugate-symmetric complex input, returns real output. s specifies output sizes along the transformed axes; None infers sizes from the input.
Linear Algebra
All linalg operations support batching: the last two dimensions are the matrix dimensions, earlier dimensions are batch dimensions.
Frontend guarantees: input matrices have compatible shapes (square where required, matching dimensions for solves).
Backend must: allocate and return result tensors. Typically delegates to LAPACK.
cholesky ~upper t computes the Cholesky factorization of a positive-definite matrix. Returns L (lower) or U (upper) such that A = L·Lᵀ or A = Uᵀ·U.
qr ~reduced t returns (Q, R) where Q is orthogonal and R is upper triangular. reduced = true returns economy-size factorization.
val svd :
full_matrices:bool ->
('a, 'b) t ->
('a, 'b) t * (float, Dtype.float64_elt) t * ('a, 'b) tsvd ~full_matrices t returns (U, S, Vᴴ). S is a 1D float64 vector of singular values in descending order. full_matrices = false returns thin SVD.
val eig :
vectors:bool ->
('a, 'b) t ->
(Stdlib.Complex.t, Dtype.complex64_elt) t
* (Stdlib.Complex.t, Dtype.complex64_elt) t optioneig ~vectors t computes eigenvalues (and optionally eigenvectors) of a square matrix. Returns complex64 results.
val eigh :
vectors:bool ->
('a, 'b) t ->
(float, Dtype.float64_elt) t * ('a, 'b) t optioneigh ~vectors t computes eigenvalues (and optionally eigenvectors) of a symmetric/Hermitian matrix. Eigenvalues are float64.
val triangular_solve :
upper:bool ->
transpose:bool ->
unit_diag:bool ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) ttriangular_solve ~upper ~transpose ~unit_diag a b solves A·x = b or Aᵀ·x = b where A is triangular.
upper: A is upper triangular. transpose: solve Aᵀ·x = b. unit_diag: assume diagonal is all ones.