aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator_test.py
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-11-28 14:10:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-28 14:13:52 -0800
commitc294fcfd85c03a801d3aad83cfd08055dadbad1a (patch)
tree39fd3b4178ad224fa1f265dd4e74b936fead3298 /tensorflow/python/estimator/estimator_test.py
parentf93c8a72154fd22fe1578bf448df156acd54fddf (diff)
Dataset support within Estimator. With this cl Input_fn can return a Dataset.
PiperOrigin-RevId: 177215252
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r--tensorflow/python/estimator/estimator_test.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index c1b773b8c4..db64fbc9cc 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -913,6 +913,80 @@ class EstimatorGetVariablesTest(test.TestCase):
self.assertEqual(3., est.get_variable_value('three'))
+class EstimatorDatasetIntegrationTest(test.TestCase):
+ """Tests dataset integration."""
+
+ def test_returned_by_input_fn(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensors(([1.], [2.]))
+
+ def _model_fn(features, labels, mode):
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=features + labels, # 1 + 2
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(_input_fn, steps=1)
+ scores = est.evaluate(_input_fn, steps=1)
+ self.assertEqual(3., scores[model_fn_lib.LOSS_METRIC_KEY])
+
+ def test_with_none_labels(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensors([7.])
+
+ def _model_fn(features, labels, mode):
+ self.assertIsNone(labels)
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=features, # 7
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(_input_fn, steps=1)
+ scores = est.evaluate(_input_fn, steps=1)
+ self.assertEqual(7., scores[model_fn_lib.LOSS_METRIC_KEY])
+
+ def test_with_predict(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensors([10.])
+
+ def _model_fn(features, labels, mode):
+ _ = labels
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=features, # 10
+ loss=features, # 10
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(_input_fn, steps=1)
+ self.assertEqual([10.], next(est.predict(input_fn=_input_fn)))
+
+ def test_batching(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensor_slices(([[1.], [2.]],
+ [[10.], [20.]])).batch(1)
+
+ def _model_fn(features, labels, mode):
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=features,
+ loss=features + (0 if labels is None else labels), # 11, 22
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(_input_fn)
+ scores = est.evaluate(_input_fn)
+ # (11 + 22)/2 = 16.5
+ self.assertEqual(16.5, scores[model_fn_lib.LOSS_METRIC_KEY])
+ self.assertEqual([1., 2.], list(est.predict(_input_fn)))
+
+
class EstimatorEvaluateTest(test.TestCase):
def test_input_fn_args(self):