diff options
author | Mustafa Ispir <ispir@google.com> | 2017-11-28 14:10:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-28 14:13:52 -0800 |
commit | c294fcfd85c03a801d3aad83cfd08055dadbad1a (patch) | |
tree | 39fd3b4178ad224fa1f265dd4e74b936fead3298 /tensorflow/python/estimator/estimator_test.py | |
parent | f93c8a72154fd22fe1578bf448df156acd54fddf (diff) |
Dataset support within Estimator. With this cl Input_fn can return a Dataset.
PiperOrigin-RevId: 177215252
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index c1b773b8c4..db64fbc9cc 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -913,6 +913,80 @@ class EstimatorGetVariablesTest(test.TestCase): self.assertEqual(3., est.get_variable_value('three')) +class EstimatorDatasetIntegrationTest(test.TestCase): + """Tests dataset integration.""" + + def test_returned_by_input_fn(self): + + def _input_fn(): + return dataset_ops.Dataset.from_tensors(([1.], [2.])) + + def _model_fn(features, labels, mode): + return model_fn_lib.EstimatorSpec( + mode, + loss=features + labels, # 1 + 2 + train_op=state_ops.assign_add(training.get_global_step(), 1)) + + est = estimator.Estimator(model_fn=_model_fn) + est.train(_input_fn, steps=1) + scores = est.evaluate(_input_fn, steps=1) + self.assertEqual(3., scores[model_fn_lib.LOSS_METRIC_KEY]) + + def test_with_none_labels(self): + + def _input_fn(): + return dataset_ops.Dataset.from_tensors([7.]) + + def _model_fn(features, labels, mode): + self.assertIsNone(labels) + return model_fn_lib.EstimatorSpec( + mode, + loss=features, # 7 + train_op=state_ops.assign_add(training.get_global_step(), 1)) + + est = estimator.Estimator(model_fn=_model_fn) + est.train(_input_fn, steps=1) + scores = est.evaluate(_input_fn, steps=1) + self.assertEqual(7., scores[model_fn_lib.LOSS_METRIC_KEY]) + + def test_with_predict(self): + + def _input_fn(): + return dataset_ops.Dataset.from_tensors([10.]) + + def _model_fn(features, labels, mode): + _ = labels + return model_fn_lib.EstimatorSpec( + mode, + predictions=features, # 10 + loss=features, # 10 + train_op=state_ops.assign_add(training.get_global_step(), 1)) + + est = estimator.Estimator(model_fn=_model_fn) + est.train(_input_fn, steps=1) + self.assertEqual([10.], next(est.predict(input_fn=_input_fn))) + + def test_batching(self): + + def _input_fn(): + return dataset_ops.Dataset.from_tensor_slices(([[1.], [2.]], + [[10.], [20.]])).batch(1) + + def _model_fn(features, labels, mode): + return model_fn_lib.EstimatorSpec( + mode, + predictions=features, + loss=features + (0 if labels is None else labels), # 11, 22 + train_op=state_ops.assign_add(training.get_global_step(), 1)) + + est = estimator.Estimator(model_fn=_model_fn) + est.train(_input_fn) + scores = est.evaluate(_input_fn) + # (11 + 22)/2 = 16.5 + self.assertEqual(16.5, scores[model_fn_lib.LOSS_METRIC_KEY]) + self.assertEqual([1., 2.], list(est.predict(_input_fn))) + + class EstimatorEvaluateTest(test.TestCase): def test_input_fn_args(self): |