Module Nx_core.Make_frontend

Frontend functor parameterized by a backend implementation.

Parameters

module B : Backend_intf.S

Signature

module B = B
val err : string -> ('a, unit, string, 'b) Stdlib.format4 -> 'a
type ('a, 'b) t = ('a, 'b) B.t
type context = B.context
type float16_elt = Nx_buffer.float16_elt
type float32_elt = Nx_buffer.float32_elt
type float64_elt = Nx_buffer.float64_elt
type bfloat16_elt = Nx_buffer.bfloat16_elt
type float8_e4m3_elt = Nx_buffer.float8_e4m3_elt
type float8_e5m2_elt = Nx_buffer.float8_e5m2_elt
type int32_elt = Nx_buffer.int32_elt
type uint32_elt = Nx_buffer.uint32_elt
type int64_elt = Nx_buffer.int64_elt
type uint64_elt = Nx_buffer.uint64_elt
type complex32_elt = Nx_buffer.complex32_elt
type complex64_elt = Nx_buffer.complex64_elt
type bool_elt = Nx_buffer.bool_elt
type ('a, 'b) dtype = ('a, 'b) Dtype.t =
  1. | Float16 : (float, float16_elt) dtype
  2. | Float32 : (float, float32_elt) dtype
  3. | Float64 : (float, float64_elt) dtype
  4. | BFloat16 : (float, bfloat16_elt) dtype
  5. | Float8_e4m3 : (float, float8_e4m3_elt) dtype
  6. | Float8_e5m2 : (float, float8_e5m2_elt) dtype
  7. | Int4 : (int, int4_elt) dtype
  8. | UInt4 : (int, uint4_elt) dtype
  9. | Int8 : (int, int8_elt) dtype
  10. | UInt8 : (int, uint8_elt) dtype
  11. | Int16 : (int, int16_elt) dtype
  12. | UInt16 : (int, uint16_elt) dtype
  13. | Int32 : (int32, int32_elt) dtype
  14. | UInt32 : (int32, uint32_elt) dtype
  15. | Int64 : (int64, int64_elt) dtype
  16. | UInt64 : (int64, uint64_elt) dtype
  17. | Complex64 : (Stdlib.Complex.t, complex32_elt) dtype
  18. | Complex128 : (Stdlib.Complex.t, complex64_elt) dtype
  19. | Bool : (bool, bool_elt) dtype
type float16_t = (float, float16_elt) t
type float32_t = (float, float32_elt) t
type float64_t = (float, float64_elt) t
type int8_t = (int, int8_elt) t
type uint8_t = (int, uint8_elt) t
type int16_t = (int, int16_elt) t
type uint16_t = (int, uint16_elt) t
type int32_t = (int32, int32_elt) t
type int64_t = (int64, int64_elt) t
type uint32_t = (int32, uint32_elt) t
type uint64_t = (int64, uint64_elt) t
type complex64_t = (Stdlib.Complex.t, complex32_elt) t
type complex128_t = (Stdlib.Complex.t, complex64_elt) t
type bool_t = (bool, bool_elt) t
val float16 : (float, float16_elt) dtype
val float32 : (float, float32_elt) dtype
val float64 : (float, float64_elt) dtype
val bfloat16 : (float, bfloat16_elt) dtype
val float8_e4m3 : (float, float8_e4m3_elt) dtype
val float8_e5m2 : (float, float8_e5m2_elt) dtype
val int4 : (int, int4_elt) dtype
val uint4 : (int, uint4_elt) dtype
val int8 : (int, int8_elt) dtype
val uint8 : (int, uint8_elt) dtype
val int16 : (int, int16_elt) dtype
val uint16 : (int, uint16_elt) dtype
val int32 : (int32, int32_elt) dtype
val uint32 : (int32, uint32_elt) dtype
val int64 : (int64, int64_elt) dtype
val uint64 : (int64, uint64_elt) dtype
val complex64 : (Stdlib.Complex.t, complex32_elt) dtype
val complex128 : (Stdlib.Complex.t, complex64_elt) dtype
val bool : (bool, bool_elt) dtype
type index =
  1. | I of int
  2. | L of int list
  3. | R of int * int
  4. | Rs of int * int * int
  5. | A
  6. | M of (bool, bool_elt) t
  7. | N
