Module Nx_backend

include Nx_core.Backend_intf.S

Types

type ('a, 'b) t

'a is the OCaml element type (e.g., float, int32). 'b is a phantom type that tags the dtype for type safety.

type context

Backend execution context.

Carries backend-specific state such as memory pools, device handles, command queues, or computation graphs.

Tensor Properties

val view : ('a, 'b) t -> Nx_core.View.t

view t returns the strided view metadata describing t's logical layout (shape, strides, offset) over its underlying buffer.

val dtype : ('a, 'b) t -> ('a, 'b) Nx_core.Dtype.t

dtype t returns the element type of t.

val context : ('a, 'b) t -> context

context t returns the execution context that owns t.

val to_host : ('a, 'b) t -> ('a, 'b) Nx_buffer.t

to_host t returns t's data as a flat, C-contiguous host buffer.

Use view to interpret the logical structure. CPU backends may return a direct reference (zero-copy); GPU backends copy from device to host.

Tensor Creation

val buffer : context -> ('a, 'b) Nx_core.Dtype.t -> int array -> ('a, 'b) t

buffer ctx dtype shape allocates an uninitialized tensor.

Contents are undefined. Used internally by the frontend to pre-allocate ~out buffers before calling operations.

Backend must: return a tensor with the given shape and dtype whose view is C-contiguous.

val full : context -> ('a, 'b) Nx_core.Dtype.t -> int array -> 'a -> ('a, 'b) t

full ctx dtype shape value creates a tensor where every element is value.

For scalars, shape is [||]. Subsumes zeros, ones, and constant fill.

Backend must: return a C-contiguous tensor of the given shape and dtype with all elements set to value.

val from_host : context -> ('a, 'b) Nx_buffer.t -> ('a, 'b) t

from_host ctx buf creates a tensor from a flat, C-contiguous host buffer.

CPU backends may share the buffer directly (zero-copy). GPU backends copy from host to device.

Frontend guarantees: buf is C-contiguous.

Element-wise Binary Operations

Frontend guarantees: out, a, and b have identical shapes (after broadcasting) and compatible dtypes (after promotion). out is C-contiguous and pre-allocated with the correct shape.

Backend must: write exactly numel elements to out, respecting the strides of a and b (which may be non-contiguous or broadcast).

Arithmetic

val add : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

add ~out a b computes out.{i} <- a.{i} + b.{i}.

val sub : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

sub ~out a b computes out.{i} <- a.{i} - b.{i}.

val mul : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

mul ~out a b computes out.{i} <- a.{i} * b.{i}.

val div : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

div ~out a b computes out.{i} <- a.{i} / b.{i}.

Integer dtypes use truncation toward zero (C division). Floating-point dtypes use IEEE 754 division.

val mod_ : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

mod_ ~out a b computes the remainder of a / b.

Integers use C's % operator (truncated division). Floats use fmod. The sign of the result follows the dividend a.

val pow : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

pow ~out base exponent computes out.{i} <- base.{i} ^ exponent.{i}.

val atan2 : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

atan2 ~out y x computes out.{i} <- atan2(y.{i}, x.{i}).

Returns the angle in radians in (-π, π], handling all quadrants.

Comparison

Comparison operations produce boolean tensors.

Frontend guarantees: out is a (bool, bool_elt) tensor with the same shape as a and b.

