Module Nx_core.Make_frontend
Frontend functor parameterized by a backend implementation.
Parameters
module B : Backend_intf.SSignature
module B = Btype ('a, 'b) t = ('a, 'b) B.ttype context = B.contexttype float16_elt = Nx_buffer.float16_elttype float32_elt = Nx_buffer.float32_elttype float64_elt = Nx_buffer.float64_elttype bfloat16_elt = Nx_buffer.bfloat16_elttype float8_e4m3_elt = Nx_buffer.float8_e4m3_elttype float8_e5m2_elt = Nx_buffer.float8_e5m2_elttype int4_elt = Nx_buffer.int4_signed_elttype uint4_elt = Nx_buffer.int4_unsigned_elttype int8_elt = Nx_buffer.int8_signed_elttype uint8_elt = Nx_buffer.int8_unsigned_elttype int16_elt = Nx_buffer.int16_signed_elttype uint16_elt = Nx_buffer.int16_unsigned_elttype int32_elt = Nx_buffer.int32_elttype uint32_elt = Nx_buffer.uint32_elttype int64_elt = Nx_buffer.int64_elttype uint64_elt = Nx_buffer.uint64_elttype complex32_elt = Nx_buffer.complex32_elttype complex64_elt = Nx_buffer.complex64_elttype bool_elt = Nx_buffer.bool_elttype ('a, 'b) dtype = ('a, 'b) Dtype.t = | Float16 : (float, float16_elt) dtype| Float32 : (float, float32_elt) dtype| Float64 : (float, float64_elt) dtype| BFloat16 : (float, bfloat16_elt) dtype| Float8_e4m3 : (float, float8_e4m3_elt) dtype| Float8_e5m2 : (float, float8_e5m2_elt) dtype| Int4 : (int, int4_elt) dtype| UInt4 : (int, uint4_elt) dtype| Int8 : (int, int8_elt) dtype| UInt8 : (int, uint8_elt) dtype| Int16 : (int, int16_elt) dtype| UInt16 : (int, uint16_elt) dtype| Int32 : (int32, int32_elt) dtype| UInt32 : (int32, uint32_elt) dtype| Int64 : (int64, int64_elt) dtype| UInt64 : (int64, uint64_elt) dtype| Complex64 : (Stdlib.Complex.t, complex32_elt) dtype| Complex128 : (Stdlib.Complex.t, complex64_elt) dtype| Bool : (bool, bool_elt) dtype
type float16_t = (float, float16_elt) ttype float32_t = (float, float32_elt) ttype float64_t = (float, float64_elt) ttype uint16_t = (int, uint16_elt) ttype uint32_t = (int32, uint32_elt) ttype uint64_t = (int64, uint64_elt) ttype complex64_t = (Stdlib.Complex.t, complex32_elt) ttype complex128_t = (Stdlib.Complex.t, complex64_elt) tval float16 : (float, float16_elt) dtypeval float32 : (float, float32_elt) dtypeval float64 : (float, float64_elt) dtypeval bfloat16 : (float, bfloat16_elt) dtypeval float8_e4m3 : (float, float8_e4m3_elt) dtypeval float8_e5m2 : (float, float8_e5m2_elt) dtypeval uint16 : (int, uint16_elt) dtypeval uint32 : (int32, uint32_elt) dtypeval uint64 : (int64, uint64_elt) dtypeval complex64 : (Stdlib.Complex.t, complex32_elt) dtypeval complex128 : (Stdlib.Complex.t, complex64_elt) dtypeval data : ('a, 'b) B.t -> ('a, 'b) Nx_buffer.tval shape : ('a, 'b) B.t -> int arrayval itemsize : ('a, 'b) B.t -> intval strides : ('a, 'b) B.t -> int arrayval stride : int -> ('a, 'b) B.t -> intval dims : ('a, 'b) B.t -> int arrayval dim : int -> ('a, 'b) B.t -> intval ndim : ('a, 'b) B.t -> intval size : ('a, 'b) B.t -> intval numel : ('a, 'b) B.t -> intval nbytes : ('a, 'b) B.t -> intval offset : ('a, 'b) B.t -> intval is_c_contiguous : ('a, 'b) B.t -> boolmodule IntSet : sig ... endval power_of_two : 'a 'b. ('a, 'b) Dtype.t -> int -> 'aval ensure_float_dtype : string -> ('a, 'b) B.t -> unitval ensure_int_dtype : string -> ('a, 'b) B.t -> unitval resolve_axis : ?ndim_opt:int -> ('a, 'b) B.t -> int option -> int arrayval resolve_single_axis : ?ndim_opt:int -> ('a, 'b) B.t -> int -> intval to_buffer : ('a, 'b) B.t -> ('a, 'b) Nx_buffer.tval to_bigarray :
('c, 'd) B.t ->
('a, 'b, Stdlib.Bigarray.c_layout) Stdlib.Bigarray.Genarray.tval of_buffer :
B.context ->
shape:Shape.t ->
('a, 'b) Nx_buffer.t ->
('a, 'b) B.tval to_array : ('a, 'b) B.t -> 'a arrayval cmpop :
?out:(bool, Dtype.bool_elt) B.t ->
(out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> 'c) ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval cmpeq :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval cmpne :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval cmplt :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval cmple :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval cmpgt :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval cmpge :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval less :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval less_equal :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval greater :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval greater_equal :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval equal :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval not_equal :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval equal_s :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
'a ->
(bool, Dtype.bool_elt) B.tval not_equal_s :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
'a ->
(bool, Dtype.bool_elt) B.tval less_s :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
'a ->
(bool, Dtype.bool_elt) B.tval greater_s :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
'a ->
(bool, Dtype.bool_elt) B.tval less_equal_s :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
'a ->
(bool, Dtype.bool_elt) B.tval greater_equal_s :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
'a ->
(bool, Dtype.bool_elt) B.tval isinf :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval isnan :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval isfinite :
?out:(bool, Dtype.bool_elt) B.t ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval all :
?out:(bool, Dtype.bool_elt) B.t ->
?axes:int list ->
?keepdims:bool ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval any :
?out:(bool, Dtype.bool_elt) B.t ->
?axes:int list ->
?keepdims:bool ->
('a, 'b) B.t ->
(bool, Dtype.bool_elt) B.tval array_equal : ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.tval squeeze : ?axes:IntSet.elt list -> ('a, 'b) B.t -> ('a, 'b) B.tval unsqueeze : ?axes:IntSet.elt list -> ('a, 'b) B.t -> ('a, 'b) B.tval squeeze_axis : IntSet.elt -> ('a, 'b) B.t -> ('a, 'b) B.tval unsqueeze_axis : IntSet.elt -> ('a, 'b) B.t -> ('a, 'b) B.tval expand_dims : IntSet.elt list -> ('a, 'b) B.t -> ('a, 'b) B.tval check_dtypes_match : op:string -> ('a, 'b) B.t list -> unitval stack : ?axis:IntSet.elt -> ('a, 'b) B.t list -> ('a, 'b) B.tval take_along_axis :
axis:int ->
(int32, Dtype.int32_elt) B.t ->
('a, 'b) B.t ->
('a, 'b) B.tval unsafe_get : int list -> ('a, 'b) B.t -> 'aval unsafe_set : int list -> 'a -> ('a, 'b) B.t -> unitval item : int list -> ('a, 'b) B.t -> 'aval set_item : int list -> 'a -> ('a, 'b) B.t -> unitval put_along_axis :
axis:int ->
indices:(int32, Dtype.int32_elt) B.t ->
values:('a, 'b) B.t ->
('a, 'b) B.t ->
unitval sort :
?descending:bool ->
?axis:int ->
('a, 'b) t ->
('a, 'b) t * (int32, Dtype.int32_elt) B.tval argsort :
?descending:bool ->
?axis:int ->
('a, 'b) t ->
(int32, Dtype.int32_elt) B.tval argmax :
?axis:int ->
?keepdims:bool ->
('a, 'b) B.t ->
(int32, Dtype.int32_elt) B.tval argmin :
?axis:int ->
?keepdims:bool ->
('a, 'b) t ->
(int32, Dtype.int32_elt) tval bernoulli : B.context -> p:float -> Shape.t -> (bool, Dtype.bool_elt) B.tval permutation : B.context -> int -> (int32, Dtype.int32_elt) B.tval categorical :
B.context ->
?axis:int ->
?shape:int array ->
('a, 'b) t ->
(int32, Dtype.int32_elt) tval tensordot :
?axes:(IntSet.elt list * IntSet.elt list) ->
('a, 'b) B.t ->
('a, 'b) B.t ->
('a, 'b) B.tmodule Einsum : sig ... endval cross :
?out:('a, 'b) B.t ->
?axis:IntSet.elt ->
('a, 'b) B.t ->
('a, 'b) B.t ->
('a, 'b) B.tval check_square : op:string -> ('a, 'b) B.t -> unitval check_float_or_complex : op:string -> ('a, 'b) t -> unitval check_real : op:string -> ('a, 'b) t -> unitval svd :
?full_matrices:bool ->
('a, 'b) t ->
('a, 'b) B.t * (float, Dtype.float64_elt) B.t * ('a, 'b) B.tval svdvals : ('a, 'b) t -> (float, Dtype.float64_elt) B.tval eig :
('a, 'b) B.t ->
(Stdlib.Complex.t, Dtype.complex64_elt) B.t
* (Stdlib.Complex.t, Dtype.complex64_elt) B.tval eigh :
?uplo:'a ->
('b, 'c) B.t ->
(float, Dtype.float64_elt) B.t * ('b, 'c) B.tval eigvals : ('a, 'b) B.t -> (Stdlib.Complex.t, Dtype.complex64_elt) B.tval eigvalsh : ?uplo:'a -> ('b, 'c) B.t -> (float, Dtype.float64_elt) B.tval slogdet :
('a, 'b) B.t ->
(float, Dtype.float32_elt) B.t * (float, Dtype.float32_elt) tval matrix_rank :
?tol:float ->
?rtol:float ->
?hermitian:bool ->
('a, 'b) t ->
intval fft_norm_scale :
[< `Backward | `Forward | `Ortho ] ->
int list ->
('a, 'b) B.t ->
floatval ifft_norm_scale :
[< `Backward | `Forward | `Ortho ] ->
int list ->
('a, 'b) B.t ->
floatval rfft :
?out:(Stdlib.Complex.t, Dtype.complex64_elt) B.t ->
?axis:int ->
?n:int ->
?norm:[< `Backward | `Forward | `Ortho Backward ] ->
(float, 'a) B.t ->
(Stdlib.Complex.t, Dtype.complex64_elt) tval irfft :
?out:(float, Dtype.float64_elt) B.t ->
?axis:int ->
?n:int ->
?norm:[< `Backward | `Forward | `Ortho Backward ] ->
(Stdlib.Complex.t, 'a) B.t ->
(float, Dtype.float64_elt) B.tval check_fft2 : op:string -> ('a, 'b) B.t -> int list option -> int listval rfft2 :
?out:(Stdlib.Complex.t, Dtype.complex64_elt) B.t ->
?axes:int list ->
?s:int list ->
?norm:[< `Backward | `Forward | `Ortho Backward ] ->
(float, 'a) B.t ->
(Stdlib.Complex.t, Dtype.complex64_elt) tval irfft2 :
?out:(float, Dtype.float64_elt) B.t ->
?axes:int list ->
?s:int list ->
?norm:[< `Backward | `Forward | `Ortho Backward ] ->
(Stdlib.Complex.t, 'a) B.t ->
(float, Dtype.float64_elt) B.tval rfftn :
?out:(Stdlib.Complex.t, Dtype.complex64_elt) B.t ->
?axes:int list ->
?s:int list ->
?norm:[< `Backward | `Forward | `Ortho Backward ] ->
(float, 'a) B.t ->
(Stdlib.Complex.t, Dtype.complex64_elt) tval irfftn :
?out:(float, Dtype.float64_elt) B.t ->
?axes:int list ->
?s:int list ->
?norm:[< `Backward | `Forward | `Ortho Backward ] ->
(Stdlib.Complex.t, 'a) B.t ->
(float, Dtype.float64_elt) B.tval hfft :
?axis:int ->
?n:int ->
?norm:[< `Backward | `Forward | `Ortho Backward ] ->
(Stdlib.Complex.t, 'a) B.t ->
(float, Dtype.float64_elt) B.tval ihfft :
?axis:int ->
?n:int ->
?norm:[< `Backward | `Forward | `Ortho Backward ] ->
(float, 'a) B.t ->
(Stdlib.Complex.t, Dtype.complex64_elt) tval fftfreq : B.context -> ?d:float -> int -> (float, Dtype.float64_elt) B.tval rfftfreq : B.context -> ?d:float -> int -> (float, Dtype.float64_elt) B.tval logsumexp :
?out:('a, 'b) B.t ->
?axes:IntSet.elt list ->
?keepdims:bool ->
('a, 'b) B.t ->
('a, 'b) B.tval logmeanexp :
?out:('a, 'b) B.t ->
?axes:IntSet.elt list ->
?keepdims:bool ->
('a, 'b) B.t ->
('a, 'b) B.tval one_hot : num_classes:int -> ('a, 'b) B.t -> (int, Dtype.uint8_elt) tval pp_data : Stdlib.Format.formatter -> ('a, 'b) t -> unitval data_to_string : ('a, 'b) t -> stringval print_data : ('a, 'b) t -> unitval pp_dtype : Stdlib.Format.formatter -> ('a, 'b) Dtype.t -> unitval dtype_to_string : ('a, 'b) Dtype.t -> stringval pp : Stdlib.Format.formatter -> ('a, 'b) B.t -> unitval print : ('a, 'b) B.t -> unitval to_string : ('a, 'b) B.t -> stringval iter_item : ('a -> 'b) -> ('a, 'c) B.t -> unitval fold_item : ('a -> 'b -> 'a) -> 'a -> ('b, 'c) B.t -> 'amodule Infix : sig ... end