diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-08-19 08:46:48 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-19 10:03:22 -0700 |
commit | c3a5ff336bb64a43353fd799772220a1a07f7a00 (patch) | |
tree | 61b31ddbe4d43810b50a25ae371b5390964e5888 | |
parent | 1a5210752d444e7c0e6c2ab58ad034e7b736d573 (diff) |
Make DeprecatedMixin and more some tf.learn tests compatible with the iterable
version of predict.
Change: 130761564
3 files changed, 21 insertions, 20 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/base.py b/tensorflow/contrib/learn/python/learn/estimators/base.py index ae7ba2cfa4..584384a4ec 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/base.py +++ b/tensorflow/contrib/learn/python/learn/estimators/base.py @@ -125,13 +125,14 @@ class DeprecatedMixin(object): x, None, n_classes=None, batch_size=batch_size or self.batch_size, shuffle=False, epochs=1) - result = super(DeprecatedMixin, self)._infer_model( + result_iter = super(DeprecatedMixin, self)._infer_model( input_fn=predict_data_feeder.input_builder, feed_fn=predict_data_feeder.get_feed_dict_fn(), - outputs=outputs) + outputs=outputs, as_iterable=True) else: - result = super(DeprecatedMixin, self)._infer_model( - input_fn=input_fn, outputs=outputs) + result_iter = super(DeprecatedMixin, self)._infer_model( + input_fn=input_fn, outputs=outputs, as_iterable=True) + result = np.array(list(result_iter)) if self.__deprecated_n_classes > 1 and axis is not None: return np.argmax(result, axis) return result @@ -327,13 +328,12 @@ class TensorFlowEstimator(estimator.Estimator, DeprecatedMixin): batch_size=batch_size, shuffle=False, epochs=1) - preds = self._infer_model( + preds = np.array(list(self._infer_model( input_fn=predict_data_feeder.input_builder, - feed_fn=predict_data_feeder.get_feed_dict_fn()) + feed_fn=predict_data_feeder.get_feed_dict_fn(), + as_iterable=True))) if self.n_classes > 1 and axis != -1: preds = preds.argmax(axis=axis) - else: - preds = preds return preds def predict(self, x, axis=1, batch_size=None): diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index 1120f3cb36..0dabf0ffc8 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import tempfile import numpy as np @@ -84,12 +85,11 @@ class LinearClassifierTest(tf.test.TestCase): def testTrainSaveLoad(self): """Tests that insures you can save and reload a trained model.""" - def input_fn(): + def input_fn(num_epochs=None): return { - 'age': tf.constant([1]), - 'language': tf.SparseTensor(values=['english'], - indices=[[0, 0]], - shape=[1, 1]) + 'age': tf.train.limit_epochs(tf.constant([1]), num_epochs=num_epochs), + 'language': tf.SparseTensor( + values=['english'], indices=[[0, 0]], shape=[1, 1]), }, tf.constant([[1]]) language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100) @@ -100,14 +100,15 @@ class LinearClassifierTest(tf.test.TestCase): model_dir=model_dir, feature_columns=[age, language]) classifier.fit(input_fn=input_fn, steps=30) - out1 = classifier.predict(input_fn=input_fn) + predict_input_fn = functools.partial(input_fn, num_epochs=1) + out1 = classifier.predict(input_fn=predict_input_fn, as_iterable=True) del classifier classifier2 = tf.contrib.learn.LinearClassifier( model_dir=model_dir, feature_columns=[age, language]) - out2 = classifier2.predict(input_fn=input_fn) - self.assertEqual(out1, out2) + out2 = classifier2.predict(input_fn=predict_input_fn, as_iterable=True) + self.assertEqual(list(out1), list(out2)) def testExport(self): """Tests that export model for servo works.""" diff --git a/tensorflow/contrib/learn/python/learn/tests/stability_test.py b/tensorflow/contrib/learn/python/learn/tests/stability_test.py index 9de02264ee..84ae45bb5d 100644 --- a/tensorflow/contrib/learn/python/learn/tests/stability_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/stability_test.py @@ -97,8 +97,8 @@ class StabilityTest(tf.test.TestCase): self.assertAllClose(regressor1.weights_, regressor2.weights_) self.assertAllClose(regressor1.bias_, regressor2.bias_) self.assertAllClose( - regressor1.predict(boston.data), regressor2.predict(boston.data), - atol=1e-05) + list(regressor1.predict(boston.data, as_iterable=True)), + list(regressor2.predict(boston.data, as_iterable=True)), atol=1e-05) def testDNNRegression(self): my_seed = 42 @@ -129,8 +129,8 @@ class StabilityTest(tf.test.TestCase): for b1, b2 in zip(regressor2.bias_, regressor2.bias_): self.assertAllClose(b1, b2) self.assertAllClose( - regressor1.predict(boston.data), regressor2.predict(boston.data), - atol=1e-05) + list(regressor1.predict(boston.data, as_iterable=True)), + list(regressor2.predict(boston.data, as_iterable=True)), atol=1e-05) if __name__ == '__main__': |