val cmpeq : out:(bool, Nx_core.Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unit

cmpeq ~out a b computes out.{i} <- (a.{i} = b.{i}).

val cmpne : out:(bool, Nx_core.Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unit

cmpne ~out a b computes out.{i} <- (a.{i} <> b.{i}).

val cmplt : out:(bool, Nx_core.Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unit

cmplt ~out a b computes out.{i} <- (a.{i} < b.{i}).

val cmple : out:(bool, Nx_core.Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unit

cmple ~out a b computes out.{i} <- (a.{i} <= b.{i}).

Min/Max

val max : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

max ~out a b computes out.{i} <- max(a.{i}, b.{i}).

val min : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

min ~out a b computes out.{i} <- min(a.{i}, b.{i}).

Bitwise

Operate on the binary representation of integer and boolean dtypes. For booleans, these are equivalent to logical AND/OR/XOR.

val xor : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

xor ~out a b computes bitwise XOR.

val or_ : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

or_ ~out a b computes bitwise OR.

val and_ : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

and_ ~out a b computes bitwise AND.

Element-wise Unary Operations

Frontend guarantees: out and x have the same shape and dtype. out is C-contiguous.

Backend must: write exactly numel elements to out, respecting the strides of x.

Arithmetic

val neg : out:('a, 'b) t -> ('a, 'b) t -> unit

neg ~out x computes out.{i} <- -x.{i}.

val recip : out:('a, 'b) t -> ('a, 'b) t -> unit

recip ~out x computes out.{i} <- 1 / x.{i}.

val abs : out:('a, 'b) t -> ('a, 'b) t -> unit

abs ~out x computes out.{i} <- |x.{i}|.

val sqrt : out:('a, 'b) t -> ('a, 'b) t -> unit

sqrt ~out x computes out.{i} <- √x.{i}.

val sign : out:('a, 'b) t -> ('a, 'b) t -> unit

sign ~out x computes the sign function: -1 for negative, 0 for zero, 1 for positive. Returns NaN for floating-point NaN inputs.

Exponential and Logarithm

val exp : out:('a, 'b) t -> ('a, 'b) t -> unit

exp ~out x computes out.{i} <- eˣ⁽ⁱ⁾.

val log : out:('a, 'b) t -> ('a, 'b) t -> unit

log ~out x computes out.{i} <- ln(x.{i}).

Trigonometric

All inputs are in radians.

val sin : out:('a, 'b) t -> ('a, 'b) t -> unit

sin ~out x computes out.{i} <- sin(x.{i}).

val cos : out:('a, 'b) t -> ('a, 'b) t -> unit

cos ~out x computes out.{i} <- cos(x.{i}).

val tan : out:('a, 'b) t -> ('a, 'b) t -> unit

tan ~out x computes out.{i} <- tan(x.{i}).

val asin : out:('a, 'b) t -> ('a, 'b) t -> unit

asin ~out x computes out.{i} <- arcsin(x.{i}).

Returns values in [-π/2, π/2].

val acos : out:('a, 'b) t -> ('a, 'b) t -> unit

acos ~out x computes out.{i} <- arccos(x.{i}).

Returns values in [0, π].

val atan : out:('a, 'b) t -> ('a, 'b) t -> unit

atan ~out x computes out.{i} <- arctan(x.{i}).

Returns values in [-π/2, π/2].

Hyperbolic

val sinh : out:('a, 'b) t -> ('a, 'b) t -> unit

sinh ~out x computes out.{i} <- sinh(x.{i}).

val cosh : out:('a, 'b) t -> ('a, 'b) t -> unit

cosh ~out x computes out.{i} <- cosh(x.{i}).

val tanh : out:('a, 'b) t -> ('a, 'b) t -> unit

tanh ~out x computes out.{i} <- tanh(x.{i}).

Rounding

For integer dtypes, all rounding operations are the identity.

val trunc : out:('a, 'b) t -> ('a, 'b) t -> unit

trunc ~out x rounds toward zero.

val ceil : out:('a, 'b) t -> ('a, 'b) t -> unit

ceil ~out x rounds toward positive infinity.

val floor : out:('a, 'b) t -> ('a, 'b) t -> unit

floor ~out x rounds toward negative infinity.

val round : out:('a, 'b) t -> ('a, 'b) t -> unit

round ~out x rounds to nearest integer, half away from zero (C's round).

Special Functions

val erf : out:('a, 'b) t -> ('a, 'b) t -> unit

erf ~out x computes the error function erf(x) = 2/√π ∫₀ˣ e^(-t²) dt.

Ternary Operations

val where : out:('a, 'b) t -> (bool, Nx_core.Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> unit

where ~out cond if_true if_false selects elements: if_true.{i} where cond.{i} is true, if_false.{i} otherwise.

Frontend guarantees: all four tensors have identical shapes. cond is boolean. out, if_true, if_false share the same dtype.

Reduction Operations

Reductions aggregate values along one or more axes.

Frontend guarantees: axes contains valid, non-negative, deduplicated axis indices. out is pre-allocated with the correct shape: reduced axes are either removed or kept as size-1 dimensions depending on keepdims.

val reduce_sum : out:('a, 'b) t -> axes:int array -> keepdims:bool -> ('a, 'b) t -> unit

reduce_sum ~out ~axes ~keepdims x sums elements along axes.

val reduce_prod : out:('a, 'b) t -> axes:int array -> keepdims:bool -> ('a, 'b) t -> unit

reduce_prod ~out ~axes ~keepdims x multiplies elements along axes.

val reduce_max : out:('a, 'b) t -> axes:int array -> keepdims:bool -> ('a, 'b) t -> unit

reduce_max ~out ~axes ~keepdims x finds maximum along axes.

val reduce_min : out:('a, 'b) t -> axes:int array -> keepdims:bool -> ('a, 'b) t -> unit

reduce_min ~out ~axes ~keepdims x finds minimum along axes.

val argmax : out:(int32, Nx_core.Dtype.int32_elt) t -> axis:int -> keepdims:bool -> ('a, 'b) t -> unit

argmax ~out ~axis ~keepdims x writes int32 indices of maximum values along axis to out. For ties, returns the first occurrence.

Frontend guarantees: axis is valid and non-negative. out has the correct reduced shape with int32 dtype.

val argmin : out:(int32, Nx_core.Dtype.int32_elt) t -> axis:int -> keepdims:bool -> ('a, 'b) t -> unit

argmin ~out ~axis ~keepdims x writes int32 indices of minimum values along axis to out. For ties, returns the first occurrence.

Frontend guarantees: axis is valid and non-negative. out has the correct reduced shape with int32 dtype.

val associative_scan : out:('a, 'b) t -> axis:int -> op:[ `Sum | `Prod | `Max | `Min ] -> ('a, 'b) t -> unit

associative_scan ~out ~axis ~op x computes an inclusive prefix scan along axis. `Sum for cumulative sum, `Prod for cumulative product, `Max/`Min for running max/min.

Frontend guarantees: axis is valid and non-negative. out has the same shape as x.

Sort Operations

Frontend guarantees: axis is valid and non-negative. out is pre-allocated with the correct shape and dtype.

val sort : out:('a, 'b) t -> axis:int -> descending:bool -> ('a, 'b) t -> unit

sort ~out ~axis ~descending x sorts elements along axis. NaN values are placed at the end regardless of sort direction.

Frontend guarantees: out has the same shape and dtype as x.

val argsort : out:(int32, Nx_core.Dtype.int32_elt) t -> axis:int -> descending:bool -> ('a, 'b) t -> unit

argsort ~out ~axis ~descending x writes int32 indices that would sort elements along axis to out.

Frontend guarantees: out has the same shape as x with int32 dtype.

Movement Operations

Movement operations manipulate view metadata (shape, strides, offset) without copying data when possible. They return new tensor handles sharing the underlying buffer.

Frontend guarantees: all parameters are validated (axes in range, shapes compatible, bounds within limits).

Backend must: return a tensor with the correct view metadata. May share the underlying buffer (zero-copy) or allocate if necessary.

val expand : ('a, 'b) t -> int array -> ('a, 'b) t

expand t shape broadcasts dimensions of size 1 to match shape by setting their stride to 0. Non-singleton dimensions must already match. Zero-copy.

val reshape : ('a, 'b) t -> int array -> ('a, 'b) t

reshape t shape changes the logical shape, preserving element count.

Zero-copy when t is C-contiguous or the reshape is compatible with the current strides. May copy if t is non-contiguous.

val permute : ('a, 'b) t -> int array -> ('a, 'b) t

permute t axes reorders dimensions according to axes, which must be a permutation of [0, ..., ndim-1]. Zero-copy.

val shrink : ('a, 'b) t -> (int * int) array -> ('a, 'b) t

shrink t ranges extracts a contiguous slice. ranges.(i) is (start, stop) with exclusive stop. Zero-copy (adjusts offset and shape).

val flip : ('a, 'b) t -> bool array -> ('a, 'b) t

flip t axes reverses dimensions where axes.(i) = true by negating strides. Zero-copy.

val pad : ('a, 'b) t -> (int * int) array -> 'a -> ('a, 'b) t

pad t padding fill_value extends t with fill_value. padding.(i) is (before, after) for dimension i.

Backend must: allocate a new buffer and copy data.

val cat : out:('a, 'b) t -> ('a, 'b) t list -> axis:int -> unit

cat ~out tensors ~axis concatenates tensors along axis into out.

Frontend guarantees: all tensors have the same shape except along axis. axis is valid. The list is non-empty. out is pre-allocated with the correct concatenated shape.

Type Conversion and Memory

val cast : out:('c, 'd) t -> ('a, 'b) t -> unit

cast ~out x converts elements of x to the dtype of out.

Float-to-int truncates toward zero. Int-to-float may lose precision for large values.

Frontend guarantees: out is pre-allocated with the correct shape and target dtype.

val contiguous : ('a, 'b) t -> ('a, 'b) t

contiguous t returns a C-contiguous version of t.

May return t unchanged if already C-contiguous. Otherwise allocates and copies.

Backend must: return a C-contiguous tensor with the same data.

val copy : ('a, 'b) t -> ('a, 'b) t

copy t creates an independent copy with its own buffer.

Backend must: always allocate a new buffer, even if t is already contiguous.

val assign : ('a, 'b) t -> ('a, 'b) t -> unit

assign dst src copies elements from src into dst in-place.

Frontend guarantees: dst and src have matching shapes and dtypes.

Backend must: write src's data into dst's buffer, respecting both tensors' strides.

Random Number Generation

val threefry : out:(int32, Nx_core.Dtype.int32_elt) t -> (int32, Nx_core.Dtype.int32_elt) t -> (int32, Nx_core.Dtype.int32_elt) t -> unit

threefry ~out key counter applies the Threefry-2x32 hash function.

Frontend guarantees: key and counter are int32 tensors with compatible shapes. out is pre-allocated with the same shape as counter.

Indexed Access Operations

val gather : out:('a, 'b) t -> ('a, 'b) t -> (int32, Nx_core.Dtype.int32_elt) t -> axis:int -> unit

gather ~out data indices ~axis selects elements from data along axis using indices and writes them to out.

Frontend guarantees: rank data = rank indices. axis is valid. Index values are in range for data's size along axis. out has the same shape as indices and the same dtype as data.

val scatter : ?mode:[ `Set | `Add ] -> ?unique_indices:bool -> ('a, 'b) t -> indices:(int32, Nx_core.Dtype.int32_elt) t -> updates:('a, 'b) t -> axis:int -> ('a, 'b) t

scatter ?mode ?unique_indices template ~indices ~updates ~axis places updates into a tensor shaped like template along axis.

`Set (default) uses the last update for duplicate indices. `Add accumulates. unique_indices = true hints that indices are unique.

Frontend guarantees: rank indices = rank updates. axis is valid. template has the desired output shape.

Backend must: allocate and return the result tensor, initialized from template's data.

Window Operations

Sliding-window extraction and its inverse. Used to implement convolution as unfold + reshape + matmul and pooling as unfold + reduce.

val unfold : ('a, 'b) t -> kernel_size:int array -> stride:int array -> dilation:int array -> padding:(int * int) array -> ('a, 'b) t

unfold t ~kernel_size ~stride ~dilation ~padding extracts sliding windows from the last K spatial dimensions, where K = Array.length kernel_size.

Input shape (leading..., spatial...) produces (leading..., prod(kernel_size), L) where L is the number of windows. All dimensions before the last K are preserved as-is.

Frontend guarantees: all array parameters have length K. Values are positive. Input has at least K dimensions.

Backend must: allocate and return the result tensor.

val fold : ('a, 'b) t -> output_size:int array -> kernel_size:int array -> stride:int array -> dilation:int array -> padding:(int * int) array -> ('a, 'b) t

fold t ~output_size ~kernel_size ~stride ~dilation ~padding combines sliding windows (inverse of unfold). Overlapping values are summed.

Input shape (leading..., prod(kernel_size), L) produces (leading..., output_size...).

Frontend guarantees: parameters are consistent with a valid unfold configuration.

Backend must: allocate and return the result tensor.

Matrix Operations

val matmul : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unit

matmul ~out a b computes matrix multiplication a × b.

For 2D inputs: standard matrix multiply. For higher dimensions: batched multiply on the last two dimensions, with broadcasting via strides.

Frontend guarantees: a's last dim equals b's second-to-last dim. out is C-contiguous with the correct output shape.

Backend must: write the result to out. May use BLAS for performance. a and b may be non-contiguous.

Fourier Transforms

Frontend guarantees: axes contains valid, non-negative axis indices. Input tensors have compatible complex or real dtypes.

val fft : ?out:(Stdlib.Complex.t, 'b) t -> (Stdlib.Complex.t, 'b) t -> axes:int array -> (Stdlib.Complex.t, 'b) t

fft ?out t ~axes computes the forward DFT along axes.

val ifft : ?out:(Stdlib.Complex.t, 'b) t -> (Stdlib.Complex.t, 'b) t -> axes:int array -> (Stdlib.Complex.t, 'b) t

ifft ?out t ~axes computes the inverse DFT along axes.

val rfft : ?out:(Stdlib.Complex.t, 'b) t -> (float, 'a) t -> dtype:(Stdlib.Complex.t, 'b) Nx_core.Dtype.t -> axes:int array -> (Stdlib.Complex.t, 'b) t

rfft ?out t ~dtype ~axes computes the real-input DFT along axes.

Exploits conjugate symmetry to return only the non-redundant half of the spectrum along the last transformed axis.

val irfft : ?out:(float, 'b) t -> ?s:int array -> (Stdlib.Complex.t, 'a) t -> dtype:(float, 'b) Nx_core.Dtype.t -> axes:int array -> (float, 'b) t

irfft ?out ?s t ~dtype ~axes computes the inverse real-input DFT along axes.

Takes conjugate-symmetric complex input, returns real output. s specifies output sizes along the transformed axes; None infers sizes from the input.

Linear Algebra

All linalg operations support batching: the last two dimensions are the matrix dimensions, earlier dimensions are batch dimensions.

Frontend guarantees: input matrices have compatible shapes (square where required, matching dimensions for solves).

Backend must: allocate and return result tensors. Typically delegates to LAPACK.

val cholesky : upper:bool -> ('a, 'b) t -> ('a, 'b) t

cholesky ~upper t computes the Cholesky factorization of a positive-definite matrix. Returns L (lower) or U (upper) such that A = L·Lᵀ or A = Uᵀ·U.

  • raises Failure

    if not positive-definite.

val qr : reduced:bool -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t

qr ~reduced t returns (Q, R) where Q is orthogonal and R is upper triangular. reduced = true returns economy-size factorization.

val svd : full_matrices:bool -> ('a, 'b) t -> ('a, 'b) t * (float, Nx_core.Dtype.float64_elt) t * ('a, 'b) t

svd ~full_matrices t returns (U, S, Vᴴ). S is a 1D float64 vector of singular values in descending order. full_matrices = false returns thin SVD.

val eig : vectors:bool -> ('a, 'b) t -> (Stdlib.Complex.t, Nx_core.Dtype.complex64_elt) t * (Stdlib.Complex.t, Nx_core.Dtype.complex64_elt) t option

eig ~vectors t computes eigenvalues (and optionally eigenvectors) of a square matrix. Returns complex64 results.

val eigh : vectors:bool -> ('a, 'b) t -> (float, Nx_core.Dtype.float64_elt) t * ('a, 'b) t option

eigh ~vectors t computes eigenvalues (and optionally eigenvectors) of a symmetric/Hermitian matrix. Eigenvalues are float64.

val triangular_solve : upper:bool -> transpose:bool -> unit_diag:bool -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

triangular_solve ~upper ~transpose ~unit_diag a b solves A·x = b or Aᵀ·x = b where A is triangular.

upper: A is upper triangular. transpose: solve Aᵀ·x = b. unit_diag: assume diagonal is all ones.

val create_context : unit -> context