diff options
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimator.py | 2 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 2 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 42 |
3 files changed, 44 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 534aac644a..ac5ef565c8 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -957,6 +957,7 @@ class BaseEstimator( self._check_inputs(features, labels) model_fn_ops = self._get_train_ops(features, labels) ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss) + all_hooks.extend(hooks) all_hooks.extend([ basic_session_run_hooks.NanTensorHook(model_fn_ops.loss), basic_session_run_hooks.LoggingTensorHook( @@ -966,7 +967,6 @@ class BaseEstimator( }, every_n_iter=100) ]) - all_hooks.extend(hooks) scaffold = model_fn_ops.scaffold or monitored_session.Scaffold() if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index f424598ccb..8e6edf6da7 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -589,6 +589,7 @@ class Estimator(object): estimator_spec = self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) + all_hooks.extend(hooks) all_hooks.extend([ training.NanTensorHook(estimator_spec.loss), training.LoggingTensorHook( @@ -598,7 +599,6 @@ class Estimator(object): }, every_n_iter=100) ]) - all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) if not (estimator_spec.scaffold.saver or diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 4119a07bd8..b86afece43 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -55,6 +55,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import checkpoint_state_pb2 from tensorflow.python.training import saver from tensorflow.python.training import saver_test_utils @@ -1520,6 +1521,47 @@ class EstimatorExportTest(test.TestCase): est.export_savedmodel(tempfile.mkdtemp(), serving_input_receiver_fn) +class EstimatorHookOrderingTest(test.TestCase): + + def testCustomHooksAreCalledBeforeNanTensorHook(self): + + def nan_making_model_fn(mode, features, labels): + """A graph that generates NaN's for testing.""" + del features, labels + + global_step = variables.Variable( + 0, dtype=dtypes.int64, name='global_step') + inc_global_step = state_ops.assign_add(global_step, 1) + nan_const = constant_op.constant(np.nan, dtype=dtypes.float32) + loss = control_flow_ops.cond( + inc_global_step > 1, lambda: nan_const, lambda: 1.0) + + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions=global_step.read_value(), + loss=loss, + train_op=inc_global_step) + + def empty_input_fn(): + return dict(), None + + class AfterRunCountingHook(session_run_hook.SessionRunHook): + """Hooks that counts the number of times after_run() is called.""" + + def __init__(self): + self.after_run_count = 0 + + def after_run(self, run_context, run_values): + del run_context, run_values + self.after_run_count += 1 + + test_hook = AfterRunCountingHook() + est = estimator.Estimator(model_fn=nan_making_model_fn) + with self.assertRaises(basic_session_run_hooks.NanLossDuringTrainingError): + est.train(input_fn=empty_input_fn, steps=2, hooks=[test_hook]) + self.assertEqual(2, test_hook.after_run_count) + + class EstimatorIntegrationTest(test.TestCase): def test_complete_flow_with_a_simple_linear_model(self): |