aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python')
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/base.py16
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py70
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py86
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py3
-rw-r--r--tensorflow/contrib/learn/python/learn/evaluable.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py497
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py250
-rw-r--r--tensorflow/contrib/learn/python/learn/trainable.py17
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/__init__.py1
10 files changed, 639 insertions, 313 deletions
diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py
index cdff6baf83..71978d4394 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/base.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/base.py
@@ -186,8 +186,8 @@ def _is_retriable(e):
@retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable)
-def urlretrieve_with_retry(url, filename):
- urllib.request.urlretrieve(url, filename)
+def urlretrieve_with_retry(url, filename=None):
+ return urllib.request.urlretrieve(url, filename)
def maybe_download(filename, work_directory, source_url):
@@ -205,11 +205,9 @@ def maybe_download(filename, work_directory, source_url):
gfile.MakeDirs(work_directory)
filepath = os.path.join(work_directory, filename)
if not gfile.Exists(filepath):
- with tempfile.NamedTemporaryFile() as tmpfile:
- temp_file_name = tmpfile.name
- urlretrieve_with_retry(source_url, temp_file_name)
- gfile.Copy(temp_file_name, filepath)
- with gfile.GFile(filepath) as f:
- size = f.size()
- print('Successfully downloaded', filename, size, 'bytes.')
+ temp_file_name, _ = urlretrieve_with_retry(source_url)
+ gfile.Copy(temp_file_name, filepath)
+ with gfile.GFile(filepath) as f:
+ size = f.size()
+ print('Successfully downloaded', filename, size, 'bytes.')
return filepath
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 91d900395b..2ec5a0659a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -330,8 +330,8 @@ class BaseEstimator(
# Features and labels TensorSignature objects.
# TODO(wicke): Rename these to something more descriptive
- self._features_info = None
- self._labels_info = None
+ self._features_info = {}
+ self._labels_info = {}
self._graph = None
@@ -641,28 +641,29 @@ class BaseEstimator(
return tensor_signature.create_example_parser_from_signatures(
self._features_info, examples_batch)
- def _check_inputs(self, features, labels):
- if self._features_info is not None:
- logging.debug('Given features: %s, required signatures: %s.',
- str(features), str(self._features_info))
- if not tensor_signature.tensors_compatible(features, self._features_info):
- raise ValueError('Features are incompatible with given information. '
+ def _check_inputs(self, features, labels, mode):
+ if mode in self._features_info:
+ logging.debug('Given features for mode %s: %s, required signatures: %s.',
+ mode, str(features), str(self._features_info[mode]))
+
+ if not tensor_signature.tensors_compatible(features, self._features_info[mode]):
+ raise ValueError('Features for mode %s are incompatible with given information. '
'Given features: %s, required signatures: %s.' %
- (str(features), str(self._features_info)))
+ (mode, str(features), str(self._features_info[mode])))
else:
- self._features_info = tensor_signature.create_signatures(features)
- logging.debug('Setting feature info to %s.', str(self._features_info))
+ self._features_info[mode] = tensor_signature.create_signatures(features)
+ logging.debug('Setting feature info for mode %s to %s.', mode, str(self._features_info[mode]))
if labels is not None:
- if self._labels_info is not None:
+ if mode in self._labels_info:
logging.debug('Given labels: %s, required signatures: %s.',
str(labels), str(self._labels_info))
- if not tensor_signature.tensors_compatible(labels, self._labels_info):
- raise ValueError('Labels are incompatible with given information. '
+ if not tensor_signature.tensors_compatible(labels, self._labels_info[mode]):
+ raise ValueError('Labels for mode %s are incompatible with given information. '
'Given labels: %s, required signatures: %s.' %
- (str(labels), str(self._labels_info)))
+ (mode, str(labels), str(self._labels_info[mode])))
else:
- self._labels_info = tensor_signature.create_signatures(labels)
- logging.debug('Setting labels info to %s', str(self._labels_info))
+ self._labels_info[mode] = tensor_signature.create_signatures(labels)
+ logging.debug('Setting labels info for mode %s to %s', mode, str(self._labels_info[mode]))
def _train_model(self,
input_fn,
@@ -699,8 +700,7 @@ class BaseEstimator(
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
- self._check_inputs(features, labels)
-
+ self._check_inputs(features, labels, model_fn_lib.ModeKeys.TRAIN)
# The default return type of _get_train_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_train_ops returns a
@@ -800,8 +800,7 @@ class BaseEstimator(
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
- self._check_inputs(features, labels)
-
+ self._check_inputs(features, labels, model_fn_lib.ModeKeys.EVAL)
# The default return type of _get_eval_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_eval_ops returns an
@@ -835,6 +834,29 @@ class BaseEstimator(
return result[0]
return result
+ def _set_infer_mode_feature_signature(self, features):
+ for mode in list(self._features_info.keys()):
+ if tensor_signature.tensors_compatible(features, self._features_info[mode]):
+ self._features_info[model_fn_lib.ModeKeys.INFER] = self._features_info[mode]
+ if mode in self._labels_info:
+ self._labels_info[model_fn_lib.ModeKeys.INFER] = (
+ self._labels_info[mode])
+ else:
+ self._labels_info[model_fn_lib.ModeKeys.INFER] = None
+ break
+
+ if model_fn_lib.ModeKeys.INFER not in self._features_info:
+ logging.warning('Features for mode %s are incompatible with neither train mode nor eval mode.'
+ ' Given features: %s' % (model_fn_lib.ModeKeys.INFER, str(features)))
+ for mode in list(self._features_info.keys()):
+ logging.warning('Whereas %s mode signatures: %s' % (mode, str(self._features_info[mode])))
+ self._check_inputs(features, None, model_fn_lib.ModeKeys.INFER)
+ if model_fn_lib.ModeKeys.TRAIN in self._labels_info:
+ logging.warning('Setting labels info for mode infer equal to that of labels info for train mode')
+ self._labels_info[model_fn_lib.ModeKeys.INFER] = self._labels_info[model_fn_lib.ModeKeys.TRAIN]
+ else:
+ self._labels_info[model_fn_lib.ModeKeys.INFER] = {}
+
def _infer_model(
self, input_fn, feed_fn=None, outputs=None, as_iterable=True):
# Check that model has been trained.
@@ -1134,8 +1156,10 @@ class Estimator(BaseEstimator):
Returns:
`ModelFnOps` object.
"""
+
+ self._set_infer_mode_feature_signature(features)
labels = tensor_signature.create_placeholders_from_signatures(
- self._labels_info)
+ self._labels_info[model_fn_lib.ModeKeys.INFER])
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER)
@experimental
@@ -1239,7 +1263,7 @@ class Estimator(BaseEstimator):
return export_dir
-# For time of deprecation x,y from Estimator allow direct access.
+# For time of deprecation x,y from Estimator allow direct access
# pylint: disable=protected-access
class SKCompat(sklearn.BaseEstimator):
"""Scikit learn wrapper for TensorFlow Learn Estimator."""
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index 5ebc299b57..3405005327 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -91,7 +91,18 @@ def boston_eval_fn():
0)
+def extract(data, key):
+ if isinstance(data, dict):
+ assert key in data
+ return data[key]
+ else:
+ return data
+
+
def linear_model_params_fn(features, labels, mode, params):
+ features = extract(features, 'input')
+ labels = extract(labels, 'labels')
+
assert mode in (
tf.contrib.learn.ModeKeys.TRAIN,
tf.contrib.learn.ModeKeys.EVAL,
@@ -106,6 +117,8 @@ def linear_model_params_fn(features, labels, mode, params):
def linear_model_fn(features, labels, mode):
+ features = extract(features, 'input')
+ labels = extract(labels, 'labels')
assert mode in (
tf.contrib.learn.ModeKeys.TRAIN,
tf.contrib.learn.ModeKeys.EVAL,
@@ -140,8 +153,8 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode):
def logistic_model_no_mode_fn(features, labels):
- if isinstance(labels, dict):
- labels = labels['labels']
+ features = extract(features, 'input')
+ labels = extract(labels, 'labels')
labels = tf.one_hot(labels, 3, 1, 0)
prediction, loss = (
tf.contrib.learn.models.logistic_regression_zero_init(features, labels)
@@ -346,6 +359,34 @@ class EstimatorTest(tf.test.TestCase):
with self.assertRaises(tf.contrib.learn.NotFittedError):
est.predict(x=boston.data)
+ def testContinueTrainingDictionaryInput(self):
+ boston = tf.contrib.learn.datasets.load_boston()
+ output_dir = tempfile.mkdtemp()
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
+ model_dir=output_dir)
+ boston_input = {'input': boston.data}
+ float64_target = {'labels': boston.target.astype(np.float64)}
+ est.fit(x=boston_input, y=float64_target, steps=50)
+ scores = est.evaluate(
+ x=boston_input,
+ y=float64_target,
+ metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
+ del est
+ # Create another estimator object with the same output dir.
+ est2 = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
+ model_dir=output_dir)
+
+ # Check we can evaluate and predict.
+ scores2 = est2.evaluate(
+ x=boston_input,
+ y=float64_target,
+ metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
+ self.assertAllClose(scores2['MSE'],
+ scores['MSE'])
+ predictions = np.array(list(est2.predict(x=boston_input)))
+ other_score = _sklearn.mean_squared_error(predictions, float64_target['labels'])
+ self.assertAllClose(other_score, scores['MSE'])
+
def testContinueTraining(self):
boston = tf.contrib.learn.datasets.load_boston()
output_dir = tempfile.mkdtemp()
@@ -405,6 +446,22 @@ class EstimatorTest(tf.test.TestCase):
self.assertTrue('global_step' in scores)
self.assertEqual(100, scores['global_step'])
+ def testBostonAllDictionaryInput(self):
+ boston = tf.contrib.learn.datasets.load_boston()
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
+ boston_input = {'input': boston.data}
+ float64_target = {'labels': boston.target.astype(np.float64)}
+ est.fit(x=boston_input, y=float64_target, steps=100)
+ scores = est.evaluate(
+ x=boston_input,
+ y=float64_target,
+ metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
+ predictions = np.array(list(est.predict(x=boston_input)))
+ other_score = _sklearn.mean_squared_error(predictions, boston.target)
+ self.assertAllClose(other_score, scores['MSE'])
+ self.assertTrue('global_step' in scores)
+ self.assertEqual(scores['global_step'], 100)
+
def testIrisAll(self):
iris = tf.contrib.learn.datasets.load_iris()
est = tf.contrib.learn.SKCompat(
@@ -428,6 +485,31 @@ class EstimatorTest(tf.test.TestCase):
self.assertTrue('global_step' in scores)
self.assertEqual(100, scores['global_step'])
+ def testIrisAllDictionaryInput(self):
+ iris = tf.contrib.learn.datasets.load_iris()
+ est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn)
+ iris_data = {'input': iris.data}
+ iris_target = {'labels': iris.target}
+ est.fit(iris_data, iris_target, steps=100)
+ scores = est.evaluate(
+ x=iris_data,
+ y=iris_target,
+ metrics={('accuracy', 'class'): tf.contrib.metrics.streaming_accuracy})
+ predictions = list(est.predict(x=iris_data))
+ predictions_class = list(est.predict(x=iris_data, outputs=['class']))
+ self.assertEqual(len(predictions), iris.target.shape[0])
+ classes_batch = np.array([p['class'] for p in predictions])
+ self.assertAllClose(
+ classes_batch,
+ np.array([p['class'] for p in predictions_class]))
+ self.assertAllClose(
+ classes_batch,
+ np.argmax(np.array([p['prob'] for p in predictions]), axis=1))
+ other_score = _sklearn.accuracy_score(iris.target, classes_batch)
+ self.assertAllClose(other_score, scores['accuracy'])
+ self.assertTrue('global_step' in scores)
+ self.assertEqual(scores['global_step'], 100)
+
def testIrisInputFn(self):
iris = tf.contrib.learn.datasets.load_iris()
est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index 45e430717f..c4a257b8d4 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -660,9 +660,10 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
"""
self._feature_columns = feature_columns
assert self._feature_columns
- self._optimizer = _get_default_optimizer(feature_columns)
if optimizer:
self._optimizer = _get_optimizer(optimizer)
+ else:
+ self._optimizer = _get_default_optimizer(feature_columns)
chief_hook = None
if (isinstance(optimizer, sdca_optimizer.SDCAOptimizer) and
diff --git a/tensorflow/contrib/learn/python/learn/evaluable.py b/tensorflow/contrib/learn/python/learn/evaluable.py
index 14cf5f01b8..aff0d70cd5 100644
--- a/tensorflow/contrib/learn/python/learn/evaluable.py
+++ b/tensorflow/contrib/learn/python/learn/evaluable.py
@@ -51,12 +51,14 @@ class Evaluable(object):
for which this evaluation was performed.
Args:
- x: Matrix of shape [n_samples, n_features...] containing the input samples
- for fitting the model. Can be iterator that returns arrays of features.
- If set, `input_fn` must be `None`.
+ x: Matrix of shape [n_samples, n_features...] or dictionary of many matrices
+ containing the input samples for fitting the model. Can be iterator that returns
+ arrays of features or dictionary of array of features. If set, `input_fn` must
+ be `None`.
y: Vector or matrix [n_samples] or [n_samples, n_outputs] containing the
label values (class labels in classification, real numbers in
- regression). Can be iterator that returns array of labels. If set,
+ regression) or dictionary of multiple vectors/matrices. Can be iterator
+ that returns array of targets or dictionary of array of targets. If set,
`input_fn` must be `None`. Note: For classification, label values must
be integers representing the class index (i.e. values from 0 to
n_classes-1).
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py
index 5781d88bb8..55be25336e 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions.py
@@ -299,10 +299,10 @@ def _monitored_train(graph,
while not super_sess.should_stop():
_, loss = super_sess.run([train_op, loss_op], feed_fn() if feed_fn else
None)
+
summary_io.SummaryWriterCache.clear()
return loss
-
# TODO(ispir): Deprecate train in favor of supervised_train
def train(graph,
output_dir,
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
index 2ce11e813f..f665ff7644 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
@@ -36,27 +36,49 @@ from tensorflow.python.platform import tf_logging as logging
# pylint: disable=g-multiple-import,g-bad-import-order
from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels
from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels
+
+
# pylint: enable=g-multiple-import,g-bad-import-order
def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None):
"""Returns shape for input and output of the data feeder."""
+ x_is_dict, y_is_dict = isinstance(x_shape, dict), y_shape is not None and isinstance(y_shape, dict)
+ if y_is_dict and n_classes is not None:
+ assert (isinstance(n_classes, dict))
+
if batch_size is None:
- batch_size = x_shape[0]
+ batch_size = list(x_shape.values())[0][0] if x_is_dict else x_shape[0]
elif batch_size <= 0:
raise ValueError('Invalid batch_size %d.' % batch_size)
- x_shape = list(x_shape[1:]) if len(x_shape) > 1 else [1]
- input_shape = [batch_size] + x_shape
+
+ if x_is_dict:
+ input_shape = {}
+ for k, v in list(x_shape.items()):
+ input_shape[k] = [batch_size] + (list(v[1:]) if len(v) > 1 else [1])
+ else:
+ x_shape = list(x_shape[1:]) if len(x_shape) > 1 else [1]
+ input_shape = [batch_size] + x_shape
+
if y_shape is None:
return input_shape, None, batch_size
- y_shape = list(y_shape[1:]) if len(y_shape) > 1 else []
- # Skip first dimension if it is 1.
- if y_shape and y_shape[0] == 1:
- y_shape = y_shape[1:]
- if n_classes is not None and n_classes > 1:
- output_shape = [batch_size] + y_shape + [n_classes]
+
+ def out_el_shape(out_shape, num_classes):
+ out_shape = list(out_shape[1:]) if len(out_shape) > 1 else []
+ # Skip first dimension if it is 1.
+ if out_shape and out_shape[0] == 1:
+ out_shape = out_shape[1:]
+ if num_classes is not None and num_classes > 1:
+ return [batch_size] + out_shape + [num_classes]
+ else:
+ return [batch_size] + out_shape
+
+ if not y_is_dict:
+ output_shape = out_el_shape(y_shape, n_classes)
else:
- output_shape = [batch_size] + y_shape
+ output_shape = dict([(k, out_el_shape(v, n_classes[k] if n_classes is not None and k in n_classes else None))
+ for k, v in list(y_shape.items())])
+
return input_shape, output_shape, batch_size
@@ -78,15 +100,18 @@ def _is_iterable(x):
def setup_train_data_feeder(
- x, y, n_classes, batch_size=None, shuffle=True, epochs=None):
+ x, y, n_classes, batch_size=None, shuffle=True, epochs=None):
"""Create data feeder, to sample inputs from dataset.
If `x` and `y` are iterators, use `StreamingDataFeeder`.
Args:
- x: numpy, pandas or Dask matrix or iterable.
- y: numpy, pandas or Dask array or iterable.
- n_classes: number of classes.
+ x: numpy, pandas or Dask matrix or dictionary of aforementioned. Also
+ supports iterables.
+ y: numpy, pandas or Dask array or dictionary of aforementioned. Also supports
+ iterables.
+ n_classes: number of classes. Must be None or same type as y. In case, `y` is `dict`
+ (or iterable which returns dict) such that `n_classes[key] = n_classes for y[key]`
batch_size: size to split data into parts. Must be >= 1.
shuffle: Whether to shuffle the inputs.
epochs: Number of epochs to run.
@@ -102,7 +127,7 @@ def setup_train_data_feeder(
# pylint: disable=g-import-not-at-top
import dask.dataframe as dd
if (isinstance(x, (dd.Series, dd.DataFrame)) and
- (y is None or isinstance(y, (dd.Series, dd.DataFrame)))):
+ (y is None or isinstance(y, (dd.Series, dd.DataFrame)))):
data_feeder_cls = DaskDataFeeder
else:
data_feeder_cls = DataFeeder
@@ -115,31 +140,54 @@ def setup_train_data_feeder(
'streaming learning to work.')
return StreamingDataFeeder(x, y, n_classes, batch_size)
return data_feeder_cls(
- x, y, n_classes, batch_size, shuffle=shuffle, epochs=epochs)
+ x, y, n_classes, batch_size, shuffle=shuffle, epochs=epochs)
def _batch_data(x, batch_size=None):
if (batch_size is not None) and (batch_size <= 0):
raise ValueError('Invalid batch_size %d.' % batch_size)
- chunk = []
+
+ x_first_el = six.next(x)
+ x = itertools.chain([x_first_el], x)
+
+ chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(x_first_el, dict) else []
+ chunk_filled = False
for data in x:
- chunk.append(data)
- if (batch_size is not None) and (len(chunk) >= batch_size):
- yield np.matrix(chunk)
- chunk = []
- yield np.matrix(chunk)
+ if isinstance(data, dict):
+ for k, v in list(data.items()):
+ chunk[k].append(v)
+ if (batch_size is not None) and (len(chunk[k]) >= batch_size):
+ chunk[k] = np.matrix(chunk[k])
+ chunk_filled = True
+ if chunk_filled:
+ yield chunk
+ chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(x_first_el, dict) else []
+ chunk_filled = False
+ else:
+ chunk.append(data)
+ if (batch_size is not None) and (len(chunk) >= batch_size):
+ yield np.matrix(chunk)
+ chunk = []
+
+ if isinstance(x_first_el, dict):
+ for k, v in list(data.items()):
+ chunk[k] = np.matrix(chunk[k])
+ yield chunk
+ else:
+ yield np.matrix(chunk)
def setup_predict_data_feeder(x, batch_size=None):
"""Returns an iterable for feeding into predict step.
Args:
- x: numpy, pandas, Dask array or iterable.
- batch_size: Size of batches to split data into.
- If `None`, returns one batch of full size.
+ x: numpy, pandas, Dask array or dictionary of aforementioned. Also supports
+ iterable.
+ batch_size: Size of batches to split data into. If `None`, returns one
+ batch of full size.
Returns:
- List or iterator of parts of data to predict on.
+ List or iterator (or dictionary thereof) of parts of data to predict on.
Raises:
ValueError: if `batch_size` <= 0.
@@ -211,7 +259,7 @@ def _access(data, iloc):
def _check_dtype(dtype):
if dtypes.as_dtype(dtype) == dtypes.float64:
logging.warn(
- 'float64 is not supported by many models, consider casting to float32.')
+ 'float64 is not supported by many models, consider casting to float32.')
return dtype
@@ -219,63 +267,85 @@ class DataFeeder(object):
"""Data feeder is an example class to sample data for TF trainer."""
def __init__(
- self, x, y, n_classes, batch_size=None, shuffle=True, random_state=None,
- epochs=None):
+ self, x, y, n_classes, batch_size=None, shuffle=True, random_state=None,
+ epochs=None):
"""Initializes a DataFeeder instance.
Args:
- x: Feature Nd numpy matrix of shape `[n_samples, n_features, ...]`.
- y: Label vector, either floats for regression or class id for
- classification. If matrix, will consider as a sequence
- of labels. Can be `None` for unsupervised setting.
+ x: One feature sample which can either Nd numpy matrix of shape
+ `[n_samples, n_features, ...]` or dictionary of Nd numpy matrix.
+ y: label vector, either floats for regression or class id for
+ classification. If matrix, will consider as a sequence of labels.
+ Can be `None` for unsupervised setting. Also supports dictionary of
+ labels.
n_classes: Number of classes, 0 and 1 are considered regression, `None`
- will pass through the input labels without one-hot conversion.
- batch_size: Mini-batch size to accumulate.
+ will pass through the input labels without one-hot conversion. Also, if
+ `y` is `dict`, then `n_classes` must be `dict` such that
+ `n_classes[key] = n_classes for label y[key]`, `None` otherwise.
+ batch_size: Mini-batch size to accumulate samples in one mini batch.
shuffle: Whether to shuffle `x`.
random_state: Numpy `RandomState` object to reproduce sampling.
epochs: Number of times to iterate over input data before raising
`StopIteration` exception.
Attributes:
- x: Input features.
- y: Input label.
+ x: Input features (ndarray or dictionary of ndarrays).
+ y: Input label (ndarray or dictionary of ndarrays).
n_classes: Number of classes (if `None`, pass through indices without
one-hot conversion).
batch_size: Mini-batch size to accumulate.
- input_shape: Shape of the input.
- output_shape: Shape of the output.
- input_dtype: DType of input.
- output_dtype: DType of output.
+ input_shape: Shape of the input (or dictionary of shapes).
+ output_shape: Shape of the output (or dictionary of shapes).
+ input_dtype: DType of input (or dictionary of shapes).
+ output_dtype: DType of output (or dictionary of shapes.
"""
- self._x = check_array(x, dtype=x.dtype)
- # self.n_classes is None means we're passing in raw label indices.
- y_dtype = (
- np.int64 if n_classes is not None and n_classes > 1 else np.float32)
+ x_is_dict, y_is_dict = isinstance(x, dict), y is not None and isinstance(y, dict)
+ if isinstance(y, list):
+ y = np.array(y)
+
+ self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())]) if x_is_dict else check_array(x, x.dtype)
+ self._y = None if y is None else \
+ dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if x_is_dict else check_array(y, y.dtype)
+
+ # self.n_classes is not None means we're converting raw target indices to one-hot.
if n_classes is not None:
- self._y = (None if y is None else check_array(y, dtype=y_dtype))
- elif isinstance(y, list):
- self._y = np.array(y)
- else:
- self._y = y
+ if not y_is_dict:
+ y_dtype = (np.int64 if n_classes is not None and n_classes > 1 else np.float32)
+ self._y = (None if y is None else check_array(y, dtype=y_dtype))
+
self.n_classes = n_classes
self.max_epochs = epochs
+
+ x_shape = dict([(k, v.shape) for k, v in list(self._x.items())]) if x_is_dict else self._x.shape
+ y_shape = dict(
+ [(k, v.shape) for k, v in list(self._y.items())]) if y_is_dict else None if y is None else self._y.shape
+
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
- self._x.shape, None if self._y is None else self._y.shape, n_classes,
- batch_size)
+ x_shape, y_shape, n_classes, batch_size)
+
# Input dtype matches dtype of x.
- self._input_dtype = _check_dtype(self._x.dtype)
- # self.n_classes is None means we're passing in raw label indices
- if n_classes is not None or self._y is None:
- self._output_dtype = np.float32
- else:
- self._output_dtype = _check_dtype(self._y.dtype)
+ self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())]) if x_is_dict \
+ else _check_dtype(self._x.dtype)
+
+ # note: self._output_dtype = np.float32 when y is None
+ self._output_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) if y_is_dict \
+ else _check_dtype(self._y.dtype) if y is not None else np.float32
+
+ # self.n_classes is None means we're passing in raw target indices
+ if n_classes is not None and y_is_dict:
+ for key in list(n_classes.keys()):
+ if key in self._output_dtype:
+ self._output_dtype[key] = np.float32
+
self._shuffle = shuffle
self.random_state = np.random.RandomState(
- 42) if random_state is None else random_state
+ 42) if random_state is None else random_state
+
+ num_samples = list(self._x.values())[0].shape[0] if x_is_dict else self._x.shape[0]
if self._shuffle:
- self.indices = self.random_state.permutation(self._x.shape[0])
+ self.indices = self.random_state.permutation(num_samples)
else:
- self.indices = np.array(range(self._x.shape[0]))
+ self.indices = np.array(range(num_samples))
self.offset = 0
self.epoch = 0
self._epoch_placeholder = None
@@ -320,19 +390,27 @@ class DataFeeder(object):
Returns:
Two placeholders for inputs and outputs.
"""
- input_shape = [None] + self.input_shape[1:]
- self._input_placeholder = array_ops.placeholder(
- dtypes.as_dtype(self._input_dtype),
- input_shape,
- name='input')
- if self.output_shape is None:
- self._output_placeholder = None
- else:
- output_shape = [None] + self.output_shape[1:]
- self._output_placeholder = array_ops.placeholder(
- dtypes.as_dtype(self._output_dtype),
- output_shape,
- name='output')
+
+ def get_placeholder(shape, dtype, name_prepend):
+ if shape is None:
+ return None
+ if isinstance(shape, dict):
+ placeholder = {}
+ for key in list(shape.keys()):
+ placeholder[key] = array_ops.placeholder(
+ dtypes.as_dtype(dtype[key]),
+ [None] + shape[key][1:],
+ name=name_prepend + '_' + key
+ )
+ else:
+ placeholder = array_ops.placeholder(
+ dtypes.as_dtype(dtype),
+ [None] + shape[1:],
+ name=name_prepend)
+ return placeholder
+
+ self._input_placeholder = get_placeholder(self.input_shape, self._input_dtype, 'input')
+ self._output_placeholder = get_placeholder(self.output_shape, self._output_dtype, 'output')
return self._input_placeholder, self._output_placeholder
def set_placeholders(self, input_placeholder, output_placeholder):
@@ -342,21 +420,21 @@ class DataFeeder(object):
input_placeholder: Placeholder for `x` variable. Should match shape
of the examples in the x dataset.
output_placeholder: Placeholder for `y` variable. Should match
- shape of the examples in the y dataset. Can be None.
+ shape of the examples in the y dataset. Can be `None`.
"""
self._input_placeholder = input_placeholder
self._output_placeholder = output_placeholder
def get_feed_params(self):
- """Function returns a dict with data feed params while training.
+ """Function returns a `dict` with data feed params while training.
Returns:
- A dict with data feed params while training.
+ A `dict` with data feed params while training.
"""
return {
- 'epoch': self.epoch,
- 'offset': self.offset,
- 'batch_size': self._batch_size
+ 'epoch': self.epoch,
+ 'offset': self.offset,
+ 'batch_size': self._batch_size
}
def get_feed_dict_fn(self):
@@ -364,8 +442,35 @@ class DataFeeder(object):
Returns:
A function that when called samples a random subset of batch size
- from x and y.
+ from `x` and `y`.
"""
+ x_is_dict, y_is_dict = isinstance(self._x, dict), self._y is not None and isinstance(self._y, dict)
+
+ # Assign input features from random indices.
+ def extract(data, indices):
+ return (np.array(_access(data, indices)).reshape((indices.shape[0], 1))
+ if len(data.shape) == 1 else _access(data, indices))
+
+ # assign labels from random indices
+ def assign_label(data, shape, dtype, n_classes, indices):
+ shape[0] = indices.shape[0]
+ out = np.zeros(shape, dtype=dtype)
+ for i in xrange(out.shape[0]):
+ sample = indices[i]
+ # self.n_classes is None means we're passing in raw target indices
+ if n_classes is None:
+ out[i] = _access(data, sample)
+ else:
+ if n_classes > 1:
+ if len(shape) == 2:
+ out.itemset((i, int(_access(data, sample))), 1.0)
+ else:
+ for idx, value in enumerate(_access(data, sample)):
+ out.itemset(tuple([i, idx, value]), 1.0)
+ else:
+ out[i] = _access(data, sample)
+ return out
+
def _feed_dict_fn():
"""Function that samples data into given placeholders."""
if self.max_epochs is not None and self.epoch + 1 > self.max_epochs:
@@ -376,20 +481,19 @@ class DataFeeder(object):
feed_dict[self._epoch_placeholder.name] = [self.epoch]
# Take next batch of indices.
- end = min(self._x.shape[0], self.offset + self._batch_size)
+ x_len = list(self._x.values())[0].shape[0] if x_is_dict else self._x.shape[0]
+ end = min(x_len, self.offset + self._batch_size)
batch_indices = self.indices[self.offset:end]
- # Assign input features from random indices.
- inp = (
- np.array(_access(self._x, batch_indices)).reshape(
- (batch_indices.shape[0], 1))
- if len(self._x.shape) == 1 else _access(self._x, batch_indices))
- feed_dict[self._input_placeholder.name] = inp
+ # adding input placeholder
+ feed_dict.update(
+ dict([(self._input_placeholder[k].name, extract(v, batch_indices)) for k, v in list(self._x.items())])
+ if x_is_dict else {self._input_placeholder.name: extract(self._x, batch_indices)})
# move offset and reset it if necessary
self.offset += self._batch_size
- if self.offset >= self._x.shape[0]:
- self.indices = self.random_state.permutation(self._x.shape[0])
+ if self.offset >= x_len:
+ self.indices = self.random_state.permutation(x_len) if self._shuffle else np.array(range(x_len))
self.offset = 0
self.epoch += 1
@@ -397,24 +501,18 @@ class DataFeeder(object):
if self._output_placeholder is None:
return feed_dict
- # assign labels from random indices
- self.output_shape[0] = batch_indices.shape[0]
- out = np.zeros(self.output_shape, dtype=self._output_dtype)
- for i in xrange(out.shape[0]):
- sample = batch_indices[i]
- # self.n_classes is None means we're passing in raw label indices
- if self.n_classes is None:
- out[i] = _access(self._y, sample)
- else:
- if self.n_classes > 1:
- if len(self.output_shape) == 2:
- out.itemset((i, int(_access(self._y, sample))), 1.0)
- else:
- for idx, value in enumerate(_access(self._y, sample)):
- out.itemset(tuple([i, idx, value]), 1.0)
- else:
- out[i] = _access(self._y, sample)
- feed_dict[self._output_placeholder.name] = out
+ # adding output placeholders
+ if y_is_dict:
+ for k, v in list(self._y.items()):
+ n_classes = (
+ self.n_classes[k] if k in self.n_classes else None) if self.n_classes is not None else None
+ shape, dtype = self.output_shape[k], self._output_dtype[k]
+ feed_dict.update(
+ {self._output_placeholder[k].name: assign_label(v, shape, dtype, n_classes, batch_indices)})
+ else:
+ shape, dtype, n_classes = self.output_shape, self._output_dtype, self.n_classes
+ feed_dict.update(
+ {self._output_placeholder.name: assign_label(self._y, shape, dtype, n_classes, batch_indices)})
return feed_dict
@@ -433,21 +531,29 @@ class StreamingDataFeeder(DataFeeder):
"""Initializes a StreamingDataFeeder instance.
Args:
- x: iterator that returns for each element, returns features.
- y: iterator that returns for each element, returns 1 or many classes /
- regression values.
- n_classes: indicator of how many classes the label has.
- batch_size: Mini batch size to accumulate.
+ x: iterator each element of which returns one feature sample. Sample can
+ be a Nd numpy matrix or dictionary of Nd numpy matrices.
+ y: iterator each element of which returns one label sample. Sample can be
+ a Nd numpy matrix or dictionary of Nd numpy matrices with 1 or many
+ classes regression values.
+ n_classes: indicator of how many classes the corresponding label sample
+ has for the purposes of one-hot conversion of label. In case where `y`
+ is a dictionary, `n_classes` must be dictionary (with same keys as `y`)
+ of how many classes there are in each label in `y`. If key is
+ present in `y` and missing in `n_classes`, the value is assumed `None`
+ and no one-hot conversion will be applied to the label with that key.
+ batch_size: Mini batch size to accumulate samples in one batch. If set
+ `None`, then assumes that iterator to return already batched element.
Attributes:
- x: input features.
- y: input label.
+ x: input features (or dictionary of input features).
+ y: input label (or dictionary of output features).
n_classes: number of classes.
batch_size: mini batch size to accumulate.
- input_shape: shape of the input.
- output_shape: shape of the output.
- input_dtype: dtype of input.
- output_dtype: dtype of output.
+ input_shape: shape of the input (can be dictionary depending on `x`).
+ output_shape: shape of the output (can be dictionary depending on `y`).
+ input_dtype: dtype of input (can be dictionary depending on `x`).
+ output_dtype: dtype of output (can be dictionary depending on `y`).
"""
# pylint: disable=invalid-name,super-init-not-called
x_first_el = six.next(x)
@@ -459,25 +565,48 @@ class StreamingDataFeeder(DataFeeder):
y_first_el = None
self._y = None
self.n_classes = n_classes
- x_first_el = ops.convert_to_tensor(x_first_el)
- y_first_el = ops.convert_to_tensor(y_first_el) if y is not None else None
- self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
- [1] + list(x_first_el.get_shape()),
- [1] + list(y_first_el.get_shape()) if y is not None else None,
- n_classes,
- batch_size)
- self._input_dtype = _check_dtype(x_first_el.dtype).as_numpy_dtype
+
+ x_is_dict, y_is_dict = isinstance(x_first_el, dict), y is not None and isinstance(y_first_el, dict)
+ if y_is_dict and n_classes is not None:
+ assert (isinstance(n_classes, dict))
+
+ # extract shapes for first_elements
+ x_first_el_shape = dict([(k, [1] + list(v.shape)) for k, v in list(x_first_el.items())]) if x_is_dict \
+ else [1] + list(x_first_el.shape)
+
+ y_first_el_shape = dict([(k, [1] + list(v.shape)) for k, v in list(y_first_el.items())]) if y_is_dict \
+ else ([1] + list(y_first_el[0].shape if isinstance(y_first_el, list) else y_first_el.shape)
+ if y is not None else None)
+
+ self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(x_first_el_shape, y_first_el_shape,
+ n_classes, batch_size)
+
+ # Input dtype of x_first_el.
+ self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(x_first_el.items())]) if x_is_dict \
+ else _check_dtype(x_first_el.dtype)
+
+ # Output dtype of y_first_el.
+ def check_y_dtype(el):
+ if isinstance(el, list) or isinstance(el, np.ndarray):
+ if isinstance(el, np.ndarray) and el.ndim == 0:
+ return el.dtype
+ else:
+ return _check_dtype(np.dtype(type(el[0])))
+ else:
+ return _check_dtype(np.dtype(type(el)))
+
# Output types are floats, due to both softmaxes and regression req.
- if n_classes is not None and n_classes > 0:
+ if n_classes is not None and (y is None or not y_is_dict) and n_classes > 0:
self._output_dtype = np.float32
- elif y is not None:
- self._output_dtype = _check_dtype(y_first_el.dtype).as_numpy_dtype
+ else:
+ self._output_dtype = dict([(k, check_y_dtype(v)) for k, v in list(y_first_el.items())]) if y_is_dict \
+ else (check_y_dtype(y_first_el) if y is not None else None)
def get_feed_params(self):
- """Function returns a dict with data feed params while training.
+ """Function returns a `dict` with data feed params while training.
Returns:
- A dict with data feed params while training.
+ A `dict` with data feed params while training.
"""
return {'batch_size': self._batch_size}
@@ -494,50 +623,76 @@ class StreamingDataFeeder(DataFeeder):
"""Samples data and provides it to placeholders.
Returns:
- Dict of input and output tensors.
+ `dict` of input and output tensors.
"""
+
+ def init_array(shape, dtype):
+ if shape is None:
+ return None
+ else:
+ return dict([(k, np.zeros(shape[k], dtype[k])) for k in list(shape.keys())]) if isinstance(shape, dict) else \
+ np.zeros(shape, dtype=dtype)
+
+ def put_data_array(dest, index, source=None, n_classes=None):
+ if source is None:
+ dest = dest[:index, :]
+ elif n_classes is not None and n_classes > 1:
+ if len(self.output_shape) == 2:
+ dest.itemset((index, source), 1.0)
+ else:
+ for idx, value in enumerate(source):
+ dest.itemset(tuple([index, idx, value]), 1.0)
+ else:
+ if len(dest.shape) > 1:
+ dest[index, :] = source
+ else:
+ dest[index] = source[0] if isinstance(source, list) else source
+ return dest
+
+ def put_data_array_or_dict(holder, index, data=None, n_classes=None):
+ if holder is None:
+ return None
+ if isinstance(holder, dict):
+ assert (isinstance(data, dict))
+ for k, v in list(holder.items()):
+ num_classes = n_classes[k] if (n_classes is not None and k in n_classes) else None
+ holder[k] = put_data_array(holder[k], index, data[k], num_classes)
+ else:
+ holder = put_data_array(holder, index, data, n_classes)
+ return holder
+
if self.stopped:
raise StopIteration
- try:
- inp = np.zeros(self.input_shape, dtype=self._input_dtype)
- except TypeError as exc:
- raise TypeError('Unrecognized dtype: {}. {}'.format(
- self._input_dtype, exc))
- if self._y is not None:
- out = np.zeros(self.output_shape, dtype=self._output_dtype)
+
+ inp = init_array(self.input_shape, self._input_dtype)
+ out = init_array(self.output_shape, self._output_dtype)
+
for i in xrange(self._batch_size):
# Add handling when queue ends.
try:
- inp[i, :] = six.next(self._x)
+ next_inp = six.next(self._x)
+ inp = put_data_array_or_dict(inp, i, next_inp, None)
except StopIteration:
self.stopped = True
if i == 0:
raise
- inp = inp[:i, :]
- if self._y is not None:
- out = out[:i]
+ inp = put_data_array_or_dict(inp, i, None, None)
+ out = put_data_array_or_dict(out, i, None, None)
break
if self._y is not None:
- y = six.next(self._y)
- if self.n_classes is not None and self.n_classes > 1:
- if len(self.output_shape) == 2:
- out.itemset((i, y), 1.0)
- else:
- for idx, value in enumerate(y):
- out.itemset(tuple([i, idx, value]), 1.0)
- else:
- # The y itertor can sometimes return scalars or singleton lists.
- try:
- out[i] = y
- except ValueError as _:
- assert len(y) == 1, ('Expected singleton label, got {}'
- .format(repr(y)))
- out[i] = y[0]
- if self._y is None:
- return {self._input_placeholder.name: inp}
- return {self._input_placeholder.name: inp,
- self._output_placeholder.name: out}
+ next_out = six.next(self._y)
+ out = put_data_array_or_dict(out, i, next_out, self.n_classes)
+
+ # creating feed_dict
+ feed_dict = dict([(self._input_placeholder[k].name, inp[k]) for k in list(self._input_placeholder.keys())]) if \
+ isinstance(inp, dict) else {self._input_placeholder.name: inp}
+ if self._y is not None:
+ feed_dict.update(
+ dict([(self._output_placeholder[k].name, out[k]) for k in list(self._output_placeholder.keys())]) \
+ if isinstance(out, dict) else {self._output_placeholder.name: out})
+
+ return feed_dict
return _feed_dict_fn
@@ -575,6 +730,10 @@ class DaskDataFeeder(object):
input_dtype: dtype of input.
output_dtype: dtype of output.
"""
+
+ if isinstance(x, dict) or isinstance(y, dict):
+ raise ValueError("DaskDataFeeder does not support dictionaries at the moment.")
+
# pylint: disable=invalid-name,super-init-not-called
import dask.dataframe as dd # pylint: disable=g-import-not-at-top
# TODO(terrytangyuan): check x and y dtypes in dask_io like pandas
@@ -601,7 +760,7 @@ class DaskDataFeeder(object):
self._shuffle = shuffle
self.epochs = epochs
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
- x_shape, y_shape, n_classes, batch_size)
+ x_shape, y_shape, n_classes, batch_size)
self.sample_fraction = self._batch_size / float(x_count)
self._input_dtype = _check_dtype(self._x.dtypes[0])
self._output_dtype = _check_dtype(self._y.dtypes[self._y_columns])
@@ -611,10 +770,10 @@ class DaskDataFeeder(object):
self.random_state = random_state
def get_feed_params(self):
- """Function returns a dict with data feed params while training.
+ """Function returns a `dict` with data feed params while training.
Returns:
- A dict with data feed params while training.
+ A `dict` with data feed params while training.
"""
return {'batch_size': self._batch_size}
@@ -629,13 +788,14 @@ class DaskDataFeeder(object):
A function that when called samples a random subset of batch size
from x and y.
"""
+
def _feed_dict_fn():
"""Samples data and provides it to placeholders."""
# TODO(ipolosukhin): option for with/without replacement (dev version of
# dask)
sample = self.df.random_split(
- [self.sample_fraction, 1 - self.sample_fraction],
- random_state=self.random_state)
+ [self.sample_fraction, 1 - self.sample_fraction],
+ random_state=self.random_state)
inp = extract_pandas_matrix(sample[0][self._x_columns].compute()).tolist()
out = extract_pandas_matrix(sample[0][self._y_columns].compute())
# convert to correct dtype
@@ -650,4 +810,5 @@ class DaskDataFeeder(object):
encoded_out[np.arange(out.size), out] = 1
return {input_placeholder.name: inp,
output_placeholder.name: encoded_out}
+
return _feed_dict_fn
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
index fe675e3122..828db45757 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
@@ -32,150 +32,200 @@ class DataFeederTest(tf.test.TestCase):
# pylint: disable=undefined-variable
"""Tests for `DataFeeder`."""
+ def _wrap_dict(self, data, prepend=''):
+ return {prepend+'1': data, prepend+'2': data}
+
def _assert_raises(self, input_data):
with self.assertRaisesRegexp(TypeError, 'annot convert'):
data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1)
def test_input_uint32(self):
- self._assert_raises(np.matrix([[1, 2], [3, 4]], dtype=np.uint32))
+ data = np.matrix([[1, 2], [3, 4]], dtype=np.uint32)
+ self._assert_raises(data)
+ self._assert_raises(self._wrap_dict(data))
def test_input_uint64(self):
- self._assert_raises(np.matrix([[1, 2], [3, 4]], dtype=np.uint64))
+ data = np.matrix([[1, 2], [3, 4]], dtype=np.uint64)
+ self._assert_raises(data)
+ self._assert_raises(self._wrap_dict(data))
def _assert_dtype(self, expected_np_dtype, expected_tf_dtype, input_data):
feeder = data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1)
- self.assertEqual(expected_np_dtype, feeder.input_dtype)
+ if isinstance(input_data, dict):
+ for k, v in list(feeder.input_dtype.items()):
+ self.assertEqual(expected_np_dtype, v)
+ else:
+ self.assertEqual(expected_np_dtype, feeder.input_dtype)
with tf.Graph().as_default() as g, self.test_session(g):
inp, _ = feeder.input_builder()
- self.assertEqual(expected_tf_dtype, inp.dtype)
+ if isinstance(inp, dict):
+ for k, v in list(inp.items()):
+ self.assertEqual(expected_tf_dtype, v.dtype)
+ else:
+ self.assertEqual(expected_tf_dtype, inp.dtype)
def test_input_int8(self):
- self._assert_dtype(
- np.int8, tf.int8, np.matrix([[1, 2], [3, 4]], dtype=np.int8))
+ data = np.matrix([[1, 2], [3, 4]], dtype=np.int8)
+ self._assert_dtype(np.int8, tf.int8, data)
+ self._assert_dtype(np.int8, tf.int8, self._wrap_dict(data))
def test_input_int16(self):
- self._assert_dtype(
- np.int16, tf.int16, np.matrix([[1, 2], [3, 4]], dtype=np.int16))
+ data = np.matrix([[1, 2], [3, 4]], dtype=np.int16)
+ self._assert_dtype(np.int16, tf.int16, data)
+ self._assert_dtype(np.int16, tf.int16, self._wrap_dict(data))
def test_input_int32(self):
- self._assert_dtype(
- np.int32, tf.int32, np.matrix([[1, 2], [3, 4]], dtype=np.int32))
+ data = np.matrix([[1, 2], [3, 4]], dtype=np.int32)
+ self._assert_dtype(np.int32, tf.int32, data)
+ self._assert_dtype(np.int32, tf.int32, self._wrap_dict(data))
def test_input_int64(self):
- self._assert_dtype(
- np.int64, tf.int64, np.matrix([[1, 2], [3, 4]], dtype=np.int64))
+ data = np.matrix([[1, 2], [3, 4]], dtype=np.int64)
+ self._assert_dtype(np.int64, tf.int64, data)
+ self._assert_dtype(np.int64, tf.int64, self._wrap_dict(data))
def test_input_uint8(self):
- self._assert_dtype(
- np.uint8, tf.uint8, np.matrix([[1, 2], [3, 4]], dtype=np.uint8))
+ data = np.matrix([[1, 2], [3, 4]], dtype=np.uint8)
+ self._assert_dtype(np.uint8, tf.uint8, data)
+ self._assert_dtype(np.uint8, tf.uint8, self._wrap_dict(data))
def test_input_uint16(self):
- self._assert_dtype(
- np.uint16, tf.uint16, np.matrix([[1, 2], [3, 4]], dtype=np.uint16))
+ data = np.matrix([[1, 2], [3, 4]], dtype=np.uint16)
+ self._assert_dtype(np.uint16, tf.uint16, data)
+ self._assert_dtype(np.uint16, tf.uint16, self._wrap_dict(data))
def test_input_float16(self):
- self._assert_dtype(
- np.float16, tf.float16, np.matrix([[1, 2], [3, 4]], dtype=np.float16))
+ data = np.matrix([[1, 2], [3, 4]], dtype=np.float16)
+ self._assert_dtype(np.float16, tf.float16, data)
+ self._assert_dtype(np.float16, tf.float16, self._wrap_dict(data))
def test_input_float32(self):
- self._assert_dtype(
- np.float32, tf.float32, np.matrix([[1, 2], [3, 4]], dtype=np.float32))
+ data = np.matrix([[1, 2], [3, 4]], dtype=np.float32)
+ self._assert_dtype(np.float32, tf.float32, data)
+ self._assert_dtype(np.float32, tf.float32, self._wrap_dict(data))
def test_input_float64(self):
- self._assert_dtype(
- np.float64, tf.float64, np.matrix([[1, 2], [3, 4]], dtype=np.float64))
+ data = np.matrix([[1, 2], [3, 4]], dtype=np.float64)
+ self._assert_dtype(np.float64, tf.float64, data)
+ self._assert_dtype(np.float64, tf.float64, self._wrap_dict(data))
def test_input_bool(self):
- self._assert_dtype(
- np.bool, tf.bool,
- np.array([[False for _ in xrange(2)] for _ in xrange(2)]))
+ data = np.array([[False for _ in xrange(2)] for _ in xrange(2)])
+ self._assert_dtype(np.bool, tf.bool, data)
+ self._assert_dtype(np.bool, tf.bool, self._wrap_dict(data))
def test_input_string(self):
input_data = np.array([['str%d' % i for i in xrange(2)] for _ in xrange(2)])
self._assert_dtype(input_data.dtype, tf.string, input_data)
+ self._assert_dtype(input_data.dtype, tf.string, self._wrap_dict(input_data))
+
+ def _assertAllClose(self, src, dest, src_key_of=None, src_prop=None):
+ def func(x):
+ val = getattr(x, src_prop) if src_prop else x
+ return val if src_key_of is None else src_key_of[val]
+ if isinstance(src, dict):
+ for k in list(src.keys()):
+ self.assertAllClose(func(src[k]), dest)
+ else:
+ self.assertAllClose(func(src), dest)
def test_unsupervised(self):
+ def func(feeder):
+ with self.test_session():
+ inp, _ = feeder.input_builder()
+ feed_dict_fn = feeder.get_feed_dict_fn()
+ feed_dict = feed_dict_fn()
+ self._assertAllClose(inp, [[1, 2]], feed_dict, 'name')
data = np.matrix([[1, 2], [2, 3], [3, 4]])
- feeder = data_feeder.DataFeeder(data, None, n_classes=0, batch_size=1)
- with self.test_session():
- inp, _ = feeder.input_builder()
- feed_dict_fn = feeder.get_feed_dict_fn()
- feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[inp.name], [[1, 2]])
+ func(data_feeder.DataFeeder(data, None, n_classes=0, batch_size=1))
+ func(data_feeder.DataFeeder(self._wrap_dict(data), None, n_classes=0, batch_size=1))
def test_data_feeder_regression(self):
+ def func(df):
+ inp, out = df.input_builder()
+ feed_dict_fn = df.get_feed_dict_fn()
+ feed_dict = feed_dict_fn()
+ self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name')
+ self._assertAllClose(out, [2, 1], feed_dict, 'name')
x = np.matrix([[1, 2], [3, 4]])
y = np.array([1, 2])
- df = data_feeder.DataFeeder(x, y, n_classes=0, batch_size=3)
- inp, out = df.input_builder()
- feed_dict_fn = df.get_feed_dict_fn()
- feed_dict = feed_dict_fn()
-
- self.assertAllClose(feed_dict[inp.name], [[3, 4], [1, 2]])
- self.assertAllClose(feed_dict[out.name], [2, 1])
+ func(data_feeder.DataFeeder(x, y, n_classes=0, batch_size=3))
+ func(data_feeder.DataFeeder(self._wrap_dict(x, 'in'), self._wrap_dict(y, 'out'),
+ n_classes=self._wrap_dict(0, 'out'), batch_size=3))
def test_epoch(self):
+ def func(feeder):
+ with self.test_session():
+ feeder.input_builder()
+ epoch = feeder.make_epoch_variable()
+ feed_dict_fn = feeder.get_feed_dict_fn()
+ # First input
+ feed_dict = feed_dict_fn()
+ self.assertAllClose(feed_dict[epoch.name], [0])
+ # Second input
+ feed_dict = feed_dict_fn()
+ self.assertAllClose(feed_dict[epoch.name], [0])
+ # Third input
+ feed_dict = feed_dict_fn()
+ self.assertAllClose(feed_dict[epoch.name], [0])
+ # Back to the first input again, so new epoch.
+ feed_dict = feed_dict_fn()
+ self.assertAllClose(feed_dict[epoch.name], [1])
data = np.matrix([[1, 2], [2, 3], [3, 4]])
labels = np.array([0, 0, 1])
- feeder = data_feeder.DataFeeder(data, labels, n_classes=0, batch_size=1)
- with self.test_session():
- feeder.input_builder()
- epoch = feeder.make_epoch_variable()
- feed_dict_fn = feeder.get_feed_dict_fn()
- # First input
- feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[epoch.name], [0])
- # Second input
- feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[epoch.name], [0])
- # Third input
- feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[epoch.name], [0])
- # Back to the first input again, so new epoch.
- feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[epoch.name], [1])
+ func(data_feeder.DataFeeder(data, labels, n_classes=0, batch_size=1))
+ func(data_feeder.DataFeeder(self._wrap_dict(data, 'in'), self._wrap_dict(labels, 'out'),
+ n_classes=self._wrap_dict(0, 'out'), batch_size=1))
def test_data_feeder_multioutput_regression(self):
+ def func(df):
+ inp, out = df.input_builder()
+ feed_dict_fn = df.get_feed_dict_fn()
+ feed_dict = feed_dict_fn()
+ self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name')
+ self._assertAllClose(out, [[3, 4], [1, 2]], feed_dict, 'name')
x = np.matrix([[1, 2], [3, 4]])
y = np.array([[1, 2], [3, 4]])
- df = data_feeder.DataFeeder(x, y, n_classes=0, batch_size=2)
- inp, out = df.input_builder()
- feed_dict_fn = df.get_feed_dict_fn()
- feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[inp.name], [[3, 4], [1, 2]])
- self.assertAllClose(feed_dict[out.name], [[3, 4], [1, 2]])
+ func(data_feeder.DataFeeder(x, y, n_classes=0, batch_size=2))
+ func(data_feeder.DataFeeder(self._wrap_dict(x, 'in'), self._wrap_dict(y, 'out'),
+ n_classes=self._wrap_dict(0, 'out'), batch_size=2))
def test_data_feeder_multioutput_classification(self):
+ def func(df):
+ inp, out = df.input_builder()
+ feed_dict_fn = df.get_feed_dict_fn()
+ feed_dict = feed_dict_fn()
+ self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name')
+ self._assertAllClose(out,
+ [[[0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]],
+ [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]],
+ feed_dict, 'name')
+
x = np.matrix([[1, 2], [3, 4]])
y = np.array([[0, 1, 2], [2, 3, 4]])
- df = data_feeder.DataFeeder(x, y, n_classes=5, batch_size=2)
- inp, out = df.input_builder()
- feed_dict_fn = df.get_feed_dict_fn()
- feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[inp.name], [[3, 4], [1, 2]])
- self.assertAllClose(feed_dict[out.name],
- [[[0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]],
- [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]])
+ func(data_feeder.DataFeeder(x, y, n_classes=5, batch_size=2))
+ func(data_feeder.DataFeeder(self._wrap_dict(x, 'in'), self._wrap_dict(y, 'out'),
+ n_classes=self._wrap_dict(5, 'out'), batch_size=2))
def test_streaming_data_feeder(self):
+ def func(df):
+ inp, out = df.input_builder()
+ feed_dict_fn = df.get_feed_dict_fn()
+ feed_dict = feed_dict_fn()
+ self._assertAllClose(inp, [[1, 2], [3, 4]], feed_dict, 'name')
+ self._assertAllClose(out, [1, 2], feed_dict, 'name' )
- def x_iter():
- yield np.array([1, 2])
- yield np.array([3, 4])
+ def x_iter(wrap_dict=False):
+ yield np.array([1, 2]) if not wrap_dict else self._wrap_dict(np.array([1, 2]), 'in')
+ yield np.array([3, 4]) if not wrap_dict else self._wrap_dict(np.array([3, 4]), 'in')
- def y_iter():
- yield np.array([1])
- yield np.array([2])
+ def y_iter(wrap_dict=False):
+ yield np.array([1]) if not wrap_dict else self._wrap_dict(np.array([1]), 'out')
+ yield np.array([2]) if not wrap_dict else self._wrap_dict(np.array([2]), 'out')
- df = data_feeder.StreamingDataFeeder(x_iter(),
- y_iter(),
- n_classes=0,
- batch_size=2)
- inp, out = df.input_builder()
- feed_dict_fn = df.get_feed_dict_fn()
- feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[inp.name], [[1, 2], [3, 4]])
- self.assertAllClose(feed_dict[out.name], [1, 2])
+ func(data_feeder.StreamingDataFeeder(x_iter(), y_iter(), n_classes=0, batch_size=2))
+ func(data_feeder.StreamingDataFeeder(x_iter(True), y_iter(True),
+ n_classes=self._wrap_dict(0, 'out'), batch_size=2))
def test_dask_data_feeder(self):
if HAS_PANDAS and HAS_DASK:
@@ -196,6 +246,13 @@ class DataFeederTest(tf.test.TestCase):
self.assertAllClose(feed_dict[out.name], [[0., 0., 1.], [0., 1., 0.]])
def test_hdf5_data_feeder(self):
+ def func(df):
+ inp, out = df.input_builder()
+ feed_dict_fn = df.get_feed_dict_fn()
+ feed_dict = feed_dict_fn()
+ self._assertAllClose(inp, [[3, 4], [1, 2]], feed_dict, 'name')
+ self.assertAllClose(out, [2, 1], feed_dict, 'name')
+
try:
import h5py # pylint: disable=g-import-not-at-top
x = np.matrix([[1, 2], [3, 4]])
@@ -207,25 +264,28 @@ class DataFeederTest(tf.test.TestCase):
h5f = h5py.File('test_hdf5.h5', 'r')
x = h5f['x']
y = h5f['y']
- df = data_feeder.DataFeeder(x, y, n_classes=0, batch_size=3)
- inp, out = df.input_builder()
- feed_dict_fn = df.get_feed_dict_fn()
- feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[inp.name], [[3, 4], [1, 2]])
- self.assertAllClose(feed_dict[out.name], [2, 1])
+ func(data_feeder.DataFeeder(x, y, n_classes=0, batch_size=3))
+ func(data_feeder.DataFeeder(self._wrap_dict(x, 'in'), self._wrap_dict(y, 'out'),
+ n_classes=self._wrap_dict(0, 'out'), batch_size=3))
except ImportError:
print("Skipped test for hdf5 since it's not installed.")
-class SetupPredictDataFeederTest(tf.test.TestCase):
+class SetupPredictDataFeederTest(DataFeederTest):
"""Tests for `DataFeeder.setup_predict_data_feeder`."""
def test_iterable_data(self):
# pylint: disable=undefined-variable
- x = iter([[1, 2], [3, 4], [5, 6]])
- df = data_feeder.setup_predict_data_feeder(x, batch_size=2)
- self.assertAllClose(six.next(df), [[1, 2], [3, 4]])
- self.assertAllClose(six.next(df), [[5, 6]])
+
+ def func(df):
+ self._assertAllClose(six.next(df), [[1, 2], [3, 4]])
+ self._assertAllClose(six.next(df), [[5, 6]])
+
+ data = [[1, 2], [3, 4], [5, 6]]
+ x = iter(data)
+ x_dict = iter([self._wrap_dict(v) for v in iter(data)])
+ func(data_feeder.setup_predict_data_feeder(x, batch_size=2))
+ func(data_feeder.setup_predict_data_feeder(x_dict, batch_size=2))
if __name__ == '__main__':
diff --git a/tensorflow/contrib/learn/python/learn/trainable.py b/tensorflow/contrib/learn/python/learn/trainable.py
index 8a1548738e..2d1d460425 100644
--- a/tensorflow/contrib/learn/python/learn/trainable.py
+++ b/tensorflow/contrib/learn/python/learn/trainable.py
@@ -33,17 +33,17 @@ class Trainable(object):
"""Trains a model given training data `x` predictions and `y` labels.
Args:
- x: Matrix of shape [n_samples, n_features...]. Can be iterator that
- returns arrays of features. The training input samples for fitting the
- model. If set, `input_fn` must be `None`.
- y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
- iterator that returns array of labels. The training label values
- (class labels in classification, real numbers in regression). If set,
- `input_fn` must be `None`. Note: For classification, label values must
+ x: Matrix of shape [n_samples, n_features...] or the dictionary of Matrices.
+ Can be iterator that returns arrays of features or dictionary of arrays of features.
+ The training input samples for fitting the model. If set, `input_fn` must be `None`.
+ y: Vector or matrix [n_samples] or [n_samples, n_outputs] or the dictionary of same.
+ Can be iterator that returns array of labels or dictionary of array of labels.
+ The training label values (class labels in classification, real numbers in regression).
+ If set, `input_fn` must be `None`. Note: For classification, label values must
be integers representing the class index (i.e. values from 0 to
n_classes-1).
input_fn: Input function returning a tuple of:
- features - Dictionary of string feature name to `Tensor` or `Tensor`.
+ features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
If input_fn is set, `x`, `y`, and `batch_size` must be `None`.
steps: Number of steps for which to train model. If `None`, train forever.
@@ -67,4 +67,3 @@ class Trainable(object):
`self`, for chaining.
"""
raise NotImplementedError
-
diff --git a/tensorflow/contrib/learn/python/learn/utils/__init__.py b/tensorflow/contrib/learn/python/learn/utils/__init__.py
index 149a4b9772..f313699c14 100644
--- a/tensorflow/contrib/learn/python/learn/utils/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/utils/__init__.py
@@ -19,5 +19,4 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.learn.python.learn.utils import checkpoints
from tensorflow.contrib.learn.python.learn.utils.export import export_estimator