Nx_datasets
Load common machine learning datasets and generate synthetic data for testing.
Overview
Nx_datasets provides two categories of data:
- Real datasets: Downloaded and cached locally (MNIST, CIFAR-10, Iris, etc.)
- Synthetic generators: Create data on-the-fly for testing (blobs, moons, regression problems)
Real datasets are automatically downloaded on first use and cached in your platform's cache directory.
Available Datasets Reference
Real Datasets
Dataset | Function | Samples | Features | Task |
---|---|---|---|---|
MNIST |
load_mnist
|
70,000 | 28×28×1 | Classification |
Fashion-MNIST |
load_fashion_mnist
|
70,000 | 28×28×1 | Classification |
CIFAR-10 |
load_cifar10
|
60,000 | 32×32×3 | Classification |
Iris |
load_iris
|
150 | 4 | Classification |
Breast Cancer |
load_breast_cancer
|
569 | 30 | Classification |
Diabetes |
load_diabetes
|
442 | 10 | Regression |
California Housing |
load_california_housing
|
20,640 | 8 | Regression |
Airline Passengers |
load_airline_passengers
|
144 | 1 | Time Series |
Synthetic Generators
Generator | Function | Purpose | Parameters |
---|---|---|---|
Gaussian Blobs |
make_blobs
|
Clustering | centers, cluster_std |
Two Moons |
make_moons
|
Non-linear classification | noise, n_samples |
Concentric Circles |
make_circles
|
Non-linear classification | noise, factor |
Classification |
make_classification
|
Controlled features | n_informative, n_redundant |
Regression |
make_regression
|
Linear relationships | noise, n_features |
Friedman |
make_friedman1/2/3
|
Non-linear regression | - |
Swiss Roll |
make_swiss_roll
|
Manifold learning | n_samples |
S-Curve |
make_s_curve
|
Manifold learning | n_samples |
Loading Real Datasets
Image Datasets
MNIST
Classic handwritten digits dataset:
let (x_train, y_train), (x_test, y_test) = Nx_datasets.load_mnist () in
Printf.printf "Train: %s, Test: %s\n"
(Nx.shape_to_string x_train)
(Nx.shape_to_string x_test)
(* Train: [60000, 28, 28, 1], Test: [10000, 28, 28, 1] *)
Images are uint8 arrays with values 0-255. Labels are single digits 0-9.
Fashion-MNIST
Clothing classification with the same format as MNIST:
let (x_train, y_train), (x_test, y_test) = Nx_datasets.load_fashion_mnist ()
(* 10 classes: T-shirt, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot *)
CIFAR-10
Color images in 10 categories:
let (x_train, y_train), (x_test, y_test) = Nx_datasets.load_cifar10 () in
(* x_train shape: [50000, 32, 32, 3] *)
(* Classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck *)
Tabular Datasets
Iris
Classic flower classification:
let x, y = Nx_datasets.load_iris () in
(* x shape: [150, 4] - sepal length/width, petal length/width *)
(* y shape: [150, 1] - 0=setosa, 1=versicolor, 2=virginica *)
Breast Cancer
Binary classification for cancer diagnosis:
let x, y = Nx_datasets.load_breast_cancer () in
(* x shape: [569, 30] - 30 features per sample *)
(* y shape: [569, 1] - 0=malignant, 1=benign *)
Regression Datasets
(* Diabetes regression *)
let x, y = Nx_datasets.load_diabetes () in
(* x: [442, 10], y: [442, 1] - diabetes progression *)
(* California housing prices *)
let x, y = Nx_datasets.load_california_housing () in
(* x: [20640, 8], y: [20640, 1] - median house values *)
Time Series
let passengers = Nx_datasets.load_airline_passengers () in
(* Monthly airline passenger counts 1949-1960 *)
(* shape: [144] *)
Generating Synthetic Data
Classification Datasets
Gaussian Blobs
Generate isotropic Gaussian blobs for clustering:
let x, y = Nx_datasets.make_blobs
~n_samples:300
~centers:(`N 3)
~cluster_std:0.5
() in
(* 3 well-separated clusters *)
Specify exact cluster centers:
let centers = Nx.of_array Nx.float32 ~shape:[|3; 2|]
[|-10.; -10.; 0.; 0.; 10.; 10.|] in
let x, y = Nx_datasets.make_blobs ~centers:(`Array centers) ()
Two Moons
Binary classification with interleaving half circles:
let x, y = Nx_datasets.make_moons
~n_samples:200
~noise:0.1
() in
(* Ideal for testing non-linear classifiers *)
Concentric Circles
Nested circles for non-linear separation:
let x, y = Nx_datasets.make_circles
~n_samples:200
~noise:0.05
~factor:0.5 (* Inner circle radius ratio *)
()
Complex Classification
Control informative/redundant features:
let x, y = Nx_datasets.make_classification
~n_samples:1000
~n_features:20
~n_informative:15 (* Useful features *)
~n_redundant:5 (* Linear combinations *)
~n_classes:3
~n_clusters_per_class:2
()
Regression Datasets
Linear Regression
Generate data with controllable properties:
let x, y, coef_opt = Nx_datasets.make_regression
~n_samples:100
~n_features:5
~n_informative:3 (* Only 3 features affect output *)
~noise:10.0 (* Gaussian noise std dev *)
~coef:true (* Return true coefficients *)
()
Friedman Benchmarks
Standard non-linear regression problems:
(* Friedman #1: y = 10*sin(π*x1*x2) + 20*(x3-0.5)² + 10*x4 + 5*x5 + noise *)
let x, y = Nx_datasets.make_friedman1 ~n_samples:100 ()
Manifold Data
Swiss Roll
3D manifold for dimensionality reduction:
let x, color = Nx_datasets.make_swiss_roll ~n_samples:1000 () in
(* x shape: [1000, 3], color: [1000] - position along roll *)
S-Curve
Another 3D manifold:
let x, color = Nx_datasets.make_s_curve ~n_samples:1000 ()