Module Ptree.Tensor

dtype t is t's dtype.

val shape : tensor -> int array

shape t is t's shape.

val numel : tensor -> int

numel t is the number of elements in t.

val to_typed : ('a, 'l) Nx.dtype -> tensor -> ('a, 'l) Nx.t option

to_typed dtype t is Some x iff t has dtype dtype, with x the typed tensor. It is None on dtype mismatch.

val to_typed_exn : ('a, 'l) Nx.dtype -> tensor -> ('a, 'l) Nx.t

to_typed_exn dtype t is the typed tensor in t.

Raises Invalid_argument if t's dtype is not dtype.