val data : ('a, 'b) B.t -> ('a, 'b) Nx_buffer.t
val shape : ('a, 'b) B.t -> int array
val dtype : ('a, 'b) B.t -> ('a, 'b) Dtype.t
val itemsize : ('a, 'b) B.t -> int
val strides : ('a, 'b) B.t -> int array
val stride : int -> ('a, 'b) B.t -> int
val dims : ('a, 'b) B.t -> int array
val dim : int -> ('a, 'b) B.t -> int
val ndim : ('a, 'b) B.t -> int
val size : ('a, 'b) B.t -> int
val numel : ('a, 'b) B.t -> int
val nbytes : ('a, 'b) B.t -> int
val offset : ('a, 'b) B.t -> int
val is_c_contiguous : ('a, 'b) B.t -> bool
val array_prod : int array -> int
module IntSet : sig ... end
val power_of_two : 'a 'b. ('a, 'b) Dtype.t -> int -> 'a
val ensure_float_dtype : string -> ('a, 'b) B.t -> unit
val ensure_int_dtype : string -> ('a, 'b) B.t -> unit
val resolve_axis : ?ndim_opt:int -> ('a, 'b) B.t -> int option -> int array
val resolve_single_axis : ?ndim_opt:int -> ('a, 'b) B.t -> int -> int
val normalize_and_dedup_axes : op:string -> int -> int list -> int list
val reduction_element_count : int array -> ?axes:int list -> unit -> int
val copy_to_out : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val reshape : Shape.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val broadcast_shapes : Shape.t -> Shape.t -> int array
val broadcast_to : Shape.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val broadcasted : ?reverse:bool -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t * ('a, 'b) B.t
val expand : int array -> ('a, 'b) B.t -> ('a, 'b) B.t
val cast : ('c, 'd) Dtype.t -> ('a, 'b) t -> ('c, 'd) t
val astype : ('a, 'b) Dtype.t -> ('c, 'd) t -> ('a, 'b) t
val contiguous : ('a, 'b) B.t -> ('a, 'b) B.t
val copy : ('a, 'b) B.t -> ('a, 'b) B.t
val blit : ('a, 'b) B.t -> ('a, 'b) B.t -> unit
val create : B.context -> ('a, 'b) Dtype.t -> int array -> 'a array -> ('a, 'b) B.t
val init : B.context -> ('a, 'b) Dtype.t -> Shape.t -> (int array -> 'a) -> ('a, 'b) B.t
val scalar : B.context -> ('a, 'b) Dtype.t -> 'a -> ('a, 'b) B.t
val scalar_like : ('a, 'b) B.t -> 'a -> ('a, 'b) B.t
val fill : 'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val empty : B.context -> ('a, 'b) Dtype.t -> int array -> ('a, 'b) B.t
val zeros : B.context -> ('a, 'b) Dtype.t -> int array -> ('a, 'b) B.t
val ones : B.context -> ('a, 'b) Dtype.t -> int array -> ('a, 'b) B.t
val full : B.context -> ('a, 'b) Dtype.t -> int array -> 'a -> ('a, 'b) B.t
val create_like : ('a, 'b) B.t -> (B.context -> ('a, 'b) Dtype.t -> int array -> 'c) -> 'c
val empty_like : ('a, 'b) B.t -> ('a, 'b) B.t
val full_like : ('a, 'b) B.t -> 'a -> ('a, 'b) B.t
val zeros_like : ('a, 'b) B.t -> ('a, 'b) B.t
val ones_like : ('a, 'b) B.t -> ('a, 'b) B.t
val to_buffer : ('a, 'b) B.t -> ('a, 'b) Nx_buffer.t
val to_bigarray : ('c, 'd) B.t -> ('a, 'b, Stdlib.Bigarray.c_layout) Stdlib.Bigarray.Genarray.t
val of_buffer : B.context -> shape:Shape.t -> ('a, 'b) Nx_buffer.t -> ('a, 'b) B.t
val of_bigarray : B.context -> 'c -> ('a, 'b) B.t
val to_array : ('a, 'b) B.t -> 'a array
val binop : ?out:('a, 'b) B.t -> (out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> 'c) -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val cmpop : ?out:(bool, Dtype.bool_elt) B.t -> (out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> 'c) -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val add : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val add_s : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> 'a -> ('a, 'b) B.t
val radd_s : ?out:('a, 'b) B.t -> 'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val sub : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val sub_s : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> 'a -> ('a, 'b) B.t
val rsub_s : ?out:('a, 'b) B.t -> 'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val mul : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val mul_s : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> 'a -> ('a, 'b) B.t
val rmul_s : ?out:('a, 'b) B.t -> 'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val div : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val div_s : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> 'a -> ('a, 'b) B.t
val rdiv_s : ?out:('a, 'b) B.t -> 'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val pow : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val pow_s : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> 'a -> ('a, 'b) B.t
val rpow_s : ?out:('a, 'b) B.t -> 'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val maximum : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val maximum_s : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> 'a -> ('a, 'b) B.t
val rmaximum_s : ?out:('a, 'b) B.t -> 'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val minimum : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val minimum_s : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> 'a -> ('a, 'b) B.t
val rminimum_s : ?out:('a, 'b) B.t -> 'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val mod_ : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val mod_s : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> 'a -> ('a, 'b) B.t
val rmod_s : ?out:('a, 'b) B.t -> 'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val bitwise_xor : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val bitwise_or : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val bitwise_and : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val logical_and : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val logical_or : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val logical_xor : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val logical_not : ?out:('a, 'b) B.t -> ('a, 'b) t -> ('a, 'b) t
val cmpeq : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val cmpne : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val cmplt : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val cmple : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val cmpgt : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val cmpge : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val less : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val less_equal : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val greater : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val greater_equal : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val equal : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val not_equal : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val equal_s : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> 'a -> (bool, Dtype.bool_elt) B.t
val not_equal_s : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> 'a -> (bool, Dtype.bool_elt) B.t
val less_s : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> 'a -> (bool, Dtype.bool_elt) B.t
val greater_s : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> 'a -> (bool, Dtype.bool_elt) B.t
val less_equal_s : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> 'a -> (bool, Dtype.bool_elt) B.t
val greater_equal_s : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> 'a -> (bool, Dtype.bool_elt) B.t
val unaryop : ?out:('a, 'b) B.t -> (out:('a, 'b) B.t -> ('a, 'b) B.t -> 'c) -> ('a, 'b) B.t -> ('a, 'b) B.t
val neg : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val bitwise_not : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val invert : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val sin : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val cos : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val sqrt : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val recip : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val log : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val exp : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val abs : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val log2 : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val exp2 : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val tan : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val square : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val sign : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val relu : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val sigmoid : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val rsqrt : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val asin : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val acos : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val atan : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val sinh : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val cosh : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val tanh : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val asinh : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val acosh : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val atanh : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val trunc : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val ceil : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val floor : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val round : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val isinf : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val isnan : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val isfinite : ?out:(bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val lerp : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val lerp_scalar_weight : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> 'a -> ('a, 'b) B.t
val shift_op : op:string -> apply:(?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t) -> ?out:('a, 'b) B.t -> ('a, 'b) B.t -> int -> ('a, 'b) B.t
val lshift : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> int -> ('a, 'b) B.t
val rshift : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> int -> ('a, 'b) B.t
val clamp : ?out:('a, 'b) B.t -> ?min:'a -> ?max:'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val clip : ?out:('a, 'b) B.t -> ?min:'a -> ?max:'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val where : ?out:('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val atan2 : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val hypot : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val reduce_output_shape : int array -> int array -> bool -> int array
val reduce_op : ?out:('a, 'b) B.t -> (out:('a, 'b) B.t -> axes:int array -> keepdims:bool -> ('a, 'b) B.t -> 'c) -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) B.t -> ('a, 'b) B.t
val sum : ?out:('a, 'b) B.t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) B.t -> ('a, 'b) B.t
val max : ?out:('a, 'b) B.t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) B.t -> ('a, 'b) B.t
val min : ?out:('a, 'b) B.t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) B.t -> ('a, 'b) B.t
val prod : ?out:('a, 'b) B.t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) B.t -> ('a, 'b) B.t
val associative_scan : axis:int -> [ `Max | `Min | `Prod | `Sum ] -> ('a, 'b) B.t -> ('a, 'b) B.t
val cumulative_scan : ?axis:int -> [ `Max | `Min | `Prod | `Sum ] -> ('a, 'b) B.t -> ('a, 'b) B.t
val cumsum : ?axis:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val cumprod : ?axis:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val cummax : ?axis:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val cummin : ?axis:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val mean : ?out:('a, 'b) B.t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) B.t -> ('a, 'b) B.t
val var : ?out:('a, 'b) B.t -> ?axes:int list -> ?keepdims:bool -> ?ddof:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val std : ?out:('a, 'b) B.t -> ?axes:int list -> ?keepdims:bool -> ?ddof:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val all : ?out:(bool, Dtype.bool_elt) B.t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val any : ?out:(bool, Dtype.bool_elt) B.t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val array_equal : ('a, 'b) B.t -> ('a, 'b) B.t -> (bool, Dtype.bool_elt) B.t
val pad : (int * int) array -> 'a -> ('a, 'b) B.t -> ('a, 'b) B.t
val shrink : (int * int) array -> ('a, 'b) B.t -> ('a, 'b) B.t
val flatten : ?start_dim:int -> ?end_dim:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val unflatten : int -> int array -> ('a, 'b) B.t -> ('a, 'b) B.t
val ravel : ('a, 'b) B.t -> ('a, 'b) B.t
val squeeze : ?axes:IntSet.elt list -> ('a, 'b) B.t -> ('a, 'b) B.t
val unsqueeze : ?axes:IntSet.elt list -> ('a, 'b) B.t -> ('a, 'b) B.t
val squeeze_axis : IntSet.elt -> ('a, 'b) B.t -> ('a, 'b) B.t
val unsqueeze_axis : IntSet.elt -> ('a, 'b) B.t -> ('a, 'b) B.t
val expand_dims : IntSet.elt list -> ('a, 'b) B.t -> ('a, 'b) B.t
val transpose : ?axes:int list -> ('a, 'b) B.t -> ('a, 'b) B.t
val flip : ?axes:int list -> ('a, 'b) B.t -> ('a, 'b) B.t
val moveaxis : int -> int -> ('a, 'b) B.t -> ('a, 'b) B.t
val swapaxes : int -> int -> ('a, 'b) B.t -> ('a, 'b) B.t
val cat_tensors : axis:int -> ('a, 'b) B.t list -> ('a, 'b) B.t
val roll : ?axis:int -> int -> ('a, 'b) B.t -> ('a, 'b) B.t
val tile : int array -> ('a, 'b) B.t -> ('a, 'b) B.t
val repeat : ?axis:int -> int -> ('a, 'b) B.t -> ('a, 'b) B.t
val check_dtypes_match : op:string -> ('a, 'b) B.t list -> unit
val concatenate : ?axis:int -> ('a, 'b) B.t list -> ('a, 'b) B.t
val stack : ?axis:IntSet.elt -> ('a, 'b) B.t list -> ('a, 'b) B.t
val ensure_ndim : int -> ('a, 'b) B.t -> ('a, 'b) B.t
val vstack : ('a, 'b) B.t list -> ('a, 'b) B.t
val hstack : ('a, 'b) B.t list -> ('a, 'b) B.t
val dstack : ('a, 'b) B.t list -> ('a, 'b) B.t
val broadcast_arrays : ('a, 'b) B.t list -> ('a, 'b) B.t list
val eye : B.context -> ?m:int -> ?k:int -> ('a, 'b) Dtype.t -> int -> ('a, 'b) B.t
val identity : B.context -> ('a, 'b) Dtype.t -> int -> ('a, 'b) B.t
val diag : ?k:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val arange : B.context -> ('a, 'b) Dtype.t -> int -> int -> int -> ('a, 'b) B.t
val arange_f : B.context -> (float, 'a) Dtype.t -> float -> float -> float -> (float, 'a) B.t
val linspace : B.context -> ('a, 'b) Dtype.t -> ?endpoint:bool -> float -> float -> int -> ('a, 'b) B.t
val logspace : B.context -> (float, 'a) Dtype.t -> ?endpoint:bool -> ?base:float -> float -> float -> int -> (float, 'a) B.t
val geomspace : B.context -> (float, 'a) Dtype.t -> ?endpoint:bool -> float -> float -> int -> (float, 'a) B.t
val meshgrid : ?indexing:[< `ij | `xy xy ] -> ('a, 'b) B.t -> ('c, 'd) B.t -> ('a, 'b) B.t * ('c, 'd) B.t
val triangular_mask : op:string -> cmp: ((int32, int32_elt) B.t -> (int32, int32_elt) B.t -> (bool, Dtype.bool_elt) B.t) -> ?k:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val tril : ?k:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val triu : ?k:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val apply_index_mode : mode:[< `clip | `raise | `wrap ] -> n:int -> B.context -> (int32, int32_elt) B.t -> (int32, int32_elt) B.t
val take : ?axis:int -> ?mode:[< `clip | `raise | `wrap raise ] -> (int32, int32_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val take_along_axis : axis:int -> (int32, Dtype.int32_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val normalize_index : int -> int -> int
val normalize_and_check_index : op:string -> int -> int -> int
type dim_op =
  1. | View of {
    1. start : int;
    2. stop : int;
    3. step : int;
    4. dim_len : int;
    }
  2. | Squeeze of {
    1. idx : int;
    }
  3. | Gather of int array
  4. | New_axis
val normalize_slice_spec : int -> index -> dim_op
val slice_internal : index list -> ('a, 'b) B.t -> ('a, 'b) B.t
val set_slice_internal : index list -> ('a, 'b) B.t -> ('a, 'b) B.t -> unit
val get : int list -> ('a, 'b) B.t -> ('a, 'b) B.t
val set : int list -> ('a, 'b) B.t -> ('a, 'b) B.t -> unit
val unsafe_get : int list -> ('a, 'b) B.t -> 'a
val unsafe_set : int list -> 'a -> ('a, 'b) B.t -> unit
val slice : index list -> ('a, 'b) B.t -> ('a, 'b) B.t
val set_slice : index list -> ('a, 'b) B.t -> ('a, 'b) B.t -> unit
val item : int list -> ('a, 'b) B.t -> 'a
val set_item : int list -> 'a -> ('a, 'b) B.t -> unit
val put : ?axis:int -> indices:(int32, int32_elt) B.t -> values:('a, 'b) B.t -> ?mode:[< `clip | `raise | `wrap raise ] -> ('a, 'b) B.t -> unit
val index_put : indices:(int32, int32_elt) B.t array -> values:('a, 'b) B.t -> ?mode:[< `clip | `raise | `wrap raise ] -> ('a, 'b) B.t -> unit
val put_along_axis : axis:int -> indices:(int32, Dtype.int32_elt) B.t -> values:('a, 'b) B.t -> ('a, 'b) B.t -> unit
val nonzero_indices_only : (bool, bool_elt) t -> (int32, int32_elt) B.t array
val compress : ?axis:int -> condition:(bool, bool_elt) t -> ('a, 'b) B.t -> ('a, 'b) B.t
val extract : condition:(bool, bool_elt) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val nonzero : ('a, 'b) t -> (int32, int32_elt) B.t array
val argwhere : ('a, 'b) t -> (int32, int32_elt) B.t
val array_split : axis:int -> [< `Count of int | `Indices of int list ] -> ('a, 'b) B.t -> ('a, 'b) B.t list
val split : axis:int -> int -> ('a, 'b) B.t -> ('a, 'b) B.t list
val sort : ?descending:bool -> ?axis:int -> ('a, 'b) t -> ('a, 'b) t * (int32, Dtype.int32_elt) B.t
val argsort : ?descending:bool -> ?axis:int -> ('a, 'b) t -> (int32, Dtype.int32_elt) B.t
val argmax : ?axis:int -> ?keepdims:bool -> ('a, 'b) B.t -> (int32, Dtype.int32_elt) B.t
val argmin : ?axis:int -> ?keepdims:bool -> ('a, 'b) t -> (int32, Dtype.int32_elt) t
val validate_random_float_params : string -> ('a, 'b) Dtype.t -> Shape.t -> unit
val rand : B.context -> ('a, 'b) Dtype.t -> Shape.t -> ('a, 'b) B.t
val randn : B.context -> ('a, 'b) Dtype.t -> Shape.t -> ('a, 'b) B.t
val randint : B.context -> ('a, 'b) Dtype.t -> ?high:int -> Shape.t -> int -> ('a, 'b) t
val bernoulli : B.context -> p:float -> Shape.t -> (bool, Dtype.bool_elt) B.t
val permutation : B.context -> int -> (int32, Dtype.int32_elt) B.t
val shuffle : B.context -> ('a, 'b) B.t -> ('a, 'b) B.t
val categorical : B.context -> ?axis:int -> ?shape:int array -> ('a, 'b) t -> (int32, Dtype.int32_elt) t
val truncated_normal : B.context -> ('a, 'b) Dtype.t -> lower:float -> upper:float -> Shape.t -> ('a, 'b) B.t
val matmul_output_shape : ('a, 'b) B.t -> ('c, 'd) B.t -> int array
val matmul_with_alloc : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val dot : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val matmul : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val diagonal : ?offset:int -> ?axis1:int -> ?axis2:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val matrix_transpose : ('a, 'b) B.t -> ('a, 'b) B.t
val extract_complex_part : op:string -> field:(Stdlib.Complex.t -> float) -> ('a, 'b) t -> ('c, 'd) t
val complex : real:('a, 'b) t -> imag:('a, 'b) t -> 'c
val real : ('a, 'b) t -> ('c, 'd) t
val imag : ('a, 'b) t -> ('c, 'd) t
val conjugate : ('a, 'b) t -> ('a, 'b) t
val vdot : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) B.t
val vecdot : ?axis:int -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val inner : ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val outer : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val tensordot : ?axes:(IntSet.elt list * IntSet.elt list) -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
module Einsum : sig ... end
val einsum : string -> ('a, 'b) B.t array -> ('a, 'b) B.t
val kron : ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val multi_dot : ('a, 'b) B.t array -> ('a, 'b) B.t
val cross : ?out:('a, 'b) B.t -> ?axis:IntSet.elt -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val check_square : op:string -> ('a, 'b) B.t -> unit
val check_float_or_complex : op:string -> ('a, 'b) t -> unit
val check_real : op:string -> ('a, 'b) t -> unit
val cholesky : ?upper:bool -> ('a, 'b) B.t -> ('a, 'b) B.t
val qr : ?mode:[< `Complete | `Reduced ] -> ('a, 'b) t -> ('a, 'b) B.t * ('a, 'b) B.t
val svd : ?full_matrices:bool -> ('a, 'b) t -> ('a, 'b) B.t * (float, Dtype.float64_elt) B.t * ('a, 'b) B.t
val svdvals : ('a, 'b) t -> (float, Dtype.float64_elt) B.t
val eig : ('a, 'b) B.t -> (Stdlib.Complex.t, Dtype.complex64_elt) B.t * (Stdlib.Complex.t, Dtype.complex64_elt) B.t
val eigh : ?uplo:'a -> ('b, 'c) B.t -> (float, Dtype.float64_elt) B.t * ('b, 'c) B.t
val eigvals : ('a, 'b) B.t -> (Stdlib.Complex.t, Dtype.complex64_elt) B.t
val eigvalsh : ?uplo:'a -> ('b, 'c) B.t -> (float, Dtype.float64_elt) B.t
val norm : ?ord: [> `Fro | `Inf | `NegInf | `NegOne | `NegTwo | `Nuc | `One | `P of float | `Two ] -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) B.t
val slogdet : ('a, 'b) B.t -> (float, Dtype.float32_elt) B.t * (float, Dtype.float32_elt) t
val det : ('a, 'b) B.t -> ('a, 'b) B.t
val matrix_rank : ?tol:float -> ?rtol:float -> ?hermitian:bool -> ('a, 'b) t -> int
val trace : ?out:('a, 'b) B.t -> ?offset:int -> ('a, 'b) B.t -> ('a, 'b) B.t
val solve : ('a, 'b) B.t -> ('a, 'b) t -> ('a, 'b) B.t
val pinv : ?rtol:float -> ?hermitian:bool -> ('a, 'b) t -> ('a, 'b) B.t
val lstsq : ?rcond:float -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) B.t * ('a, 'b) B.t * int * (float, Dtype.float64_elt) B.t
val inv : ('a, 'b) B.t -> ('a, 'b) B.t
val matrix_power : ('a, 'b) B.t -> int -> ('a, 'b) B.t
val cond : ?p:[> `Inf | `One | `Two ] -> ('a, 'b) B.t -> ('a, 'b) t
val tensorsolve : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) B.t
val tensorinv : ?ind:int -> ('a, 'b) t -> ('a, 'b) B.t
type fft_norm = [
  1. | `Backward
  2. | `Forward
  3. | `Ortho
]
val pad_or_truncate_for_fft : ('a, 'b) B.t -> int list -> int list option -> ('a, 'b) B.t
val fft_norm_scale : [< `Backward | `Forward | `Ortho ] -> int list -> ('a, 'b) B.t -> float
val ifft_norm_scale : [< `Backward | `Forward | `Ortho ] -> int list -> ('a, 'b) B.t -> float
val apply_fft_scale : ?out:(Stdlib.Complex.t, 'a) B.t -> float -> (Stdlib.Complex.t, 'a) t -> (Stdlib.Complex.t, 'a) t
val fft : ?out:(Stdlib.Complex.t, 'a) B.t -> ?axis:int -> ?n:int -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (Stdlib.Complex.t, 'a) t -> (Stdlib.Complex.t, 'a) t
val ifft : ?out:(Stdlib.Complex.t, 'a) B.t -> ?axis:int -> ?n:int -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (Stdlib.Complex.t, 'a) t -> (Stdlib.Complex.t, 'a) t
val rfft : ?out:(Stdlib.Complex.t, Dtype.complex64_elt) B.t -> ?axis:int -> ?n:int -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (float, 'a) B.t -> (Stdlib.Complex.t, Dtype.complex64_elt) t
val irfft : ?out:(float, Dtype.float64_elt) B.t -> ?axis:int -> ?n:int -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (Stdlib.Complex.t, 'a) B.t -> (float, Dtype.float64_elt) B.t
val check_fft2 : op:string -> ('a, 'b) B.t -> int list option -> int list
val fft2 : ?out:(Stdlib.Complex.t, 'a) B.t -> ?axes:int list -> ?s:int list -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (Stdlib.Complex.t, 'a) B.t -> (Stdlib.Complex.t, 'a) t
val ifft2 : ?out:(Stdlib.Complex.t, 'a) B.t -> ?axes:int list -> ?s:int list -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (Stdlib.Complex.t, 'a) B.t -> (Stdlib.Complex.t, 'a) t
val fftn : ?out:(Stdlib.Complex.t, 'a) B.t -> ?axes:int list -> ?s:int list -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (Stdlib.Complex.t, 'a) B.t -> (Stdlib.Complex.t, 'a) t
val ifftn : ?out:(Stdlib.Complex.t, 'a) B.t -> ?axes:int list -> ?s:int list -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (Stdlib.Complex.t, 'a) B.t -> (Stdlib.Complex.t, 'a) t
val rfft2 : ?out:(Stdlib.Complex.t, Dtype.complex64_elt) B.t -> ?axes:int list -> ?s:int list -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (float, 'a) B.t -> (Stdlib.Complex.t, Dtype.complex64_elt) t
val irfft2 : ?out:(float, Dtype.float64_elt) B.t -> ?axes:int list -> ?s:int list -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (Stdlib.Complex.t, 'a) B.t -> (float, Dtype.float64_elt) B.t
val rfftn : ?out:(Stdlib.Complex.t, Dtype.complex64_elt) B.t -> ?axes:int list -> ?s:int list -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (float, 'a) B.t -> (Stdlib.Complex.t, Dtype.complex64_elt) t
val irfftn : ?out:(float, Dtype.float64_elt) B.t -> ?axes:int list -> ?s:int list -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (Stdlib.Complex.t, 'a) B.t -> (float, Dtype.float64_elt) B.t
val hfft : ?axis:int -> ?n:int -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (Stdlib.Complex.t, 'a) B.t -> (float, Dtype.float64_elt) B.t
val ihfft : ?axis:int -> ?n:int -> ?norm:[< `Backward | `Forward | `Ortho Backward ] -> (float, 'a) B.t -> (Stdlib.Complex.t, Dtype.complex64_elt) t
val fftfreq : B.context -> ?d:float -> int -> (float, Dtype.float64_elt) B.t
val rfftfreq : B.context -> ?d:float -> int -> (float, Dtype.float64_elt) B.t
val fftshift : ?axes:int list -> ('a, 'b) B.t -> ('a, 'b) B.t
val ifftshift : ?axes:int list -> ('a, 'b) B.t -> ('a, 'b) B.t
val softmax : ?out:('a, 'b) B.t -> ?axes:int list -> ?scale:float -> ('a, 'b) B.t -> ('a, 'b) B.t
val log_softmax : ?out:('a, 'b) B.t -> ?axes:int list -> ?scale:float -> ('a, 'b) B.t -> ('a, 'b) B.t
val logsumexp : ?out:('a, 'b) B.t -> ?axes:IntSet.elt list -> ?keepdims:bool -> ('a, 'b) B.t -> ('a, 'b) B.t
val logmeanexp : ?out:('a, 'b) B.t -> ?axes:IntSet.elt list -> ?keepdims:bool -> ('a, 'b) B.t -> ('a, 'b) B.t
val standardize : ?out:('a, 'b) B.t -> ?axes:int list -> ?mean:('a, 'b) B.t -> ?variance:('a, 'b) B.t -> ?epsilon:float -> ('a, 'b) B.t -> ('a, 'b) B.t
val erf : ?out:('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val extract_patches : kernel_size:int array -> stride:int array -> dilation:int array -> padding:(int * int) array -> ('a, 'b) B.t -> ('a, 'b) B.t
val combine_patches : output_size:int array -> kernel_size:int array -> stride:int array -> dilation:int array -> padding:(int * int) array -> ('a, 'b) B.t -> ('a, 'b) B.t
val correlate_padding : mode:[< `Full | `Same | `Valid ] -> 'a -> int array -> (int * int) array
val correlate : ?padding:[< `Full | `Same | `Valid Valid ] -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val convolve : ?padding:[< `Full | `Same | `Valid Valid ] -> ('a, 'b) B.t -> ('a, 'b) B.t -> ('a, 'b) B.t
val sliding_filter : reduce_fn:(('a, 'b) B.t -> axes:int list -> keepdims:bool -> ('c, 'd) B.t) -> kernel_size:int array -> ?stride:int array -> ('a, 'b) B.t -> ('c, 'd) B.t
val maximum_filter : kernel_size:int array -> ?stride:int array -> ('a, 'b) B.t -> ('a, 'b) B.t
val minimum_filter : kernel_size:int array -> ?stride:int array -> ('a, 'b) B.t -> ('a, 'b) B.t
val uniform_filter : kernel_size:int array -> ?stride:int array -> ('a, 'b) B.t -> ('a, 'b) B.t
val one_hot : num_classes:int -> ('a, 'b) B.t -> (int, Dtype.uint8_elt) t
val pp_data : Stdlib.Format.formatter -> ('a, 'b) t -> unit
val format_to_string : (Stdlib.Format.formatter -> 'a -> 'b) -> 'a -> string
val print_with_formatter : (Stdlib.Format.formatter -> 'a -> 'b) -> 'a -> unit
val data_to_string : ('a, 'b) t -> string
val print_data : ('a, 'b) t -> unit
val pp_dtype : Stdlib.Format.formatter -> ('a, 'b) Dtype.t -> unit
val dtype_to_string : ('a, 'b) Dtype.t -> string
val shape_to_string : int array -> string
val pp_shape : Stdlib.Format.formatter -> int array -> unit
val pp : Stdlib.Format.formatter -> ('a, 'b) B.t -> unit
val print : ('a, 'b) B.t -> unit
val to_string : ('a, 'b) B.t -> string
val map_item : ('a -> 'a) -> ('a, 'b) B.t -> ('a, 'b) B.t
val iter_item : ('a -> 'b) -> ('a, 'c) B.t -> unit
val fold_item : ('a -> 'b -> 'a) -> 'a -> ('b, 'c) B.t -> 'a
val map : (('a, 'b) B.t -> ('a, 'b) B.t) -> ('a, 'b) B.t -> ('a, 'b) B.t
val iter : (('a, 'b) B.t -> 'c) -> ('a, 'b) B.t -> unit
val fold : ('a -> ('b, 'c) B.t -> 'a) -> 'a -> ('b, 'c) B.t -> 'a
module Infix : sig ... end