aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r--tensorflow/python/estimator/estimator_test.py35
1 files changed, 25 insertions, 10 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index e3f22d9010..05d1a04d2f 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -58,6 +58,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
+from tensorflow.python.ops.random_ops import random_uniform
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
@@ -158,16 +159,7 @@ class EstimatorInheritanceConstraintTest(test.TestCase):
def __init__(self):
super(_Estimator, self).__init__(model_fn=dummy_model_fn)
- def _call_input_fn(self, input_fn, mode):
- return input_fn()
-
- def _create_global_step(self, graph):
- pass
-
- def _convert_train_steps_to_hooks(self, steps, max_steps):
- pass
-
- def _convert_eval_steps_to_hooks(self, steps):
+ def _tf_api_names(self):
pass
_Estimator()
@@ -473,6 +465,29 @@ class EstimatorTrainTest(test.TestCase):
est.train(InputFn(), steps=1)
self.assertEqual(1, input_fn_call_count[0])
+ def test_nested_input_fn(self):
+ expected_params = {'batch_size': 10}
+
+ def _input_fn():
+ dataset_features = dataset_ops.Dataset.from_tensor_slices(
+ (random_uniform([4]),
+ random_uniform([4, 100], maxval=100, dtype=dtypes.int32)))
+ dataset_labels = dataset_ops.Dataset.from_tensor_slices(
+ random_uniform([4, 10]))
+ dataset = dataset_ops.Dataset.zip((dataset_features, dataset_labels))
+ dataset = dataset.repeat(-1)
+ iterator = dataset.make_initializable_iterator()
+ return iterator.get_next()
+
+ def _model_fn(features, labels, mode, params, config):
+ del params, config
+ return model_fn_global_step_incrementer(features, labels, mode)
+
+ expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
+ est = estimator.Estimator(
+ model_fn=_model_fn, params=expected_params, config=expected_config)
+ est.train(_input_fn, steps=4)
+
def test_input_fn_args(self):
expected_mode = model_fn_lib.ModeKeys.TRAIN
expected_params = {'batch_size': 10}