diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/datasets/__init__.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/datasets/__init__.py | 41 |
1 files changed, 40 insertions, 1 deletions
diff --git a/tensorflow/contrib/learn/python/learn/datasets/__init__.py b/tensorflow/contrib/learn/python/learn/datasets/__init__.py index 80a0af5f52..a3521b4109 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/__init__.py +++ b/tensorflow/contrib/learn/python/learn/datasets/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -"""Module includes reference datasets and utilities to load datasets.""" +"""Dataset utilities and synthetic/reference datasets.""" from __future__ import absolute_import from __future__ import division @@ -26,6 +26,7 @@ import numpy as np from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.datasets import mnist +from tensorflow.contrib.learn.python.learn.datasets import synthetic from tensorflow.contrib.learn.python.learn.datasets import text_datasets # Export load_iris and load_boston. @@ -43,6 +44,12 @@ DATASETS = { 'dbpedia': text_datasets.load_dbpedia, } +# List of all synthetic datasets +SYNTHETIC = { + # All of these will return ['data', 'target'] -> base.Dataset + 'circles': synthetic.circles, + 'spirals': synthetic.spirals +} def load_dataset(name, size='small', test_with_fake_data=False): """Loads dataset by name. @@ -64,3 +71,35 @@ def load_dataset(name, size='small', test_with_fake_data=False): return DATASETS[name](size, test_with_fake_data) else: return DATASETS[name]() + + +def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs): + """Creates binary synthetic datasets + + Args: + name: str, name of the dataset to generate + n_samples: int, number of datapoints to generate + noise: float or None, standard deviation of the Gaussian noise added + seed: int or None, seed for noise + + Returns: + Shuffled features and labels for given synthetic dataset of type `base.Dataset` + + Raises: + ValueError: Raised if `name` not found + + Note: + - This is a generic synthetic data generator - individual generators might have more parameters! + See documentation for individual parameters + - Note that the `noise` parameter uses `numpy.random.normal` and depends on `numpy`'s seed + + TODO: + - Support multiclass datasets + - Need shuffling routine. Currently synthetic datasets are reshuffled to avoid train/test correlation, + but that hurts reprodusability + """ + # seed = kwargs.pop('seed', None) + if name not in SYNTHETIC: + raise ValueError('Synthetic dataset not found or not implemeted: %s' % name) + else: + return SYNTHETIC[name](n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs) |