Module type Backend_intf.S
Backend interface for Nx tensor operations.
This module type defines the contract between Nx's frontend and its pluggable backends. Backends may execute operations eagerly (C backend), raise effects for JIT compilation (Rune), build computation graphs, or implement other execution strategies.
Design Philosophy
Operations exist at the level of C standard library functions: every operation that maps to a C stdlib call is a backend primitive, avoiding the overhead of composing multiple operations in eager mode. Rune's JIT pipeline can decompose these into lower primitives when building computation graphs.
Frontend/Backend Contract
The frontend is responsible for:
- Broadcasting inputs to matching shapes before calling binary operations.
- Promoting dtypes to compatible types before calling operations.
- Validating parameters (axes in range, shapes compatible, etc.).
- Allocating output tensors with the correct shape and dtype.
The backend can assume all inputs are well-formed. It is responsible for:
- Executing the operation correctly for all supported dtypes.
- Handling strided (non-contiguous) inputs via the view metadata.
- Returning tensors with correct view metadata.
Conventions
- Binary, unary, reduction, and other compute operations write results to a caller-provided
~outbuffer for memory reuse. The frontend controls all allocation. - Movement operations manipulate view metadata (shape, strides, offset) without copying data when possible.
- Operations that must allocate by nature (
copy,contiguous,pad,scatter) return new tensor handles.
Types
'a is the OCaml element type (e.g., float, int32). 'b is a phantom type that tags the dtype for type safety.
Backend execution context.
Carries backend-specific state such as memory pools, device handles, command queues, or computation graphs.
Tensor Properties
view t returns the strided view metadata describing t's logical layout (shape, strides, offset) over its underlying buffer.
val to_host : ('a, 'b) t -> ('a, 'b) Nx_buffer.tto_host t returns t's data as a flat, C-contiguous host buffer.
Use view to interpret the logical structure. CPU backends may return a direct reference (zero-copy); GPU backends copy from device to host.
Tensor Creation
buffer ctx dtype shape allocates an uninitialized tensor.
Contents are undefined. Used internally by the frontend to pre-allocate ~out buffers before calling operations.
Backend must: return a tensor with the given shape and dtype whose view is C-contiguous.
full ctx dtype shape value creates a tensor where every element is value.
For scalars, shape is [||]. Subsumes zeros, ones, and constant fill.
Backend must: return a C-contiguous tensor of the given shape and dtype with all elements set to value.
val from_host : context -> ('a, 'b) Nx_buffer.t -> ('a, 'b) tfrom_host ctx buf creates a tensor from a flat, C-contiguous host buffer.
CPU backends may share the buffer directly (zero-copy). GPU backends copy from host to device.
Frontend guarantees: buf is C-contiguous.
Element-wise Binary Operations
Frontend guarantees: out, a, and b have identical shapes (after broadcasting) and compatible dtypes (after promotion). out is C-contiguous and pre-allocated with the correct shape.
Backend must: write exactly numel elements to out, respecting the strides of a and b (which may be non-contiguous or broadcast).
Arithmetic
add ~out a b computes out.{i} <- a.{i} + b.{i}.
sub ~out a b computes out.{i} <- a.{i} - b.{i}.
mul ~out a b computes out.{i} <- a.{i} * b.{i}.
div ~out a b computes out.{i} <- a.{i} / b.{i}.
Integer dtypes use truncation toward zero (C division). Floating-point dtypes use IEEE 754 division.
mod_ ~out a b computes the remainder of a / b.
Integers use C's % operator (truncated division). Floats use fmod. The sign of the result follows the dividend a.
pow ~out base exponent computes out.{i} <- base.{i} ^ exponent.{i}.
atan2 ~out y x computes out.{i} <- atan2(y.{i}, x.{i}).
Returns the angle in radians in (-π, π], handling all quadrants.
Comparison
Comparison operations produce boolean tensors.
Frontend guarantees: out is a (bool, bool_elt) tensor with the same shape as a and b.
val cmpeq : out:(bool, Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unitcmpeq ~out a b computes out.{i} <- (a.{i} = b.{i}).
val cmpne : out:(bool, Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unitcmpne ~out a b computes out.{i} <- (a.{i} <> b.{i}).
val cmplt : out:(bool, Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unitcmplt ~out a b computes out.{i} <- (a.{i} < b.{i}).
val cmple : out:(bool, Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unitcmple ~out a b computes out.{i} <- (a.{i} <= b.{i}).
Min/Max
max ~out a b computes out.{i} <- max(a.{i}, b.{i}).
min ~out a b computes out.{i} <- min(a.{i}, b.{i}).
Bitwise
Operate on the binary representation of integer and boolean dtypes. For booleans, these are equivalent to logical AND/OR/XOR.
Element-wise Unary Operations
Frontend guarantees: out and x have the same shape and dtype. out is C-contiguous.
Backend must: write exactly numel elements to out, respecting the strides of x.
Arithmetic
sign ~out x computes the sign function: -1 for negative, 0 for zero, 1 for positive. Returns NaN for floating-point NaN inputs.
Exponential and Logarithm
Trigonometric
All inputs are in radians.
asin ~out x computes out.{i} <- arcsin(x.{i}).
Returns values in [-π/2, π/2].
acos ~out x computes out.{i} <- arccos(x.{i}).
Returns values in [0, π].
atan ~out x computes out.{i} <- arctan(x.{i}).
Returns values in [-π/2, π/2].
Hyperbolic
Rounding
For integer dtypes, all rounding operations are the identity.
round ~out x rounds to nearest integer, half away from zero (C's round).
Special Functions
erf ~out x computes the error function erf(x) = 2/√π ∫₀ˣ e^(-t²) dt.
Ternary Operations
val where :
out:('a, 'b) t ->
(bool, Dtype.bool_elt) t ->
('a, 'b) t ->
('a, 'b) t ->
unitwhere ~out cond if_true if_false selects elements: if_true.{i} where cond.{i} is true, if_false.{i} otherwise.
Frontend guarantees: all four tensors have identical shapes. cond is boolean. out, if_true, if_false share the same dtype.
Reduction Operations
Reductions aggregate values along one or more axes.
Frontend guarantees: axes contains valid, non-negative, deduplicated axis indices. out is pre-allocated with the correct shape: reduced axes are either removed or kept as size-1 dimensions depending on keepdims.
reduce_sum ~out ~axes ~keepdims x sums elements along axes.
reduce_prod ~out ~axes ~keepdims x multiplies elements along axes.
reduce_max ~out ~axes ~keepdims x finds maximum along axes.
reduce_min ~out ~axes ~keepdims x finds minimum along axes.
val argmax :
out:(int32, Dtype.int32_elt) t ->
axis:int ->
keepdims:bool ->
('a, 'b) t ->
unitargmax ~out ~axis ~keepdims x writes int32 indices of maximum values along axis to out. For ties, returns the first occurrence.
Frontend guarantees: axis is valid and non-negative. out has the correct reduced shape with int32 dtype.
val argmin :
out:(int32, Dtype.int32_elt) t ->
axis:int ->
keepdims:bool ->
('a, 'b) t ->
unitargmin ~out ~axis ~keepdims x writes int32 indices of minimum values along axis to out. For ties, returns the first occurrence.
Frontend guarantees: axis is valid and non-negative. out has the correct reduced shape with int32 dtype.
val associative_scan :
out:('a, 'b) t ->
axis:int ->
op:[ `Sum | `Prod | `Max | `Min ] ->
('a, 'b) t ->
unitassociative_scan ~out ~axis ~op x computes an inclusive prefix scan along axis. `Sum for cumulative sum, `Prod for cumulative product, `Max/`Min for running max/min.
Frontend guarantees: axis is valid and non-negative. out has the same shape as x.
Sort Operations
Frontend guarantees: axis is valid and non-negative. out is pre-allocated with the correct shape and dtype.
sort ~out ~axis ~descending x sorts elements along axis. NaN values are placed at the end regardless of sort direction.
Frontend guarantees: out has the same shape and dtype as x.
val argsort :
out:(int32, Dtype.int32_elt) t ->
axis:int ->
descending:bool ->
('a, 'b) t ->
unitargsort ~out ~axis ~descending x writes int32 indices that would sort elements along axis to out.
Frontend guarantees: out has the same shape as x with int32 dtype.
Movement Operations
Movement operations manipulate view metadata (shape, strides, offset) without copying data when possible. They return new tensor handles sharing the underlying buffer.
Frontend guarantees: all parameters are validated (axes in range, shapes compatible, bounds within limits).
Backend must: return a tensor with the correct view metadata. May share the underlying buffer (zero-copy) or allocate if necessary.
expand t shape broadcasts dimensions of size 1 to match shape by setting their stride to 0. Non-singleton dimensions must already match. Zero-copy.
reshape t shape changes the logical shape, preserving element count.
Zero-copy when t is C-contiguous or the reshape is compatible with the current strides. May copy if t is non-contiguous.
permute t axes reorders dimensions according to axes, which must be a permutation of [0, ..., ndim-1]. Zero-copy.
shrink t ranges extracts a contiguous slice. ranges.(i) is (start, stop) with exclusive stop. Zero-copy (adjusts offset and shape).
flip t axes reverses dimensions where axes.(i) = true by negating strides. Zero-copy.
pad t padding fill_value extends t with fill_value. padding.(i) is (before, after) for dimension i.
Backend must: allocate a new buffer and copy data.
cat ~out tensors ~axis concatenates tensors along axis into out.
Frontend guarantees: all tensors have the same shape except along axis. axis is valid. The list is non-empty. out is pre-allocated with the correct concatenated shape.
Type Conversion and Memory
cast ~out x converts elements of x to the dtype of out.
Float-to-int truncates toward zero. Int-to-float may lose precision for large values.
Frontend guarantees: out is pre-allocated with the correct shape and target dtype.
contiguous t returns a C-contiguous version of t.
May return t unchanged if already C-contiguous. Otherwise allocates and copies.
Backend must: return a C-contiguous tensor with the same data.
copy t creates an independent copy with its own buffer.
Backend must: always allocate a new buffer, even if t is already contiguous.
assign dst src copies elements from src into dst in-place.
Frontend guarantees: dst and src have matching shapes and dtypes.
Backend must: write src's data into dst's buffer, respecting both tensors' strides.
Random Number Generation
val threefry :
out:(int32, Dtype.int32_elt) t ->
(int32, Dtype.int32_elt) t ->
(int32, Dtype.int32_elt) t ->
unitthreefry ~out key counter applies the Threefry-2x32 hash function.
Frontend guarantees: key and counter are int32 tensors with compatible shapes. out is pre-allocated with the same shape as counter.
Indexed Access Operations
val gather :
out:('a, 'b) t ->
('a, 'b) t ->
(int32, Dtype.int32_elt) t ->
axis:int ->
unitgather ~out data indices ~axis selects elements from data along axis using indices and writes them to out.
Frontend guarantees: rank data = rank indices. axis is valid. Index values are in range for data's size along axis. out has the same shape as indices and the same dtype as data.
val scatter :
?mode:[ `Set | `Add ] ->
?unique_indices:bool ->
('a, 'b) t ->
indices:(int32, Dtype.int32_elt) t ->
updates:('a, 'b) t ->
axis:int ->
('a, 'b) tscatter ?mode ?unique_indices template ~indices ~updates ~axis places updates into a tensor shaped like template along axis.
`Set (default) uses the last update for duplicate indices. `Add accumulates. unique_indices = true hints that indices are unique.
Frontend guarantees: rank indices = rank updates. axis is valid. template has the desired output shape.
Backend must: allocate and return the result tensor, initialized from template's data.
Window Operations
Sliding-window extraction and its inverse. Used to implement convolution as unfold + reshape + matmul and pooling as unfold + reduce.
val unfold :
('a, 'b) t ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) tunfold t ~kernel_size ~stride ~dilation ~padding extracts sliding windows from the last K spatial dimensions, where K = Array.length kernel_size.
Input shape (leading..., spatial...) produces (leading..., prod(kernel_size), L) where L is the number of windows. All dimensions before the last K are preserved as-is.
Frontend guarantees: all array parameters have length K. Values are positive. Input has at least K dimensions.
Backend must: allocate and return the result tensor.
val fold :
('a, 'b) t ->
output_size:int array ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) tfold t ~output_size ~kernel_size ~stride ~dilation ~padding combines sliding windows (inverse of unfold). Overlapping values are summed.
Input shape (leading..., prod(kernel_size), L) produces (leading..., output_size...).
Frontend guarantees: parameters are consistent with a valid unfold configuration.
Backend must: allocate and return the result tensor.
Matrix Operations
matmul ~out a b computes matrix multiplication a × b.
For 2D inputs: standard matrix multiply. For higher dimensions: batched multiply on the last two dimensions, with broadcasting via strides.
Frontend guarantees: a's last dim equals b's second-to-last dim. out is C-contiguous with the correct output shape.
Backend must: write the result to out. May use BLAS for performance. a and b may be non-contiguous.
Fourier Transforms
Frontend guarantees: axes contains valid, non-negative axis indices. Input tensors have compatible complex or real dtypes.
val fft :
?out:(Stdlib.Complex.t, 'b) t ->
(Stdlib.Complex.t, 'b) t ->
axes:int array ->
(Stdlib.Complex.t, 'b) tfft ?out t ~axes computes the forward DFT along axes.
val ifft :
?out:(Stdlib.Complex.t, 'b) t ->
(Stdlib.Complex.t, 'b) t ->
axes:int array ->
(Stdlib.Complex.t, 'b) tifft ?out t ~axes computes the inverse DFT along axes.
val rfft :
?out:(Stdlib.Complex.t, 'b) t ->
(float, 'a) t ->
dtype:(Stdlib.Complex.t, 'b) Dtype.t ->
axes:int array ->
(Stdlib.Complex.t, 'b) trfft ?out t ~dtype ~axes computes the real-input DFT along axes.
Exploits conjugate symmetry to return only the non-redundant half of the spectrum along the last transformed axis.
val irfft :
?out:(float, 'b) t ->
?s:int array ->
(Stdlib.Complex.t, 'a) t ->
dtype:(float, 'b) Dtype.t ->
axes:int array ->
(float, 'b) tirfft ?out ?s t ~dtype ~axes computes the inverse real-input DFT along axes.
Takes conjugate-symmetric complex input, returns real output. s specifies output sizes along the transformed axes; None infers sizes from the input.
Linear Algebra
All linalg operations support batching: the last two dimensions are the matrix dimensions, earlier dimensions are batch dimensions.
Frontend guarantees: input matrices have compatible shapes (square where required, matching dimensions for solves).
Backend must: allocate and return result tensors. Typically delegates to LAPACK.
cholesky ~upper t computes the Cholesky factorization of a positive-definite matrix. Returns L (lower) or U (upper) such that A = L·Lᵀ or A = Uᵀ·U.
qr ~reduced t returns (Q, R) where Q is orthogonal and R is upper triangular. reduced = true returns economy-size factorization.
val svd :
full_matrices:bool ->
('a, 'b) t ->
('a, 'b) t * (float, Dtype.float64_elt) t * ('a, 'b) tsvd ~full_matrices t returns (U, S, Vᴴ). S is a 1D float64 vector of singular values in descending order. full_matrices = false returns thin SVD.
val eig :
vectors:bool ->
('a, 'b) t ->
(Stdlib.Complex.t, Dtype.complex64_elt) t
* (Stdlib.Complex.t, Dtype.complex64_elt) t optioneig ~vectors t computes eigenvalues (and optionally eigenvectors) of a square matrix. Returns complex64 results.
val eigh :
vectors:bool ->
('a, 'b) t ->
(float, Dtype.float64_elt) t * ('a, 'b) t optioneigh ~vectors t computes eigenvalues (and optionally eigenvectors) of a symmetric/Hermitian matrix. Eigenvalues are float64.
val triangular_solve :
upper:bool ->
transpose:bool ->
unit_diag:bool ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) ttriangular_solve ~upper ~transpose ~unit_diag a b solves A·x = b or Aᵀ·x = b where A is triangular.
upper: A is upper triangular. transpose: solve Aᵀ·x = b. unit_diag: assume diagonal is all ones.