aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator_test.py
diff options
context:
space:
mode:
authorGravatar Nick Felt <nickfelt@google.com>2018-07-24 09:49:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 09:53:22 -0700
commit568727eed199dba04e37f500265b50f96fed455e (patch)
tree999f31d1469b3b5f2dc12d5ca04061cfe6062faa /tensorflow/python/estimator/estimator_test.py
parentf8bbd3ceb7e86b7595ba74a9a03cfc7c1be252a8 (diff)
Add v2 summary support to Estimator.train() and MonitoredSession hooks
This change makes Estimator.train() support v2 summaries (tf.contrib.summary.*) out-of-the-box, to match the support for v1 summaries. Estimator.train() will now handle the boilerplate necessary to initialize a file writer and enable summary writing every N steps, and will ensure that its own automatically exported summaries (for loss and global_step/sec) get written to the same underlying events file. As part of this change, tf.train.SummarySaverHook, tf.train.CheckpointSaverHook, tf.train.StepCounterHook, and tf.train.ProfilerHook have also been adapted to write summaries using the v2 summary system (via a compatibility layer), instead of using FileWriterCache. A couple additional smaller changes are: - the 'session' parameter to FileWriter() can now be a callable returning a tf.Session instance. - the introduction of tf.contrib.summary.record_summaries_if() which takes a boolean tensor for direct control of tf.contrib.summary.should_record_summaries(). - EstimatorSpec.train_op, besides a tf.Operation, is now allowed to be any Tensor-equivalent object rather than just a tf.Tensor. PiperOrigin-RevId: 205843986
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r--tensorflow/python/estimator/estimator_test.py260
1 files changed, 247 insertions, 13 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 8bc410ba0b..1dd45a07c2 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -22,6 +22,7 @@ import functools
import glob
import os
import tempfile
+import time
import numpy as np
import six
@@ -29,6 +30,7 @@ import six
from google.protobuf import text_format
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.core.util.event_pb2 import SessionLog
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
@@ -40,6 +42,7 @@ from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
@@ -55,6 +58,7 @@ from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
@@ -85,13 +89,32 @@ def dummy_model_fn(features, labels, params):
_, _, _ = features, labels, params
-def summaries_with_matching_keyword(keyword, dir_):
- """Yields summary protos matching given keyword from event file."""
-
+def load_eventfile_contents(directory_path):
+ """Returns the contents of the singular event file in the given directory."""
writer_cache.FileWriterCache.clear()
- event_paths = glob.glob(os.path.join(dir_, 'events*'))
- for event in summary_iterator.summary_iterator(event_paths[-1]):
+ # Get last Event written.
+ event_paths = glob.glob(os.path.join(directory_path, '*tfevent*'))
+ if len(event_paths) != 1:
+ raise AssertionError('Expected one eventfile, got %s' % str(event_paths))
+ return list(summary_iterator.summary_iterator(event_paths[0]))
+
+
+def make_summary_steps(eventlist):
+ """Returns dict of tags in eventlist mapped to steps where they're logged."""
+ tag_to_steps = {}
+ for event in eventlist:
+ if event.summary is not None:
+ for value in event.summary.value:
+ if value.tag not in tag_to_steps:
+ tag_to_steps[value.tag] = []
+ tag_to_steps[value.tag].append(event.step)
+ return tag_to_steps
+
+
+def summaries_with_matching_keyword(keyword, dir_):
+ """Yields summary protos matching given keyword from event file."""
+ for event in load_eventfile_contents(dir_):
if event.summary is not None:
for value in event.summary.value:
if keyword in value.tag:
@@ -366,13 +389,51 @@ def dummy_input_fn():
constant_op.constant([[1], [1]]))
+class StableGlobalStepEstimator(estimator.Estimator):
+ """Estimator subclass using a ResourceVariable global_step for testing."""
+ # TODO(nickfelt): remove after standard global_step is a ResourceVariable.
+
+ def _create_global_step(self, graph):
+ """Creates a stable ResourceVariable-based global step suitable for tests.
+
+ Args:
+ graph: The graph in which to create the global step.
+
+ Returns:
+ A global step `Tensor`.
+ """
+ with graph.as_default(), graph.name_scope(None):
+ return variable_scope.get_variable(
+ ops.GraphKeys.GLOBAL_STEP,
+ shape=[],
+ dtype=dtypes.int64,
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ collections=[
+ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
+ ],
+ # Use a ResourceVariable and set caching_device to make the read
+ # behavior deterministic and well-defined.
+ caching_device='cpu:0',
+ use_resource=True)
+
+
def model_fn_global_step_incrementer(features, labels, mode):
_, _ = features, labels
- global_step = training.get_global_step()
return model_fn_lib.EstimatorSpec(
mode,
loss=constant_op.constant(1.),
- train_op=state_ops.assign_add(global_step, 1))
+ train_op=training.get_global_step().assign_add(1))
+
+
+def model_fn_with_v1_and_v2_summaries(features, labels, mode):
+ del features, labels
+ summary.scalar('foo-v1', 1.0)
+ summary_ops_v2.scalar('foo-v2', 2.0)
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=constant_op.constant(1.),
+ train_op=training.get_global_step().assign_add(1))
def assert_features_op(expected_features, actual_features):
@@ -408,6 +469,25 @@ def _make_input_fn(features, labels):
return _input_fn
+class RaiseOnceAtStepHook(session_run_hook.SessionRunHook):
+ """Hook that raises an Exception the first time it reaches step N."""
+
+ def __init__(self, n, ex):
+ self.n = n
+ self.ex = ex
+ self.raised = False
+
+ def before_run(self, run_context):
+ # Raise the first time we reach step N.
+ self.n -= 1
+ if 0 == self.n and not self.raised:
+ # Wait 1 sec so that event file names have different UNIX timestamps.
+ time.sleep(1.2)
+ self.raised = True
+ raise self.ex
+ return None
+
+
class EstimatorTrainTest(test.TestCase):
def test_callable_model_fn(self):
@@ -617,17 +697,171 @@ class EstimatorTrainTest(test.TestCase):
self.assertEqual(
5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
- def test_loss_summary(self):
+ def test_summary_loss(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer,
config=run_config.RunConfig(save_summary_steps=1))
est.train(dummy_input_fn, steps=1)
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual({'loss': [1]}, make_summary_steps(events))
- # Make sure nothing is stuck in limbo.
- writer_cache.FileWriterCache.clear()
+ def test_summary_user_defined_v1_and_v2(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(save_summary_steps=1))
+ est.train(dummy_input_fn, steps=1)
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual(
+ {'foo-v1': [1], 'foo-v2': [0], 'loss': [1]},
+ make_summary_steps(events))
- if check_eventfile_for_keyword('loss', est.model_dir):
- return
- self.fail('{} should be part of reported summaries.'.format('loss'))
+ def test_summary_writing_disabled(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(save_summary_steps=0))
+ est.train(dummy_input_fn, steps=1)
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual({}, make_summary_steps(events))
+
+ def test_summary_saving_steps(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(save_summary_steps=2))
+ est.train(dummy_input_fn, steps=5)
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual(
+ {'foo-v1': [1, 3, 5], 'foo-v2': [0, 2, 4], 'loss': [1, 3, 5]},
+ make_summary_steps(events))
+
+ def test_summary_additional_hook(self):
+ def model_fn_extra_summary_hook(features, labels, mode, config):
+ del features, labels
+ v1_op = summary.scalar('foo-v1', 1.0)
+ v2_op = summary_ops_v2.scalar('foo-v2', 2.0)
+ extra_hook = basic_session_run_hooks.SummarySaverHook(
+ output_dir=os.path.join(config.model_dir, 'extra'),
+ save_steps=3,
+ summary_op=control_flow_ops.with_dependencies([v2_op], v1_op))
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=constant_op.constant(1.),
+ train_op=training.get_global_step().assign_add(1),
+ training_hooks=[extra_hook])
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_extra_summary_hook,
+ config=run_config.RunConfig(save_summary_steps=2))
+ est.train(dummy_input_fn, steps=7)
+
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual(
+ {'foo-v1': [1, 3, 5, 7], 'foo-v2': [0, 2, 4, 6], 'loss': [1, 3, 5, 7]},
+ make_summary_steps(events))
+ extra_dir = os.path.join(est.model_dir, 'extra')
+ extra_events = load_eventfile_contents(extra_dir)
+ self.assertEqual({'foo-v1': [1, 4, 7]}, make_summary_steps(extra_events))
+
+ def test_summary_user_defined_in_input_fn(self):
+ def input_fn_custom_summaries():
+ summary.scalar('foo-v1', 1.0)
+ summary_ops_v2.scalar('foo-v2', 2.0)
+ return ({'x': constant_op.constant([[1], [1]])},
+ constant_op.constant([[1], [1]]))
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig(save_summary_steps=1))
+ est.train(input_fn_custom_summaries, steps=1)
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual(
+ {'foo-v1': [1], 'foo-v2': [0], 'loss': [1]},
+ make_summary_steps(events))
+
+ def test_summary_with_warm_start(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(save_summary_steps=1))
+ est.train(dummy_input_fn, steps=5)
+ warm_started_est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(save_summary_steps=1),
+ warm_start_from=est.model_dir)
+ warm_started_est.train(dummy_input_fn, steps=3)
+ events = load_eventfile_contents(warm_started_est.model_dir)
+ self.assertEqual(
+ {'foo-v1': [1, 2, 3], 'foo-v2': [0, 1, 2], 'loss': [1, 2, 3]},
+ make_summary_steps(events))
+
+ def test_summary_with_error_and_auto_restart(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(
+ save_summary_steps=2, save_checkpoints_steps=5))
+ abort_hook = RaiseOnceAtStepHook(
+ 7, errors_impl.AbortedError(None, None, 'Abort'))
+ est.train(dummy_input_fn, steps=10, hooks=[abort_hook])
+
+ # We expect two event files: one for the aborted run, and one post-restart.
+ event_paths = sorted(glob.glob(os.path.join(est.model_dir, '*tfevent*')))
+ self.assertEqual(2, len(event_paths))
+
+ # First file should have summaries up to the last checkpoint.
+ first_events = list(summary_iterator.summary_iterator(event_paths[0]))
+ first_summaries = make_summary_steps(first_events)
+ self.assertEqual([0, 2, 4], first_summaries['foo-v2'])
+ # The V1 summaries may or may not include step 5 (depending on the flush()
+ # sequence) so just check that at least 1 and 3 are there.
+ # TODO(nickfelt): ensure summaries *at* checkpoint step get flushed too.
+ self.assertEqual([1, 3], first_summaries['foo-v1'][:2])
+ self.assertEqual([1, 3], first_summaries['loss'][:2])
+
+ # Second file should pick up from global_step=5. Note that the 2 step save
+ # interval will reset at this step as well, so summaries logged at steps
+ # 2 and 4 continue not with 6, 8, ... but at steps 5, 7, ... instead.
+ second_events = list(summary_iterator.summary_iterator(event_paths[1]))
+ self.assertEqual(
+ {'foo-v1': [6, 8, 10], 'foo-v2': [5, 7, 9], 'loss': [6, 8, 10]},
+ make_summary_steps(second_events))
+ # Second file should contain a session START event at resumed global_step.
+ session_start_event = next(event for event in second_events
+ if event.session_log.status == SessionLog.START)
+ self.assertEqual(5, session_start_event.step)
+
+ def test_summary_with_error_and_explicit_restart(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(
+ save_summary_steps=2, save_checkpoints_steps=5))
+ abort_hook = RaiseOnceAtStepHook(
+ 7, errors_impl.UnknownError(None, None, 'Unknown failure'))
+ self.assertRaises(
+ errors_impl.UnknownError,
+ lambda: est.train(dummy_input_fn, max_steps=10, hooks=[abort_hook]))
+ # Explicitly retry after the error.
+ est.train(dummy_input_fn, max_steps=10, hooks=[abort_hook])
+
+ # We expect two event files: one for the failed run, and one post-restart.
+ event_paths = sorted(glob.glob(os.path.join(est.model_dir, '*tfevent*')))
+ self.assertEqual(2, len(event_paths))
+
+ # First file should have summaries up to the last checkpoint.
+ first_events = list(summary_iterator.summary_iterator(event_paths[0]))
+ first_summaries = make_summary_steps(first_events)
+ self.assertEqual([0, 2, 4], first_summaries['foo-v2'])
+ # The V1 summaries may or may not include step 5 (depending on the flush()
+ # sequence) so just check that at least 1 and 3 are there.
+ # TODO(nickfelt): ensure summaries *at* checkpoint step get flushed too.
+ self.assertEqual([1, 3], first_summaries['foo-v1'][:2])
+ self.assertEqual([1, 3], first_summaries['loss'][:2])
+
+ # Second file should pick up from global_step=5. Note that the 2 step save
+ # interval will reset at this step as well, so summaries logged at steps
+ # 2 and 4 continue not with 6, 8, ... but at steps 5, 7, ... instead.
+ second_events = list(summary_iterator.summary_iterator(event_paths[1]))
+ self.assertEqual(
+ {'foo-v1': [6, 8, 10], 'foo-v2': [5, 7, 9], 'loss': [6, 8, 10]},
+ make_summary_steps(second_events))
+ # Second file should contain a session START event at resumed global_step.
+ session_start_event = next(event for event in second_events
+ if event.session_log.status == SessionLog.START)
+ self.assertEqual(5, session_start_event.step)
def test_latest_checkpoint(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)