diff options
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index e3f22d9010..05d1a04d2f 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -58,6 +58,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses +from tensorflow.python.ops.random_ops import random_uniform from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging @@ -158,16 +159,7 @@ class EstimatorInheritanceConstraintTest(test.TestCase): def __init__(self): super(_Estimator, self).__init__(model_fn=dummy_model_fn) - def _call_input_fn(self, input_fn, mode): - return input_fn() - - def _create_global_step(self, graph): - pass - - def _convert_train_steps_to_hooks(self, steps, max_steps): - pass - - def _convert_eval_steps_to_hooks(self, steps): + def _tf_api_names(self): pass _Estimator() @@ -473,6 +465,29 @@ class EstimatorTrainTest(test.TestCase): est.train(InputFn(), steps=1) self.assertEqual(1, input_fn_call_count[0]) + def test_nested_input_fn(self): + expected_params = {'batch_size': 10} + + def _input_fn(): + dataset_features = dataset_ops.Dataset.from_tensor_slices( + (random_uniform([4]), + random_uniform([4, 100], maxval=100, dtype=dtypes.int32))) + dataset_labels = dataset_ops.Dataset.from_tensor_slices( + random_uniform([4, 10])) + dataset = dataset_ops.Dataset.zip((dataset_features, dataset_labels)) + dataset = dataset.repeat(-1) + iterator = dataset.make_initializable_iterator() + return iterator.get_next() + + def _model_fn(features, labels, mode, params, config): + del params, config + return model_fn_global_step_incrementer(features, labels, mode) + + expected_config = run_config.RunConfig().replace(tf_random_seed=4321) + est = estimator.Estimator( + model_fn=_model_fn, params=expected_params, config=expected_config) + est.train(_input_fn, steps=4) + def test_input_fn_args(self): expected_mode = model_fn_lib.ModeKeys.TRAIN expected_params = {'batch_size': 10} |