aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-06-07 13:37:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-07 13:45:11 -0700
commit38249d6be21e77bbd0663b71af598c0bdb99d6dc (patch)
tree0fecd933878f6bfedc90db6e7f0cc6e6cf00f7b9
parent599727c654aac53ee6f290b3d5e36c0e0852e951 (diff)
Swap the order of NanTensorHook and custom hooks
to ensure that when the training encounteres NaN's in the loss function, user-supplied hooks such as tf_debug.LocalCLIDebugHook can still be used to debug the root cause of the numeric issues. PiperOrigin-RevId: 158310249
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py2
-rw-r--r--tensorflow/python/estimator/estimator.py2
-rw-r--r--tensorflow/python/estimator/estimator_test.py42
3 files changed, 44 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 534aac644a..ac5ef565c8 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -957,6 +957,7 @@ class BaseEstimator(
self._check_inputs(features, labels)
model_fn_ops = self._get_train_ops(features, labels)
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
+ all_hooks.extend(hooks)
all_hooks.extend([
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
basic_session_run_hooks.LoggingTensorHook(
@@ -966,7 +967,6 @@ class BaseEstimator(
},
every_n_iter=100)
])
- all_hooks.extend(hooks)
scaffold = model_fn_ops.scaffold or monitored_session.Scaffold()
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index f424598ccb..8e6edf6da7 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -589,6 +589,7 @@ class Estimator(object):
estimator_spec = self._call_model_fn(features, labels,
model_fn_lib.ModeKeys.TRAIN)
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
+ all_hooks.extend(hooks)
all_hooks.extend([
training.NanTensorHook(estimator_spec.loss),
training.LoggingTensorHook(
@@ -598,7 +599,6 @@ class Estimator(object):
},
every_n_iter=100)
])
- all_hooks.extend(hooks)
all_hooks.extend(estimator_spec.training_hooks)
if not (estimator_spec.scaffold.saver or
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 4119a07bd8..b86afece43 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -55,6 +55,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_state_pb2
from tensorflow.python.training import saver
from tensorflow.python.training import saver_test_utils
@@ -1520,6 +1521,47 @@ class EstimatorExportTest(test.TestCase):
est.export_savedmodel(tempfile.mkdtemp(), serving_input_receiver_fn)
+class EstimatorHookOrderingTest(test.TestCase):
+
+ def testCustomHooksAreCalledBeforeNanTensorHook(self):
+
+ def nan_making_model_fn(mode, features, labels):
+ """A graph that generates NaN's for testing."""
+ del features, labels
+
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, name='global_step')
+ inc_global_step = state_ops.assign_add(global_step, 1)
+ nan_const = constant_op.constant(np.nan, dtype=dtypes.float32)
+ loss = control_flow_ops.cond(
+ inc_global_step > 1, lambda: nan_const, lambda: 1.0)
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ predictions=global_step.read_value(),
+ loss=loss,
+ train_op=inc_global_step)
+
+ def empty_input_fn():
+ return dict(), None
+
+ class AfterRunCountingHook(session_run_hook.SessionRunHook):
+ """Hooks that counts the number of times after_run() is called."""
+
+ def __init__(self):
+ self.after_run_count = 0
+
+ def after_run(self, run_context, run_values):
+ del run_context, run_values
+ self.after_run_count += 1
+
+ test_hook = AfterRunCountingHook()
+ est = estimator.Estimator(model_fn=nan_making_model_fn)
+ with self.assertRaises(basic_session_run_hooks.NanLossDuringTrainingError):
+ est.train(input_fn=empty_input_fn, steps=2, hooks=[test_hook])
+ self.assertEqual(2, test_hook.after_run_count)
+
+
class EstimatorIntegrationTest(test.TestCase):
def test_complete_flow_with_a_simple_linear_model(self):