diff options
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 108 |
1 files changed, 95 insertions, 13 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 8bc410ba0b..d316742a83 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -58,6 +58,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses +from tensorflow.python.ops.random_ops import random_uniform from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging @@ -69,6 +70,7 @@ from tensorflow.python.summary import summary from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import checkpoint_state_pb2 from tensorflow.python.training import saver from tensorflow.python.training import saver_test_utils @@ -157,16 +159,7 @@ class EstimatorInheritanceConstraintTest(test.TestCase): def __init__(self): super(_Estimator, self).__init__(model_fn=dummy_model_fn) - def _call_input_fn(self, input_fn, mode): - return input_fn() - - def _create_global_step(self, graph): - pass - - def _convert_train_steps_to_hooks(self, steps, max_steps): - pass - - def _convert_eval_steps_to_hooks(self, steps): + def _tf_api_names(self): pass _Estimator() @@ -175,7 +168,7 @@ class EstimatorInheritanceConstraintTest(test.TestCase): class EstimatorConstructorTest(test.TestCase): def test_config_must_be_a_run_config(self): - with self.assertRaisesRegexp(ValueError, 'an instance of RunConfig'): + with self.assertRaisesRegexp(ValueError, 'an instance of `RunConfig`'): estimator.Estimator(model_fn=None, config='NotARunConfig') def test_model_fn_must_be_provided(self): @@ -228,6 +221,15 @@ class EstimatorConstructorTest(test.TestCase): self.assertEqual(_TMP_DIR, est.config.model_dir) self.assertEqual(_TMP_DIR, est.model_dir) + def test_empty_model_dir(self): + def model_fn(features, labels): + _, _ = features, labels + + with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR): + est = estimator.Estimator(model_fn=model_fn, model_dir='') + self.assertEqual(_TMP_DIR, est.config.model_dir) + self.assertEqual(_TMP_DIR, est.model_dir) + def test_model_dir_in_run_config(self): class FakeConfig(run_config.RunConfig): @@ -272,7 +274,7 @@ class EstimatorConstructorTest(test.TestCase): with self.assertRaisesRegexp( ValueError, - 'model_dir are set both in constructor and RunConfig, but ' + '`model_dir` are set both in constructor and `RunConfig`, but ' 'with different values'): estimator.Estimator( model_fn=model_fn, config=FakeConfig(), model_dir=_ANOTHER_TMP_DIR) @@ -463,6 +465,29 @@ class EstimatorTrainTest(test.TestCase): est.train(InputFn(), steps=1) self.assertEqual(1, input_fn_call_count[0]) + def test_nested_input_fn(self): + expected_params = {'batch_size': 10} + + def _input_fn(): + dataset_features = dataset_ops.Dataset.from_tensor_slices( + (random_uniform([4]), + random_uniform([4, 100], maxval=100, dtype=dtypes.int32))) + dataset_labels = dataset_ops.Dataset.from_tensor_slices( + random_uniform([4, 10])) + dataset = dataset_ops.Dataset.zip((dataset_features, dataset_labels)) + dataset = dataset.repeat(-1) + iterator = dataset.make_initializable_iterator() + return iterator.get_next() + + def _model_fn(features, labels, mode, params, config): + del params, config + return model_fn_global_step_incrementer(features, labels, mode) + + expected_config = run_config.RunConfig().replace(tf_random_seed=4321) + est = estimator.Estimator( + model_fn=_model_fn, params=expected_params, config=expected_config) + est.train(_input_fn, steps=4) + def test_input_fn_args(self): expected_mode = model_fn_lib.ModeKeys.TRAIN expected_params = {'batch_size': 10} @@ -930,6 +955,19 @@ class EstimatorTrainTest(test.TestCase): est = estimator.Estimator(model_fn=_model_fn) est.train(dummy_input_fn, steps=1) + def test_config_should_not_be_evaluator_or_ps(self): + + class FakeEvaluatorConfig(run_config.RunConfig): + + @property + def task_type(self): + return run_config.TaskType.EVALUATOR + + est = estimator.Estimator( + model_fn=dummy_model_fn, config=FakeEvaluatorConfig()) + with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'): + est.train(dummy_input_fn, steps=1) + def _model_fn_with_eval_metric_ops(features, labels, mode, params): _, _ = features, labels @@ -1448,6 +1486,48 @@ class EstimatorEvaluateTest(test.TestCase): self.assertProtoEquals(expected_tensor_proto, next(summaries).value[0].tensor) + def test_summary_writing_with_tensor(self): + + def model_fn_with_prediction_mean_tensor_eval_metric_ops( + features, labels, mode, params): + _, _ = features, labels + global_step = training.get_global_step() + + metric_name = params.get('metric_name') or 'metric' + predictions = constant_op.constant([1., .5, 0.]) + eval_metric_ops = {metric_name: metrics_lib.mean_tensor(predictions)} + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant(1.), + predictions={'predictions': predictions}, + train_op=state_ops.assign_add(global_step, 1), + eval_metric_ops=eval_metric_ops) + + metric_key = 'PMT' + params = { + 'metric_name': metric_key, + } + est = estimator.Estimator( + model_fn=model_fn_with_prediction_mean_tensor_eval_metric_ops, + params=params, + config=run_config.RunConfig(save_summary_steps=1)) + est.train(input_fn=dummy_input_fn, steps=10) + est.evaluate( + input_fn=dummy_input_fn, + steps=10, + ) + + writer_cache.FileWriterCache.clear() + + self.assertTrue( + check_eventfile_for_keyword(metric_key, est.eval_dir()), + '{} should be part of reported summaries.'.format(metric_key)) + + summaries = summaries_with_matching_keyword(metric_key, est.eval_dir()) + for value in next(summaries).value: + if value.tag == metric_key: + self.assertTrue(value.HasField('tensor')) + class EstimatorPredictTest(test.TestCase): @@ -1539,7 +1619,8 @@ class EstimatorPredictTest(test.TestCase): next( est.predict( dummy_input_fn, - checkpoint_path=saver.latest_checkpoint('fakedir'))) + checkpoint_path= + checkpoint_management.latest_checkpoint('fakedir'))) def test_tensor_predictions(self): @@ -2630,6 +2711,7 @@ class EstimatorExportTest(test.TestCase): _, _ = features, labels my_int = variables.Variable(1, name='my_int', collections=[ops.GraphKeys.LOCAL_VARIABLES]) + _ = training.get_or_create_steps_per_run_variable() scores = constant_op.constant([3.]) with ops.control_dependencies([ variables.local_variables_initializer(), |