aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/datasets/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/datasets/__init__.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/__init__.py41
1 files changed, 40 insertions, 1 deletions
diff --git a/tensorflow/contrib/learn/python/learn/datasets/__init__.py b/tensorflow/contrib/learn/python/learn/datasets/__init__.py
index 80a0af5f52..a3521b4109 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
-"""Module includes reference datasets and utilities to load datasets."""
+"""Dataset utilities and synthetic/reference datasets."""
from __future__ import absolute_import
from __future__ import division
@@ -26,6 +26,7 @@ import numpy as np
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.contrib.learn.python.learn.datasets import mnist
+from tensorflow.contrib.learn.python.learn.datasets import synthetic
from tensorflow.contrib.learn.python.learn.datasets import text_datasets
# Export load_iris and load_boston.
@@ -43,6 +44,12 @@ DATASETS = {
'dbpedia': text_datasets.load_dbpedia,
}
+# List of all synthetic datasets
+SYNTHETIC = {
+ # All of these will return ['data', 'target'] -> base.Dataset
+ 'circles': synthetic.circles,
+ 'spirals': synthetic.spirals
+}
def load_dataset(name, size='small', test_with_fake_data=False):
"""Loads dataset by name.
@@ -64,3 +71,35 @@ def load_dataset(name, size='small', test_with_fake_data=False):
return DATASETS[name](size, test_with_fake_data)
else:
return DATASETS[name]()
+
+
+def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs):
+ """Creates binary synthetic datasets
+
+ Args:
+ name: str, name of the dataset to generate
+ n_samples: int, number of datapoints to generate
+ noise: float or None, standard deviation of the Gaussian noise added
+ seed: int or None, seed for noise
+
+ Returns:
+ Shuffled features and labels for given synthetic dataset of type `base.Dataset`
+
+ Raises:
+ ValueError: Raised if `name` not found
+
+ Note:
+ - This is a generic synthetic data generator - individual generators might have more parameters!
+ See documentation for individual parameters
+ - Note that the `noise` parameter uses `numpy.random.normal` and depends on `numpy`'s seed
+
+ TODO:
+ - Support multiclass datasets
+ - Need shuffling routine. Currently synthetic datasets are reshuffled to avoid train/test correlation,
+ but that hurts reprodusability
+ """
+ # seed = kwargs.pop('seed', None)
+ if name not in SYNTHETIC:
+ raise ValueError('Synthetic dataset not found or not implemeted: %s' % name)
+ else:
+ return SYNTHETIC[name](n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs)