aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-19 08:46:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-19 10:03:22 -0700
commitc3a5ff336bb64a43353fd799772220a1a07f7a00 (patch)
tree61b31ddbe4d43810b50a25ae371b5390964e5888
parent1a5210752d444e7c0e6c2ab58ad034e7b736d573 (diff)
Make DeprecatedMixin and more some tf.learn tests compatible with the iterable
version of predict. Change: 130761564
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/base.py16
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py17
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/stability_test.py8
3 files changed, 21 insertions, 20 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/base.py b/tensorflow/contrib/learn/python/learn/estimators/base.py
index ae7ba2cfa4..584384a4ec 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/base.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/base.py
@@ -125,13 +125,14 @@ class DeprecatedMixin(object):
x, None, n_classes=None,
batch_size=batch_size or self.batch_size,
shuffle=False, epochs=1)
- result = super(DeprecatedMixin, self)._infer_model(
+ result_iter = super(DeprecatedMixin, self)._infer_model(
input_fn=predict_data_feeder.input_builder,
feed_fn=predict_data_feeder.get_feed_dict_fn(),
- outputs=outputs)
+ outputs=outputs, as_iterable=True)
else:
- result = super(DeprecatedMixin, self)._infer_model(
- input_fn=input_fn, outputs=outputs)
+ result_iter = super(DeprecatedMixin, self)._infer_model(
+ input_fn=input_fn, outputs=outputs, as_iterable=True)
+ result = np.array(list(result_iter))
if self.__deprecated_n_classes > 1 and axis is not None:
return np.argmax(result, axis)
return result
@@ -327,13 +328,12 @@ class TensorFlowEstimator(estimator.Estimator, DeprecatedMixin):
batch_size=batch_size,
shuffle=False, epochs=1)
- preds = self._infer_model(
+ preds = np.array(list(self._infer_model(
input_fn=predict_data_feeder.input_builder,
- feed_fn=predict_data_feeder.get_feed_dict_fn())
+ feed_fn=predict_data_feeder.get_feed_dict_fn(),
+ as_iterable=True)))
if self.n_classes > 1 and axis != -1:
preds = preds.argmax(axis=axis)
- else:
- preds = preds
return preds
def predict(self, x, axis=1, batch_size=None):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index 1120f3cb36..0dabf0ffc8 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
import tempfile
import numpy as np
@@ -84,12 +85,11 @@ class LinearClassifierTest(tf.test.TestCase):
def testTrainSaveLoad(self):
"""Tests that insures you can save and reload a trained model."""
- def input_fn():
+ def input_fn(num_epochs=None):
return {
- 'age': tf.constant([1]),
- 'language': tf.SparseTensor(values=['english'],
- indices=[[0, 0]],
- shape=[1, 1])
+ 'age': tf.train.limit_epochs(tf.constant([1]), num_epochs=num_epochs),
+ 'language': tf.SparseTensor(
+ values=['english'], indices=[[0, 0]], shape=[1, 1]),
}, tf.constant([[1]])
language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100)
@@ -100,14 +100,15 @@ class LinearClassifierTest(tf.test.TestCase):
model_dir=model_dir,
feature_columns=[age, language])
classifier.fit(input_fn=input_fn, steps=30)
- out1 = classifier.predict(input_fn=input_fn)
+ predict_input_fn = functools.partial(input_fn, num_epochs=1)
+ out1 = classifier.predict(input_fn=predict_input_fn, as_iterable=True)
del classifier
classifier2 = tf.contrib.learn.LinearClassifier(
model_dir=model_dir,
feature_columns=[age, language])
- out2 = classifier2.predict(input_fn=input_fn)
- self.assertEqual(out1, out2)
+ out2 = classifier2.predict(input_fn=predict_input_fn, as_iterable=True)
+ self.assertEqual(list(out1), list(out2))
def testExport(self):
"""Tests that export model for servo works."""
diff --git a/tensorflow/contrib/learn/python/learn/tests/stability_test.py b/tensorflow/contrib/learn/python/learn/tests/stability_test.py
index 9de02264ee..84ae45bb5d 100644
--- a/tensorflow/contrib/learn/python/learn/tests/stability_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/stability_test.py
@@ -97,8 +97,8 @@ class StabilityTest(tf.test.TestCase):
self.assertAllClose(regressor1.weights_, regressor2.weights_)
self.assertAllClose(regressor1.bias_, regressor2.bias_)
self.assertAllClose(
- regressor1.predict(boston.data), regressor2.predict(boston.data),
- atol=1e-05)
+ list(regressor1.predict(boston.data, as_iterable=True)),
+ list(regressor2.predict(boston.data, as_iterable=True)), atol=1e-05)
def testDNNRegression(self):
my_seed = 42
@@ -129,8 +129,8 @@ class StabilityTest(tf.test.TestCase):
for b1, b2 in zip(regressor2.bias_, regressor2.bias_):
self.assertAllClose(b1, b2)
self.assertAllClose(
- regressor1.predict(boston.data), regressor2.predict(boston.data),
- atol=1e-05)
+ list(regressor1.predict(boston.data, as_iterable=True)),
+ list(regressor2.predict(boston.data, as_iterable=True)), atol=1e-05)
if __name__ == '__main__':