aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/summary/summary_ops_graph_test.py20
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py17
-rw-r--r--tensorflow/core/kernels/summary_kernels.cc2
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/estimator/estimator.py24
-rw-r--r--tensorflow/python/estimator/estimator_test.py260
-rw-r--r--tensorflow/python/estimator/model_fn.py3
-rw-r--r--tensorflow/python/estimator/training_test.py10
-rw-r--r--tensorflow/python/ops/summary_ops_v2.py68
-rw-r--r--tensorflow/python/saved_model/builder_impl.py5
-rw-r--r--tensorflow/python/summary/writer/event_file_writer_v2.py71
-rw-r--r--tensorflow/python/summary/writer/writer.py8
-rw-r--r--tensorflow/python/summary/writer/writer_test.py54
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py182
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py476
-rw-r--r--tensorflow/python/training/monitored_session.py11
-rw-r--r--tensorflow/python/training/optimizer.py6
17 files changed, 939 insertions, 279 deletions
diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py
index ae8336daaf..409fdf4583 100644
--- a/tensorflow/contrib/summary/summary_ops_graph_test.py
+++ b/tensorflow/contrib/summary/summary_ops_graph_test.py
@@ -228,6 +228,26 @@ class GraphFileTest(test_util.TensorFlowTestCase):
sess.run(writer.flush())
self.assertEqual(2, get_total())
+ def testSummaryOpsCollector(self):
+ summary_ops.scalar('x', 1.0, step=1)
+ with summary_ops.create_file_writer(self.get_temp_dir()).as_default():
+ s2 = summary_ops.scalar('x', 1.0, step=1)
+ collector1 = summary_ops._SummaryOpsCollector()
+ collector2 = summary_ops._SummaryOpsCollector()
+ with collector1.capture():
+ s3 = summary_ops.scalar('x', 1.0, step=1)
+ with collector2.capture():
+ s4 = summary_ops.scalar('x', 1.0, step=1)
+ s5 = summary_ops.scalar('x', 1.0, step=1)
+ s6 = summary_ops.scalar('x', 1.0, step=1)
+ summary_ops.scalar('six', 1.0, step=1)
+
+ # Ops defined outside summary writer context are ignored; ops defined inside
+ # SummaryOpsCollector capture context are stored to innermost such context.
+ self.assertItemsEqual([s2, s6], summary_ops.all_summary_ops())
+ self.assertItemsEqual([s3, s5], collector1.collected_ops)
+ self.assertItemsEqual([s4], collector2.collected_ops)
+
class GraphDbTest(summary_test_util.SummaryDbTest):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 42406db88a..1eb43ac7f7 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -1506,13 +1506,17 @@ class _OutfeedHostCall(object):
_OutfeedHostCall.validate(host_calls)
ret = {}
for name, host_call in host_calls.items():
+ # Isolate host call summary ops from main graph.
+ summary_collector = contrib_summary._SummaryOpsCollector() # pylint: disable=protected-access
host_fn, tensors = host_call
if isinstance(tensors, (tuple, list)):
- ret[name] = host_fn(*tensors)
+ with summary_collector.capture():
+ ret[name] = host_fn(*tensors)
else:
# Must be dict.
try:
- ret[name] = host_fn(**tensors)
+ with summary_collector.capture():
+ ret[name] = host_fn(**tensors)
except TypeError as e:
logging.warning(
'Exception while calling %s: %s. It is likely the tensors '
@@ -1627,11 +1631,14 @@ class _OutfeedHostCall(object):
# dimension.
dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0)
+ # Isolate host call summary ops from main graph.
+ summary_collector = contrib_summary._SummaryOpsCollector() # pylint: disable=protected-access
if self._tensor_keys[name] is not None:
# The user-provided eval_metrics[1] is a dict.
dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops))
try:
- ret[name] = self._host_fns[name](**dequeue_ops)
+ with summary_collector.capture():
+ ret[name] = self._host_fns[name](**dequeue_ops)
except TypeError as e:
logging.warning(
'Exception while calling %s: %s. It is likely the tensors '
@@ -1639,8 +1646,8 @@ class _OutfeedHostCall(object):
'function\'s arguments', name, e, name)
raise e
else:
- ret[name] = self._host_fns[name](*dequeue_ops)
-
+ with summary_collector.capture():
+ ret[name] = self._host_fns[name](*dequeue_ops)
return ret
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index b287f0cc2f..b518c3cbf4 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -53,6 +53,7 @@ class CreateSummaryFileWriterOp : public OpKernel {
max_queue, flush_millis, logdir,
filename_suffix, ctx->env(), s);
}));
+ core::ScopedUnref unref(s);
}
};
REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
@@ -89,6 +90,7 @@ class CreateSummaryDbWriterOp : public OpKernel {
db, experiment_name, run_name, user_name, ctx->env(), s));
return Status::OK();
}));
+ core::ScopedUnref unref(s);
}
};
REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 814239533c..b5a0051c28 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2822,6 +2822,7 @@ py_library(
":framework_ops",
":math_ops",
":resource_variable_ops",
+ ":resources",
":smart_cond",
":summary_op_util",
":summary_ops_gen",
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 915ceeb98b..b7185e8966 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -46,6 +46,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import resources
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
@@ -65,6 +66,7 @@ from tensorflow.python.util import compat
from tensorflow.python.util import compat_internal
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
+from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import estimator_export
@@ -1156,7 +1158,8 @@ class Estimator(object):
Loss from training
"""
worker_hooks = []
- with ops.Graph().as_default() as g, g.device(self._device_fn):
+ with ops.Graph().as_default() as g, g.device(
+ self._device_fn), self._summary_writing_context():
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
training_util._get_or_create_global_step_read() # pylint: disable=protected-access
@@ -1190,7 +1193,7 @@ class Estimator(object):
is_tpu_strategy = self._distribution.__class__.__name__ == 'TPUStrategy'
worker_hooks = []
- with ops.Graph().as_default() as g:
+ with ops.Graph().as_default() as g, self._summary_writing_context():
with self._distribution.scope():
random_seed.set_random_seed(self._config.tf_random_seed)
@@ -1519,6 +1522,23 @@ class Estimator(object):
(self._warm_start_settings,))
warm_starting_util.warm_start(*self._warm_start_settings)
+ @tf_contextlib.contextmanager
+ def _summary_writing_context(self):
+ """Context manager for enabling V2 summary writing."""
+ # Avoid creating a file writer at all if no summary writing was requested.
+ if self._config.save_summary_steps <= 0:
+ yield
+ return
+ file_writer = summary_ops_v2.create_file_writer(
+ logdir=self._model_dir, filename_suffix='')
+ with file_writer.as_default():
+ # Create a boolean placeholder, default False, that SummarySaverHook can
+ # use to enable/disable V2 summary writing according to its own logic.
+ placeholder = array_ops.placeholder_with_default(False, shape=[])
+ training.SummarySaverHook._set_placeholder(placeholder) # pylint: disable=protected-access
+ with summary_ops_v2.record_summaries_if(placeholder):
+ yield
+
def create_per_tower_ready_op(scaffold):
"""Create a Scaffold.ready_op inside a tower."""
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 8bc410ba0b..1dd45a07c2 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -22,6 +22,7 @@ import functools
import glob
import os
import tempfile
+import time
import numpy as np
import six
@@ -29,6 +30,7 @@ import six
from google.protobuf import text_format
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.core.util.event_pb2 import SessionLog
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
@@ -40,6 +42,7 @@ from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
@@ -55,6 +58,7 @@ from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
@@ -85,13 +89,32 @@ def dummy_model_fn(features, labels, params):
_, _, _ = features, labels, params
-def summaries_with_matching_keyword(keyword, dir_):
- """Yields summary protos matching given keyword from event file."""
-
+def load_eventfile_contents(directory_path):
+ """Returns the contents of the singular event file in the given directory."""
writer_cache.FileWriterCache.clear()
- event_paths = glob.glob(os.path.join(dir_, 'events*'))
- for event in summary_iterator.summary_iterator(event_paths[-1]):
+ # Get last Event written.
+ event_paths = glob.glob(os.path.join(directory_path, '*tfevent*'))
+ if len(event_paths) != 1:
+ raise AssertionError('Expected one eventfile, got %s' % str(event_paths))
+ return list(summary_iterator.summary_iterator(event_paths[0]))
+
+
+def make_summary_steps(eventlist):
+ """Returns dict of tags in eventlist mapped to steps where they're logged."""
+ tag_to_steps = {}
+ for event in eventlist:
+ if event.summary is not None:
+ for value in event.summary.value:
+ if value.tag not in tag_to_steps:
+ tag_to_steps[value.tag] = []
+ tag_to_steps[value.tag].append(event.step)
+ return tag_to_steps
+
+
+def summaries_with_matching_keyword(keyword, dir_):
+ """Yields summary protos matching given keyword from event file."""
+ for event in load_eventfile_contents(dir_):
if event.summary is not None:
for value in event.summary.value:
if keyword in value.tag:
@@ -366,13 +389,51 @@ def dummy_input_fn():
constant_op.constant([[1], [1]]))
+class StableGlobalStepEstimator(estimator.Estimator):
+ """Estimator subclass using a ResourceVariable global_step for testing."""
+ # TODO(nickfelt): remove after standard global_step is a ResourceVariable.
+
+ def _create_global_step(self, graph):
+ """Creates a stable ResourceVariable-based global step suitable for tests.
+
+ Args:
+ graph: The graph in which to create the global step.
+
+ Returns:
+ A global step `Tensor`.
+ """
+ with graph.as_default(), graph.name_scope(None):
+ return variable_scope.get_variable(
+ ops.GraphKeys.GLOBAL_STEP,
+ shape=[],
+ dtype=dtypes.int64,
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ collections=[
+ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
+ ],
+ # Use a ResourceVariable and set caching_device to make the read
+ # behavior deterministic and well-defined.
+ caching_device='cpu:0',
+ use_resource=True)
+
+
def model_fn_global_step_incrementer(features, labels, mode):
_, _ = features, labels
- global_step = training.get_global_step()
return model_fn_lib.EstimatorSpec(
mode,
loss=constant_op.constant(1.),
- train_op=state_ops.assign_add(global_step, 1))
+ train_op=training.get_global_step().assign_add(1))
+
+
+def model_fn_with_v1_and_v2_summaries(features, labels, mode):
+ del features, labels
+ summary.scalar('foo-v1', 1.0)
+ summary_ops_v2.scalar('foo-v2', 2.0)
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=constant_op.constant(1.),
+ train_op=training.get_global_step().assign_add(1))
def assert_features_op(expected_features, actual_features):
@@ -408,6 +469,25 @@ def _make_input_fn(features, labels):
return _input_fn
+class RaiseOnceAtStepHook(session_run_hook.SessionRunHook):
+ """Hook that raises an Exception the first time it reaches step N."""
+
+ def __init__(self, n, ex):
+ self.n = n
+ self.ex = ex
+ self.raised = False
+
+ def before_run(self, run_context):
+ # Raise the first time we reach step N.
+ self.n -= 1
+ if 0 == self.n and not self.raised:
+ # Wait 1 sec so that event file names have different UNIX timestamps.
+ time.sleep(1.2)
+ self.raised = True
+ raise self.ex
+ return None
+
+
class EstimatorTrainTest(test.TestCase):
def test_callable_model_fn(self):
@@ -617,17 +697,171 @@ class EstimatorTrainTest(test.TestCase):
self.assertEqual(
5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
- def test_loss_summary(self):
+ def test_summary_loss(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer,
config=run_config.RunConfig(save_summary_steps=1))
est.train(dummy_input_fn, steps=1)
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual({'loss': [1]}, make_summary_steps(events))
- # Make sure nothing is stuck in limbo.
- writer_cache.FileWriterCache.clear()
+ def test_summary_user_defined_v1_and_v2(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(save_summary_steps=1))
+ est.train(dummy_input_fn, steps=1)
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual(
+ {'foo-v1': [1], 'foo-v2': [0], 'loss': [1]},
+ make_summary_steps(events))
- if check_eventfile_for_keyword('loss', est.model_dir):
- return
- self.fail('{} should be part of reported summaries.'.format('loss'))
+ def test_summary_writing_disabled(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(save_summary_steps=0))
+ est.train(dummy_input_fn, steps=1)
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual({}, make_summary_steps(events))
+
+ def test_summary_saving_steps(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(save_summary_steps=2))
+ est.train(dummy_input_fn, steps=5)
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual(
+ {'foo-v1': [1, 3, 5], 'foo-v2': [0, 2, 4], 'loss': [1, 3, 5]},
+ make_summary_steps(events))
+
+ def test_summary_additional_hook(self):
+ def model_fn_extra_summary_hook(features, labels, mode, config):
+ del features, labels
+ v1_op = summary.scalar('foo-v1', 1.0)
+ v2_op = summary_ops_v2.scalar('foo-v2', 2.0)
+ extra_hook = basic_session_run_hooks.SummarySaverHook(
+ output_dir=os.path.join(config.model_dir, 'extra'),
+ save_steps=3,
+ summary_op=control_flow_ops.with_dependencies([v2_op], v1_op))
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=constant_op.constant(1.),
+ train_op=training.get_global_step().assign_add(1),
+ training_hooks=[extra_hook])
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_extra_summary_hook,
+ config=run_config.RunConfig(save_summary_steps=2))
+ est.train(dummy_input_fn, steps=7)
+
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual(
+ {'foo-v1': [1, 3, 5, 7], 'foo-v2': [0, 2, 4, 6], 'loss': [1, 3, 5, 7]},
+ make_summary_steps(events))
+ extra_dir = os.path.join(est.model_dir, 'extra')
+ extra_events = load_eventfile_contents(extra_dir)
+ self.assertEqual({'foo-v1': [1, 4, 7]}, make_summary_steps(extra_events))
+
+ def test_summary_user_defined_in_input_fn(self):
+ def input_fn_custom_summaries():
+ summary.scalar('foo-v1', 1.0)
+ summary_ops_v2.scalar('foo-v2', 2.0)
+ return ({'x': constant_op.constant([[1], [1]])},
+ constant_op.constant([[1], [1]]))
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig(save_summary_steps=1))
+ est.train(input_fn_custom_summaries, steps=1)
+ events = load_eventfile_contents(est.model_dir)
+ self.assertEqual(
+ {'foo-v1': [1], 'foo-v2': [0], 'loss': [1]},
+ make_summary_steps(events))
+
+ def test_summary_with_warm_start(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(save_summary_steps=1))
+ est.train(dummy_input_fn, steps=5)
+ warm_started_est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(save_summary_steps=1),
+ warm_start_from=est.model_dir)
+ warm_started_est.train(dummy_input_fn, steps=3)
+ events = load_eventfile_contents(warm_started_est.model_dir)
+ self.assertEqual(
+ {'foo-v1': [1, 2, 3], 'foo-v2': [0, 1, 2], 'loss': [1, 2, 3]},
+ make_summary_steps(events))
+
+ def test_summary_with_error_and_auto_restart(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(
+ save_summary_steps=2, save_checkpoints_steps=5))
+ abort_hook = RaiseOnceAtStepHook(
+ 7, errors_impl.AbortedError(None, None, 'Abort'))
+ est.train(dummy_input_fn, steps=10, hooks=[abort_hook])
+
+ # We expect two event files: one for the aborted run, and one post-restart.
+ event_paths = sorted(glob.glob(os.path.join(est.model_dir, '*tfevent*')))
+ self.assertEqual(2, len(event_paths))
+
+ # First file should have summaries up to the last checkpoint.
+ first_events = list(summary_iterator.summary_iterator(event_paths[0]))
+ first_summaries = make_summary_steps(first_events)
+ self.assertEqual([0, 2, 4], first_summaries['foo-v2'])
+ # The V1 summaries may or may not include step 5 (depending on the flush()
+ # sequence) so just check that at least 1 and 3 are there.
+ # TODO(nickfelt): ensure summaries *at* checkpoint step get flushed too.
+ self.assertEqual([1, 3], first_summaries['foo-v1'][:2])
+ self.assertEqual([1, 3], first_summaries['loss'][:2])
+
+ # Second file should pick up from global_step=5. Note that the 2 step save
+ # interval will reset at this step as well, so summaries logged at steps
+ # 2 and 4 continue not with 6, 8, ... but at steps 5, 7, ... instead.
+ second_events = list(summary_iterator.summary_iterator(event_paths[1]))
+ self.assertEqual(
+ {'foo-v1': [6, 8, 10], 'foo-v2': [5, 7, 9], 'loss': [6, 8, 10]},
+ make_summary_steps(second_events))
+ # Second file should contain a session START event at resumed global_step.
+ session_start_event = next(event for event in second_events
+ if event.session_log.status == SessionLog.START)
+ self.assertEqual(5, session_start_event.step)
+
+ def test_summary_with_error_and_explicit_restart(self):
+ est = StableGlobalStepEstimator(
+ model_fn=model_fn_with_v1_and_v2_summaries,
+ config=run_config.RunConfig(
+ save_summary_steps=2, save_checkpoints_steps=5))
+ abort_hook = RaiseOnceAtStepHook(
+ 7, errors_impl.UnknownError(None, None, 'Unknown failure'))
+ self.assertRaises(
+ errors_impl.UnknownError,
+ lambda: est.train(dummy_input_fn, max_steps=10, hooks=[abort_hook]))
+ # Explicitly retry after the error.
+ est.train(dummy_input_fn, max_steps=10, hooks=[abort_hook])
+
+ # We expect two event files: one for the failed run, and one post-restart.
+ event_paths = sorted(glob.glob(os.path.join(est.model_dir, '*tfevent*')))
+ self.assertEqual(2, len(event_paths))
+
+ # First file should have summaries up to the last checkpoint.
+ first_events = list(summary_iterator.summary_iterator(event_paths[0]))
+ first_summaries = make_summary_steps(first_events)
+ self.assertEqual([0, 2, 4], first_summaries['foo-v2'])
+ # The V1 summaries may or may not include step 5 (depending on the flush()
+ # sequence) so just check that at least 1 and 3 are there.
+ # TODO(nickfelt): ensure summaries *at* checkpoint step get flushed too.
+ self.assertEqual([1, 3], first_summaries['foo-v1'][:2])
+ self.assertEqual([1, 3], first_summaries['loss'][:2])
+
+ # Second file should pick up from global_step=5. Note that the 2 step save
+ # interval will reset at this step as well, so summaries logged at steps
+ # 2 and 4 continue not with 6, 8, ... but at steps 5, 7, ... instead.
+ second_events = list(summary_iterator.summary_iterator(event_paths[1]))
+ self.assertEqual(
+ {'foo-v1': [6, 8, 10], 'foo-v2': [5, 7, 9], 'loss': [6, 8, 10]},
+ make_summary_steps(second_events))
+ # Second file should contain a session START event at resumed global_step.
+ session_start_event = next(event for event in second_events
+ if event.session_log.status == SessionLog.START)
+ self.assertEqual(5, session_start_event.step)
def test_latest_checkpoint(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index a9fd8f8e1a..b1b2f65edf 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -26,6 +26,7 @@ import six
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
@@ -432,7 +433,7 @@ class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
def _check_is_tensor_or_operation(x, name):
- if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)):
+ if not (isinstance(x, ops.Operation) or tensor_util.is_tensor(x)):
raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x))
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index dc106c7d3b..121439a2cd 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -2059,7 +2059,7 @@ class TrainAndEvaluateIntegrationTest(test.TestCase):
def _extract_loss_and_global_step(self, event_folder):
"""Returns the loss and global step in last event."""
- event_paths = glob.glob(os.path.join(event_folder, 'events*'))
+ event_paths = sorted(glob.glob(os.path.join(event_folder, 'events*')))
loss = None
global_step_count = None
@@ -2139,10 +2139,12 @@ class TrainAndEvaluateIntegrationTest(test.TestCase):
# Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear()
- # Examine the training events. Use a range to check global step to avoid
- # flakyness due to global step race condition.
- training_loss, _ = self._extract_loss_and_global_step(est.model_dir)
+ # Examine the training events.
+ training_loss, training_global_step = self._extract_loss_and_global_step(
+ est.model_dir)
self.assertIsNotNone(training_loss)
+ # Training summaries are logged for steps 1 and 10, so we see final step.
+ self.assertEqual(max_steps, training_global_step)
# Examine the eval events. The global step should be accurate.
eval_loss, eval_global_step = self._extract_loss_and_global_step(
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index 00150fe688..669358d9db 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_summary_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import resources
from tensorflow.python.ops import summary_op_util
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training_util
@@ -66,41 +67,39 @@ def should_record_summaries():
return should_record_collection[0]
+@tf_contextlib.contextmanager
+def always_record_summaries():
+ """Sets the should_record_summaries Tensor to always true."""
+ with record_summaries_if(True):
+ yield
+
+
+@tf_contextlib.contextmanager
+def never_record_summaries():
+ """Sets the should_record_summaries Tensor to always false."""
+ with record_summaries_if(False):
+ yield
+
+
# TODO(apassos) consider how to handle local step here.
@tf_contextlib.contextmanager
def record_summaries_every_n_global_steps(n, global_step=None):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
if global_step is None:
global_step = training_util.get_or_create_global_step()
- collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
- old = collection_ref[:]
- try:
- with ops.device("cpu:0"):
- collection_ref[:] = [math_ops.equal(global_step % n, 0)]
- yield
- finally:
- collection_ref[:] = old
-
-
-@tf_contextlib.contextmanager
-def always_record_summaries():
- """Sets the should_record_summaries Tensor to always true."""
- collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
- old = collection_ref[:]
- try:
- collection_ref[:] = [True]
+ with ops.device("cpu:0"):
+ on_nth_global_step = math_ops.equal(global_step % n, 0)
+ with record_summaries_if(on_nth_global_step):
yield
- finally:
- collection_ref[:] = old
@tf_contextlib.contextmanager
-def never_record_summaries():
- """Sets the should_record_summaries Tensor to always false."""
+def record_summaries_if(bool_value):
+ """Sets the should_record_summaries Tensor to the given boolean value."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
try:
- collection_ref[:] = [False]
+ collection_ref[:] = [bool_value]
yield
finally:
collection_ref[:] = old
@@ -143,7 +142,6 @@ class SummaryWriter(object):
finally:
context.context().summary_writer_resource = old
-
def init(self):
"""Operation to initialize the summary writer resource."""
if self._resource is not None:
@@ -311,6 +309,9 @@ def _make_summary_writer(name, factory, **kwargs):
# TODO(apassos): Consider doing this instead.
# ops.get_default_session().run(init_op)
ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op)
+ # TODO(nickfelt): expose an actual op for this
+ is_initialized_op = constant_op.constant(True)
+ resources.register_resource(resource, init_op, is_initialized_op)
return SummaryWriter(resource, init_op_fn)
@@ -325,6 +326,27 @@ def _nothing():
return constant_op.constant(False)
+class _SummaryOpsCollector(object):
+ """Defines a context manager for isolating out a subset of summary ops.
+
+ Summary ops defined within this context will be accumulated within this
+ collector instead of being added to the graph-wide summary ops collection that
+ is returned by {@tf.contrib.summary.all_summary_ops}.
+ """
+
+ def __init__(self):
+ self.collected_ops = []
+
+ @tf_contextlib.contextmanager
+ def capture(self):
+ collection_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
+ original_ops = collection_ref[:]
+ collection_ref[:] = []
+ yield
+ self.collected_ops = collection_ref[:]
+ collection_ref[:] = original_ops
+
+
def all_summary_ops():
"""Graph-mode only. Returns all summary ops.
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index e58be804c2..b67d0f2362 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -28,6 +28,7 @@ from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging
@@ -178,10 +179,10 @@ class SavedModelBuilder(object):
stored as a collection with key TRAIN_OP_KEY, but not executed.
Raises:
- TypeError if Train op is not of type `Operation`.
+ TypeError if Train op is not of type `Operation` or a Tensor.
"""
if train_op is not None:
- if (not isinstance(train_op, ops.Tensor) and
+ if (not tensor_util.is_tensor(train_op) and
not isinstance(train_op, ops.Operation)):
raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
diff --git a/tensorflow/python/summary/writer/event_file_writer_v2.py b/tensorflow/python/summary/writer/event_file_writer_v2.py
index 5c66c0f7a8..262182d3b8 100644
--- a/tensorflow/python/summary/writer/event_file_writer_v2.py
+++ b/tensorflow/python/summary/writer/event_file_writer_v2.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -43,11 +44,11 @@ class EventFileWriterV2(object):
"""Creates an `EventFileWriterV2` and an event file to write to.
On construction, this calls `tf.contrib.summary.create_file_writer` within
- the graph from `session.graph` to look up a shared summary writer resource
- for `logdir` if one exists, and create one if not. Creating the summary
+ the default graph, which finds and returns a shared summary writer resource
+ for `logdir` if one exists, and creates one if not. Creating the summary
writer resource in turn creates a new event file in `logdir` to be filled
with `Event` protocol buffers passed to `add_event`. Graph ops to control
- this writer resource are added to `session.graph` during this init call;
+ this writer resource are added to the default graph during this init call;
stateful methods on this class will call `session.run()` on these ops.
Note that because the underlying resource is shared, it is possible that
@@ -61,38 +62,50 @@ class EventFileWriterV2(object):
no effect. See `tf.contrib.summary.create_file_writer` for details.
Args:
- session: A `tf.Session`. Session that will hold shared writer resource.
- The writer ops will be added to session.graph during this init call.
+ session: A `tf.Session`, or a callable that provides one which will be
+ called on-demand. The session will hold the shared writer resource.
logdir: A string. Directory where event file will be written.
max_queue: Integer. Size of the queue for pending events and summaries.
flush_secs: Number. How often, in seconds, to flush the
pending events and summaries to disk.
filename_suffix: A string. Every event file's name is suffixed with
`filename_suffix`.
+
+ Raises:
+ ValueError: if `session` is not a `tf.Session` or a callable
"""
- self._session = session
+ if isinstance(session, tf_session.SessionInterface):
+ self._session = lambda: session
+ elif callable(session):
+ self._session = session
+ else:
+ raise ValueError('session must be tf.Session or callable')
self._logdir = logdir
+ self._initialized = False
self._closed = False
if not gfile.IsDirectory(self._logdir):
gfile.MakeDirs(self._logdir)
- with self._session.graph.as_default():
- with ops.name_scope('filewriter'):
- file_writer = summary_ops_v2.create_file_writer(
- logdir=self._logdir,
- max_queue=max_queue,
- flush_millis=flush_secs * 1000,
- filename_suffix=filename_suffix)
- with summary_ops_v2.always_record_summaries(), file_writer.as_default():
- self._event_placeholder = array_ops.placeholder_with_default(
- constant_op.constant('unused', dtypes.string),
- shape=[])
- self._add_event_op = summary_ops_v2.import_event(
- self._event_placeholder)
- self._init_op = file_writer.init()
- self._flush_op = file_writer.flush()
- self._close_op = file_writer.close()
- self._session.run(self._init_op)
+ with ops.name_scope('filewriter'):
+ file_writer = summary_ops_v2.create_file_writer(
+ logdir=self._logdir,
+ max_queue=max_queue,
+ flush_millis=flush_secs * 1000,
+ filename_suffix=filename_suffix)
+ with summary_ops_v2.always_record_summaries(), file_writer.as_default():
+ self._event_placeholder = array_ops.placeholder_with_default(
+ constant_op.constant('unused', dtypes.string),
+ shape=[])
+ self._add_event_op = summary_ops_v2.import_event(
+ self._event_placeholder)
+ self._init_op = file_writer.init()
+ self._flush_op = file_writer.flush()
+ self._close_op = file_writer.close()
+
+ def _init_if_needed(self):
+ if not self._initialized:
+ self._session().run(self._init_op)
+ self._initialized = True
def get_logdir(self):
"""Returns the directory where event file will be written."""
@@ -108,7 +121,6 @@ class EventFileWriterV2(object):
"""
if self._closed:
self._closed = False
- self._session.run(self._init_op)
def add_event(self, event):
"""Adds an event to the event file.
@@ -117,8 +129,9 @@ class EventFileWriterV2(object):
event: An `Event` protocol buffer.
"""
if not self._closed:
+ self._init_if_needed()
event_pb = event.SerializeToString()
- self._session.run(
+ self._session().run(
self._add_event_op, feed_dict={self._event_placeholder: event_pb})
def flush(self):
@@ -127,7 +140,9 @@ class EventFileWriterV2(object):
Call this method to make sure that all pending events have been written to
disk.
"""
- self._session.run(self._flush_op)
+ if not self._closed:
+ self._init_if_needed()
+ self._session().run(self._flush_op)
def close(self):
"""Flushes the event file to disk and close the file.
@@ -135,6 +150,8 @@ class EventFileWriterV2(object):
Call this method when you do not need the summary writer anymore.
"""
if not self._closed:
+ self._init_if_needed()
self.flush()
- self._session.run(self._close_op)
+ self._session().run(self._close_op)
self._closed = True
+ self._initialized = False
diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py
index aca084fc91..2a967ae3a5 100644
--- a/tensorflow/python/summary/writer/writer.py
+++ b/tensorflow/python/summary/writer/writer.py
@@ -332,8 +332,11 @@ class FileWriter(SummaryToEventTransformer):
the same shared resource name (which by default scoped to the logdir). If
no such resource exists, one will be created using the remaining arguments
to this constructor, but if one already exists those arguments are ignored.
- In either case, ops will be added to `session.graph` to control the
+ In either case, ops will be added to the default graph to control the
underlying file writer resource. See `tf.contrib.summary` for more details.
+ Instead of an actual `tf.Session`, this argument may also be a callable that
+ provides a `tf.Session` when invoked (e.g. `tf.get_default_session`), which
+ will be called on-demand when a session is needed.
Args:
logdir: A string. Directory where event file will be written.
@@ -344,7 +347,8 @@ class FileWriter(SummaryToEventTransformer):
graph_def: DEPRECATED: Use the `graph` argument instead.
filename_suffix: A string. Every event file's name is suffixed with
`suffix`.
- session: A `tf.Session` object. See details above.
+ session: A `tf.Session` object or a callable that provides `tf.Session`
+ objects. See details above.
Raises:
RuntimeError: If called with eager execution enabled.
diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py
index dc990c2602..3380dea317 100644
--- a/tensorflow/python/summary/writer/writer_test.py
+++ b/tensorflow/python/summary/writer/writer_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for training_coordinator.py."""
+"""Tests for writer.py."""
from __future__ import absolute_import
from __future__ import division
@@ -574,6 +574,58 @@ class SessionBasedFileWriterTestCase(FileWriterTestCase):
# No more files
self.assertRaises(StopIteration, lambda: next(event_paths))
+ def testSesssionArgument_callableProvider(self):
+ logdir = self.get_temp_dir()
+ setup_writer = summary_ops_v2.create_file_writer(logdir=logdir)
+ with summary_ops_v2.always_record_summaries(), setup_writer.as_default():
+ summary1 = summary_ops_v2.scalar("one", 0.0, step=0)
+ summary2 = summary_ops_v2.scalar("two", 0.0, step=0)
+ sess1 = session.Session()
+ sess1.run(setup_writer.init())
+ sess1.run(summary1)
+ sess1.run(setup_writer.flush())
+ time.sleep(1.1) # Ensure filename has a different timestamp
+ sess2 = session.Session()
+ sess2.run(setup_writer.init())
+ sess2.run(summary2)
+ sess2.run(setup_writer.flush())
+
+ # Using get_default_session as session provider should make this FileWriter
+ # send its summaries to the current default session's shared summary writer
+ # resource (initializing it as needed).
+ test_writer = writer.FileWriter(
+ session=ops.get_default_session, logdir=logdir)
+ with sess1.as_default():
+ test_writer.add_summary(self._createTaggedSummary("won"), 1)
+ test_writer.flush()
+ with sess2.as_default():
+ test_writer.add_summary(self._createTaggedSummary("too"), 1)
+ test_writer.flush()
+
+ event_paths = iter(sorted(glob.glob(os.path.join(logdir, "event*"))))
+
+ # First file should have tags "one", "won"
+ events = summary_iterator.summary_iterator(next(event_paths))
+ self.assertEqual("brain.Event:2", next(events).file_version)
+ self.assertEqual("one", next(events).summary.value[0].tag)
+ self.assertEqual("won", next(events).summary.value[0].tag)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ # Second file should have tags "two", "too"
+ events = summary_iterator.summary_iterator(next(event_paths))
+ self.assertEqual("brain.Event:2", next(events).file_version)
+ self.assertEqual("two", next(events).summary.value[0].tag)
+ self.assertEqual("too", next(events).summary.value[0].tag)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ # No more files
+ self.assertRaises(StopIteration, lambda: next(event_paths))
+
+ def testSessionArgument_notSessionOrCallable(self):
+ logdir = self.get_temp_dir()
+ self.assertRaises(
+ ValueError, lambda: writer.FileWriter(session=[], logdir=logdir))
+
class FileWriterCacheTest(test.TestCase):
"""FileWriterCache tests."""
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index b0dd188db1..b8df7fe51b 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -31,12 +31,13 @@ from tensorflow.python.client import timeline
from tensorflow.python.framework import errors
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.summary.writer import writer
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
-from tensorflow.python.training.summary_io import SummaryWriterCache
from tensorflow.python.util.tf_export import tf_export
@@ -422,7 +423,9 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
self._steps_per_run = steps_per_run
def begin(self):
- self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
+ self._summary_writer = writer.FileWriter(
+ self._checkpoint_dir, session=ops.get_default_session,
+ filename_suffix="")
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
@@ -431,10 +434,12 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
l.begin()
def after_create_session(self, session, coord):
+ del coord
+ # Ensure summary writer resource has been initialized.
+ session.run(summary_ops_v2.summary_writer_initializer_op())
global_step = session.run(self._global_step_tensor)
- # We do write graph and saver_def at the first call of before_run.
- # We cannot do this in begin, since we let other hooks to change graph and
- # add variables in begin. Graph is finalized after all begin calls.
+ # Write graph and saver_def once graph is finalized, which isn't true yet
+ # in begin() since later hooks can still change the graph.
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir,
@@ -444,8 +449,9 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True),
saver_def=saver_def)
- self._summary_writer.add_graph(graph)
- self._summary_writer.add_meta_graph(meta_graph_def)
+ with ops.default_session(session):
+ self._summary_writer.add_graph(graph)
+ self._summary_writer.add_meta_graph(meta_graph_def)
# The checkpoint saved here is the state at step "global_step".
self._save(session, global_step)
self._timer.update_last_triggered_step(global_step)
@@ -470,6 +476,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
self._save(session, last_step)
for l in self._listeners:
l.end(session, last_step)
+ with ops.default_session(session):
+ self._summary_writer.flush()
def _save(self, session, step):
"""Saves the latest checkpoint, returns should_stop."""
@@ -479,10 +487,12 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
l.before_save(session, step)
self._get_saver().save(session, self._save_path, global_step=step)
- self._summary_writer.add_session_log(
- SessionLog(
- status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
- step)
+ with ops.default_session(session):
+ self._summary_writer.add_session_log(
+ SessionLog(
+ status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
+ step)
+ self._summary_writer.flush()
should_stop = False
for l in self._listeners:
@@ -543,13 +553,23 @@ class StepCounterHook(session_run_hook.SessionRunHook):
def begin(self):
if self._summary_writer is None and self._output_dir:
- self._summary_writer = SummaryWriterCache.get(self._output_dir)
+ self._summary_writer = writer.FileWriter(
+ self._output_dir, session=ops.get_default_session,
+ filename_suffix="")
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use StepCounterHook.")
self._summary_tag = training_util.get_global_step().op.name + "/sec"
+ def after_create_session(self, session, coord):
+ del coord
+ # Reset any stale state in case we're recovering from a previous error.
+ session.run(summary_ops_v2.summary_writer_initializer_op())
+ self._last_global_step = None
+ self._global_step_check_count = 0
+ self._timer.reset()
+
def before_run(self, run_context): # pylint: disable=unused-argument
return SessionRunArgs(self._global_step_tensor)
@@ -562,8 +582,6 @@ class StepCounterHook(session_run_hook.SessionRunHook):
logging.info("%s: %g", self._summary_tag, steps_per_sec)
def after_run(self, run_context, run_values):
- _ = run_context
-
stale_global_step = run_values.results
if self._timer.should_trigger_for_step(
stale_global_step + self._steps_per_run):
@@ -573,7 +591,8 @@ class StepCounterHook(session_run_hook.SessionRunHook):
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
global_step)
if elapsed_time is not None:
- self._log_and_record(elapsed_steps, elapsed_time, global_step)
+ with ops.default_session(run_context.session):
+ self._log_and_record(elapsed_steps, elapsed_time, global_step)
# Check whether the global step has been increased. Here, we do not use the
# timer.last_triggered_step as the timer might record a different global
@@ -599,6 +618,11 @@ class StepCounterHook(session_run_hook.SessionRunHook):
self._last_global_step = stale_global_step
+ def end(self, session):
+ if self._summary_writer is not None:
+ with ops.default_session(session):
+ self._summary_writer.flush()
+
@tf_export("train.NanLossDuringTrainingError")
class NanLossDuringTrainingError(RuntimeError):
@@ -643,6 +667,25 @@ class NanTensorHook(session_run_hook.SessionRunHook):
class SummarySaverHook(session_run_hook.SessionRunHook):
"""Saves summaries every N steps."""
+ _SUMMARY_PLACEHOLDER_COLLECTION = "_SUMMARY_SAVER_PLACEHOLDER"
+
+ @classmethod
+ def _set_placeholder(cls, placeholder):
+ """Sets a `tf.placeholder` to be fed by the first SummarySaverHook.
+
+ If a placeholder is provided, the first instance of SummarySaverHook in use
+ will feed it a boolean indicating whether summaries should be written,
+ according to the `save_steps` and `save_secs` parameters of that hook. This
+ makes the placeholder usable with `tf.contrib.summary.record_summaries_if`
+ to control `tf.contrib.summary` summary writing using the same schedule as
+ the `tf.summary` summary writing (which the hook controls directly).
+
+ Args:
+ placeholder: `tf.placeholder` for the first SummarySaverHook to feed
+ """
+ collection = ops.get_collection_ref(cls._SUMMARY_PLACEHOLDER_COLLECTION)
+ collection[:] = [placeholder]
+
def __init__(self,
save_steps=None,
save_secs=None,
@@ -680,53 +723,82 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
self._scaffold = scaffold
self._timer = SecondOrStepTimer(every_secs=save_secs,
every_steps=save_steps)
+ self._placeholder = None
# TODO(mdan): Throw an error if output_dir and summary_writer are None.
def begin(self):
if self._summary_writer is None and self._output_dir:
- self._summary_writer = SummaryWriterCache.get(self._output_dir)
- self._next_step = None
- self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
+ self._summary_writer = writer.FileWriter(
+ self._output_dir, filename_suffix="", session=ops.get_default_session)
+ # Designate the first SummarySaverHook to call begin() as the "primary"
+ # hook; it will control writing of v2 summaries via a placeholder bool.
+ collection = ops.get_collection_ref(self._SUMMARY_PLACEHOLDER_COLLECTION)
+ if collection:
+ self._placeholder = collection[0]
+ collection[:] = []
+ self._current_step = None
+ self._global_step_tensor = training_util.get_or_create_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use SummarySaverHook.")
- def before_run(self, run_context): # pylint: disable=unused-argument
- self._request_summary = (
- self._next_step is None or
- self._timer.should_trigger_for_step(self._next_step))
+ def after_create_session(self, session, coord):
+ del coord
+ # Reset any stale state in case we're recovering from a previous error.
+ session.run(summary_ops_v2.summary_writer_initializer_op())
+ self._current_step = None
+ self._timer.reset()
+
+ def before_run(self, run_context):
+ # For the first run, record a SessionLog.START at the pre-run global step.
+ if self._current_step is None:
+ self._current_step = run_context.session.run(self._global_step_tensor)
+ with ops.default_session(run_context.session):
+ self._summary_writer.add_session_log(
+ SessionLog(status=SessionLog.START), self._current_step)
requests = {"global_step": self._global_step_tensor}
+ self._request_summary = self._timer.should_trigger_for_step(
+ self._current_step)
if self._request_summary:
+ self._timer.update_last_triggered_step(self._current_step)
if self._get_summary_op() is not None:
requests["summary"] = self._get_summary_op()
-
- return SessionRunArgs(requests)
+ feeds = {}
+ if self._placeholder is not None and self._request_summary:
+ feeds[self._placeholder] = self._request_summary
+ args = SessionRunArgs(fetches=requests, feed_dict=feeds)
+ return args
def after_run(self, run_context, run_values):
- _ = run_context
- if not self._summary_writer:
- return
-
+ # Collect any legacy v1 summaries to emit.
+ summaries_to_emit = []
+ if self._summary_writer and self._request_summary:
+ for summary in run_values.results.get("summary", []):
+ # Skip None results corresponding to V2 summary operations.
+ if summary is not None:
+ summaries_to_emit.append(summary)
+ # Heuristically estimate current step as possibly-stale value plus one.
stale_global_step = run_values.results["global_step"]
- global_step = stale_global_step + 1
- if self._next_step is None or self._request_summary:
- global_step = run_context.session.run(self._global_step_tensor)
-
- if self._next_step is None:
- self._summary_writer.add_session_log(
- SessionLog(status=SessionLog.START), global_step)
-
- if self._request_summary:
- self._timer.update_last_triggered_step(global_step)
- if "summary" in run_values.results:
- for summary in run_values.results["summary"]:
- self._summary_writer.add_summary(summary, global_step)
-
- self._next_step = global_step + 1
+ self._current_step = stale_global_step + 1
+ # Read the actual post-run global step if we need better accuracy because
+ # 1) we will request summaries on the next run (based on estimate now) and
+ # must ensure we record an accurate "last triggered step" value, or
+ # 2) we have legacy v1 summaries to emit using the post-run step value.
+ # Note: we could have dealt with (1) separately in before_run() but by doing
+ # it here we can consolidate the reads in case both (1) and (2) apply.
+ near_next_trigger = self._timer.should_trigger_for_step(self._current_step)
+ if near_next_trigger or summaries_to_emit:
+ self._current_step = run_context.session.run(self._global_step_tensor)
+ # Emit any legacy v1 summaries.
+ if summaries_to_emit:
+ with ops.default_session(run_context.session):
+ for summary in summaries_to_emit:
+ self._summary_writer.add_summary(summary, self._current_step)
def end(self, session=None):
- if self._summary_writer:
- self._summary_writer.flush()
+ if self._summary_writer and session:
+ with ops.default_session(session):
+ self._summary_writer.flush()
def _get_summary_op(self):
"""Fetches the summary op either from self._summary_op or self._scaffold.
@@ -893,19 +965,27 @@ class ProfilerHook(session_run_hook.SessionRunHook):
show_memory: `bool`, if True, add object snapshot events to the trace
showing the sizes and lifetimes of tensors.
"""
+ self._output_dir = output_dir
self._output_file = os.path.join(output_dir, "timeline-{}.json")
- self._file_writer = SummaryWriterCache.get(output_dir)
self._show_dataflow = show_dataflow
self._show_memory = show_memory
self._timer = SecondOrStepTimer(
every_secs=save_secs, every_steps=save_steps)
def begin(self):
+ self._file_writer = writer.FileWriter(
+ self._output_dir, filename_suffix="", session=ops.get_default_session)
self._next_step = None
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use ProfilerHook.")
+ def after_create_session(self, session, coord):
+ del coord
+ # Reset any stale state in case we're recovering from a previous error.
+ session.run(summary_ops_v2.summary_writer_initializer_op())
+ self._timer.reset()
+
def before_run(self, run_context):
self._request_summary = (
self._next_step is None or
@@ -925,8 +1005,10 @@ class ProfilerHook(session_run_hook.SessionRunHook):
self._save(global_step,
self._output_file.format(global_step),
run_values.run_metadata.step_stats)
- self._file_writer.add_run_metadata(run_values.run_metadata,
- "step_%d" % global_step)
+ with ops.default_session(run_context.session):
+ self._file_writer.add_run_metadata(run_values.run_metadata,
+ "step_%d" % global_step,
+ global_step=global_step)
self._next_step = global_step + 1
@@ -938,6 +1020,10 @@ class ProfilerHook(session_run_hook.SessionRunHook):
trace.generate_chrome_trace_format(
show_dataflow=self._show_dataflow, show_memory=self._show_memory))
+ def end(self, session):
+ with ops.default_session(session):
+ self._file_writer.flush()
+
def _as_graph_element(obj):
"""Retrieves Graph element."""
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index b49a871a56..b89167f3c1 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -19,8 +19,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import glob
import os.path
import shutil
+import sys
import tempfile
import threading
import time
@@ -28,6 +30,9 @@ import time
from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.testing.python.framework import fake_summary_writer
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.core.util.event_pb2 import SessionLog
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -35,9 +40,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
@@ -45,13 +53,27 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
from tensorflow.python.summary import summary as summary_lib
+from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import monitored_session
+from tensorflow.python.training import saver
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
+def load_eventfile_contents(directory_path):
+ """Returns the contents of the singular event file in the given directory."""
+ writer_cache.FileWriterCache.clear()
+
+ # Get last Event written.
+ event_paths = glob.glob(os.path.join(directory_path, '*tfevent*'))
+ if len(event_paths) != 1:
+ raise AssertionError('Expected one eventfile, got %s' % str(event_paths))
+ result = list(summary_iterator.summary_iterator(event_paths[0]))
+ return result
+
+
class MockCheckpointSaverListener(
basic_session_run_hooks.CheckpointSaverListener):
@@ -717,11 +739,12 @@ class CheckpointSaverHookTest(test.TestCase):
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
- def test_summary_writer_defs(self):
- fake_summary_writer.FakeSummaryWriter.install()
- writer_cache.FileWriterCache.clear()
- summary_writer = writer_cache.FileWriterCache.get(self.model_dir)
+ def _assertCheckpointEvent(self, event, step, checkpoint_path):
+ self.assertEqual(step, event.step)
+ self.assertEqual(SessionLog.CHECKPOINT, event.session_log.status)
+ self.assertEqual(checkpoint_path, event.session_log.checkpoint_path)
+ def test_summary_writer_defs(self):
with self.graph.as_default():
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir, save_steps=2, scaffold=self.scaffold)
@@ -730,18 +753,40 @@ class CheckpointSaverHookTest(test.TestCase):
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
- hook.after_create_session(sess, None)
- mon_sess.run(self.train_op)
- summary_writer.assert_summaries(
- test_case=self,
- expected_logdir=self.model_dir,
- expected_added_meta_graphs=[
- meta_graph.create_meta_graph_def(
- graph_def=self.graph.as_graph_def(add_shapes=True),
- saver_def=self.scaffold.saver.saver_def)
- ])
-
- fake_summary_writer.FakeSummaryWriter.uninstall()
+ hook.after_create_session(sess, None) # Checkpoint saved at step 0.
+ expected_graph_def = self.graph.as_graph_def(add_shapes=True)
+ expected_meta_graph_def = meta_graph.create_meta_graph_def(
+ graph_def=expected_graph_def,
+ saver_def=self.scaffold.saver.saver_def)
+ mon_sess.run(self.train_op) # No checkpoint saved at step 1.
+ mon_sess.run(self.train_op) # Checkpoint saved at step 2.
+ mon_sess.run(self.train_op) # No checkpoint saved at step 3.
+ hook.end(sess) # Checkpoint saved at the last step (3)
+ events = iter(load_eventfile_contents(self.model_dir))
+ next(events) # Skip version event that's always there.
+
+ # Graph.
+ event = next(events)
+ self.assertEqual(0, event.step)
+ actual_graph_def = graph_pb2.GraphDef()
+ actual_graph_def.ParseFromString(event.graph_def)
+ test_util.assert_equal_graph_def(actual_graph_def, expected_graph_def)
+
+ # Metagraph.
+ event = next(events)
+ self.assertEqual(0, event.step)
+ actual_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ actual_meta_graph_def.ParseFromString(event.meta_graph_def)
+ test_util.assert_meta_graph_protos_equal(
+ self, expected_meta_graph_def, actual_meta_graph_def)
+
+ # Checkpoints.
+ # Strip the "-step#" suffix off the latest checkpoint to get base path.
+ checkpoint_path = saver.latest_checkpoint(self.model_dir).rsplit('-', 1)[0]
+ self._assertCheckpointEvent(next(events), 0, checkpoint_path)
+ self._assertCheckpointEvent(next(events), 2, checkpoint_path)
+ self._assertCheckpointEvent(next(events), 3, checkpoint_path)
+ self.assertRaises(StopIteration, lambda: next(events)) # No more events.
def test_save_checkpoint_before_first_train_step(self):
with self.graph.as_default():
@@ -1102,167 +1147,305 @@ class StepCounterHookTest(test.TestCase):
self.assertEqual('global_step/sec', summary_value.tag)
self.assertGreater(summary_value.simple_value, 0)
+ def test_summary_writer(self):
+ with ops.Graph().as_default(), session_lib.Session() as sess:
+ variables.get_or_create_global_step()
+ train_op = training_util._increment_global_step(1)
+ hook = basic_session_run_hooks.StepCounterHook(
+ output_dir=self.log_dir, every_n_steps=10)
+ hook.begin()
+ sess.run(variables_lib.global_variables_initializer())
+ mon_sess = monitored_session._HookedSession(sess, [hook])
+ for _ in range(30):
+ mon_sess.run(train_op)
+ hook.end(sess)
+ events = iter(load_eventfile_contents(self.log_dir))
+ next(events) # Skip version event that's always there.
+
+ event = next(events)
+ self.assertEqual(11, event.step)
+ self.assertEqual('global_step/sec', event.summary.value[0].tag)
+ self.assertLess(0, event.summary.value[0].simple_value)
-class SummarySaverHookTest(test.TestCase):
+ event = next(events)
+ self.assertEqual(21, event.step)
+ self.assertEqual('global_step/sec', event.summary.value[0].tag)
+ self.assertLess(0, event.summary.value[0].simple_value)
- def setUp(self):
- test.TestCase.setUp(self)
+ self.assertRaises(StopIteration, lambda: next(events)) # No more events.
- self.log_dir = 'log/dir'
- self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)
- var = variables_lib.Variable(0.0)
- tensor = state_ops.assign_add(var, 1.0)
- tensor2 = tensor * 2
- self.summary_op = summary_lib.scalar('my_summary', tensor)
- self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)
+class SummarySaverHookTest(test.TestCase):
- variables.get_or_create_global_step()
- self.train_op = training_util._increment_global_step(1)
+ def setUp(self):
+ test.TestCase.setUp(self)
+ self.logdir = self.get_temp_dir()
+ self._create_stable_global_step()
+
+ def _create_stable_global_step(self):
+ """Returns a new ResourceVariable global_step for deterministic tests."""
+ # TODO(nickfelt): remove after standard global_step is a ResourceVariable.
+ with ops.get_default_graph().name_scope(None):
+ return variable_scope.get_variable(
+ ops.GraphKeys.GLOBAL_STEP,
+ shape=[],
+ dtype=dtypes.int64,
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ collections=[
+ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
+ ],
+ # Use a ResourceVariable and set caching_device to make the read
+ # behavior deterministic and well-defined.
+ caching_device='cpu:0',
+ use_resource=True)
def test_raise_when_scaffold_and_summary_op_both_missing(self):
with self.assertRaises(ValueError):
basic_session_run_hooks.SummarySaverHook()
def test_raise_when_scaffold_and_summary_op_both_present(self):
+ summary_op = summary_lib.merge_all()
with self.assertRaises(ValueError):
basic_session_run_hooks.SummarySaverHook(
- scaffold=monitored_session.Scaffold(), summary_op=self.summary_op)
+ scaffold=monitored_session.Scaffold(), summary_op=summary_op)
- def test_raise_in_both_secs_and_steps(self):
+ def test_raise_when_secs_and_steps_both_missing(self):
with self.assertRaises(ValueError):
basic_session_run_hooks.SummarySaverHook(
- save_secs=10, save_steps=20, summary_writer=self.summary_writer)
+ save_secs=None, save_steps=None, output_dir=self.logdir)
- def test_raise_in_none_secs_and_steps(self):
+ def test_raise_when_secs_and_steps_both_present(self):
with self.assertRaises(ValueError):
basic_session_run_hooks.SummarySaverHook(
- save_secs=None, save_steps=None, summary_writer=self.summary_writer)
+ save_secs=10, save_steps=20, output_dir=self.logdir)
- def test_save_steps(self):
- hook = basic_session_run_hooks.SummarySaverHook(
- save_steps=8,
- summary_writer=self.summary_writer,
- summary_op=self.summary_op)
+ def _makeHook(self, **kwargs):
+ kwargs['output_dir'] = self.logdir
+ kwargs['scaffold'] = monitored_session.Scaffold()
+ return basic_session_run_hooks.SummarySaverHook(**kwargs)
+ def _runForSteps(self, hook, steps, loop_body_fn=None):
+ train_op = training_util.get_global_step().assign_add(1)
with self.test_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
+ scaffold = hook._scaffold # pylint: disable=protected-access
+ if scaffold is not None:
+ scaffold.finalize()
+ sess.run(scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
- for _ in range(30):
- mon_sess.run(self.train_op)
+ for _ in range(steps):
+ mon_sess.run(train_op)
+ if loop_body_fn is not None:
+ loop_body_fn()
hook.end(sess)
- self.summary_writer.assert_summaries(
- test_case=self,
- expected_logdir=self.log_dir,
- expected_summaries={
- 1: {
- 'my_summary': 1.0
- },
- 9: {
- 'my_summary': 2.0
- },
- 17: {
- 'my_summary': 3.0
- },
- 25: {
- 'my_summary': 4.0
- },
- })
+ def _assertSessionEvent(self, event, step, session_status):
+ self.assertEqual(step, event.step)
+ self.assertEqual(session_status, event.session_log.status)
+
+ def _assertSummaryEvent(self, event, step, tag_value_list):
+ self.assertEqual(step, event.step)
+ tag_value_actual_list = [
+ (value.tag, value.simple_value) for value in event.summary.value
+ ]
+ self.assertItemsEqual(tag_value_list, tag_value_actual_list)
+
+ def test_no_summaries(self):
+ hook = self._makeHook(save_steps=1)
+ self._runForSteps(hook, 3)
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ def test_basic_summaries(self):
+ summary_lib.scalar('foo-v1', 1.0)
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.always_record_summaries():
+ summary_ops_v2.scalar('foo-v2', 2.0)
+ hook = self._makeHook(save_steps=1)
+ self._runForSteps(hook, 3)
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 1, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 1, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 2, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 2, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 3, [('foo-v1', 1.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
def test_multiple_summaries(self):
- hook = basic_session_run_hooks.SummarySaverHook(
- save_steps=8,
- summary_writer=self.summary_writer,
- summary_op=[self.summary_op, self.summary_op2])
-
+ summary_lib.scalar('foo-v1', 1.0)
+ summary_lib.scalar('bar-v1', 10.0)
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.always_record_summaries():
+ foo = summary_ops_v2.scalar('foo-v2', 2.0)
+ # Ensure deterministic write order
+ with ops.control_dependencies([foo]):
+ summary_ops_v2.scalar('bar-v2', 20.0)
+ hook = self._makeHook(save_steps=1)
+ self._runForSteps(hook, 1)
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 0, [('bar-v2', 20.0)])
+ self._assertSummaryEvent(
+ next(events), 1, [('foo-v1', 1.0), ('bar-v1', 10.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ def test_v2_summaries_only(self):
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.always_record_summaries():
+ summary_ops_v2.scalar('foo-v2', 2.0)
+ hook = self._makeHook(save_steps=1)
+ self._runForSteps(hook, 1)
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ def test_v2_summaries_custom_file_writer(self):
+ other_dir = os.path.join(self.logdir, 'other')
+ other_writer = summary_ops_v2.create_file_writer(other_dir)
+ # SummarySaverHook only flushes the writer for logdir; this one needs to be
+ # manually flushed.
+ flush_op = other_writer.flush()
+ with summary_ops_v2.always_record_summaries():
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ summary_ops_v2.scalar('foo-v2', 2.0)
+ with other_writer.as_default():
+ summary_ops_v2.scalar('other-v2', 3.0)
+ hook = self._makeHook(save_steps=1)
+ self._runForSteps(hook, 1)
with self.test_session() as sess:
- hook.begin()
- sess.run(variables_lib.global_variables_initializer())
- mon_sess = monitored_session._HookedSession(sess, [hook])
- for _ in range(10):
- mon_sess.run(self.train_op)
- hook.end(sess)
+ sess.run(flush_op)
- self.summary_writer.assert_summaries(
- test_case=self,
- expected_logdir=self.log_dir,
- expected_summaries={
- 1: {
- 'my_summary': 1.0,
- 'my_summary2': 2.0
- },
- 9: {
- 'my_summary': 2.0,
- 'my_summary2': 4.0
- },
- })
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
- def test_save_secs_saving_once_every_step(self):
- hook = basic_session_run_hooks.SummarySaverHook(
- save_secs=0.5,
- summary_writer=self.summary_writer,
- summary_op=self.summary_op)
+ events = iter(load_eventfile_contents(other_dir))
+ next(events) # Skip version event that's always there.
+ self._assertSummaryEvent(next(events), 0, [('other-v2', 3.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
- with self.test_session() as sess:
- hook.begin()
- sess.run(variables_lib.global_variables_initializer())
- mon_sess = monitored_session._HookedSession(sess, [hook])
- for _ in range(4):
- mon_sess.run(self.train_op)
- time.sleep(0.5)
- hook.end(sess)
+ def test_save_steps(self):
+ summary_lib.scalar('foo-v1', 1.0)
+ placeholder = array_ops.placeholder_with_default(False, shape=[])
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.record_summaries_if(placeholder):
+ summary_ops_v2.scalar('foo-v2', 2.0)
- self.summary_writer.assert_summaries(
- test_case=self,
- expected_logdir=self.log_dir,
- expected_summaries={
- 1: {
- 'my_summary': 1.0
- },
- 2: {
- 'my_summary': 2.0
- },
- 3: {
- 'my_summary': 3.0
- },
- 4: {
- 'my_summary': 4.0
- },
- })
+ basic_session_run_hooks.SummarySaverHook._set_placeholder(placeholder)
+ hook = self._makeHook(save_steps=8)
+ self._runForSteps(hook, 30)
+
+ events = load_eventfile_contents(self.logdir)
+ print('TEST SAVE STEPS EVENTS', str(events), file=sys.stderr)
+ events = iter(events)
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 1, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 8, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 9, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 16, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 17, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 24, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 25, [('foo-v1', 1.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
@test.mock.patch.object(time, 'time')
- def test_save_secs_saving_once_every_three_steps(self, mock_time):
- mock_time.return_value = 1484695987.209386
- hook = basic_session_run_hooks.SummarySaverHook(
- save_secs=9.,
- summary_writer=self.summary_writer,
- summary_op=self.summary_op)
+ def test_save_secs_saving_once_every_step(self, mock_time):
+ mock_time.return_value = 1000.0
+ summary_lib.scalar('foo-v1', 1.0)
+ placeholder = array_ops.placeholder_with_default(False, shape=[])
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.record_summaries_if(placeholder):
+ summary_ops_v2.scalar('foo-v2', 2.0)
- with self.test_session() as sess:
- hook.begin()
- sess.run(variables_lib.global_variables_initializer())
- mon_sess = monitored_session._HookedSession(sess, [hook])
- for _ in range(8):
- mon_sess.run(self.train_op)
- mock_time.return_value += 3.1
- hook.end(sess)
+ basic_session_run_hooks.SummarySaverHook._set_placeholder(placeholder)
+ hook = self._makeHook(save_secs=0.5)
+ def fake_sleep():
+ mock_time.return_value += 0.5
+ self._runForSteps(hook, 4, fake_sleep)
+
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 1, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 1, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 2, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 2, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 3, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 3, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 4, [('foo-v1', 1.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ @test.mock.patch.object(time, 'time')
+ def test_save_secs_saving_once_every_three_steps(self, mock_time):
+ mock_time.return_value = 1000.0
+ summary_lib.scalar('foo-v1', 1.0)
+ placeholder = array_ops.placeholder_with_default(False, shape=[])
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.record_summaries_if(placeholder):
+ summary_ops_v2.scalar('foo-v2', 2.0)
+
+ basic_session_run_hooks.SummarySaverHook._set_placeholder(placeholder)
+ hook = self._makeHook(save_secs=9)
+ def fake_sleep():
+ mock_time.return_value += 3.1
+ self._runForSteps(hook, 8, fake_sleep)
+
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
# 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first:
- self.summary_writer.assert_summaries(
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 1, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 3, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 4, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 6, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 7, [('foo-v1', 1.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ def test_explicit_summary_writer_and_op(self):
+ summary_writer = fake_summary_writer.FakeSummaryWriter(self.logdir)
+ hook = basic_session_run_hooks.SummarySaverHook(
+ save_steps=1,
+ summary_writer=summary_writer,
+ summary_op=summary_lib.scalar('foo-v1', 1.0))
+ self._runForSteps(hook, 3)
+ summary_writer.assert_summaries(
test_case=self,
- expected_logdir=self.log_dir,
+ expected_logdir=self.logdir,
expected_summaries={
- 1: {
- 'my_summary': 1.0
- },
- 4: {
- 'my_summary': 2.0
- },
- 7: {
- 'my_summary': 3.0
- },
+ 1: {'foo-v1': 1.0},
+ 2: {'foo-v1': 1.0},
+ 3: {'foo-v1': 1.0},
})
@@ -1518,18 +1701,23 @@ class ProfilerHookTest(test.TestCase):
sess.run(self.train_op) # Saved.
self.assertEqual(3, self._count_timeline_files())
- def test_run_metadata_saves_in_first_step(self):
- writer_cache.FileWriterCache.clear()
- fake_summary_writer.FakeSummaryWriter.install()
- fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
+ def test_run_metadata_summary_saving(self):
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
- save_secs=2, output_dir=self.output_dir)
+ save_steps=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
sess.run(self.train_op) # Saved.
- self.assertEqual(
- list(fake_writer._added_run_metadata.keys()), ['step_1'])
- fake_summary_writer.FakeSummaryWriter.uninstall()
+ sess.run(self.train_op) # Not saved.
+ sess.run(self.train_op) # Saved.
+ events = iter(load_eventfile_contents(self.output_dir))
+ next(events) # Skip version event that's always there.
+ event = next(events)
+ self.assertEqual(1, event.step)
+ self.assertEqual('step_1', event.tagged_run_metadata.tag)
+ event = next(events)
+ self.assertEqual(3, event.step)
+ self.assertEqual('step_3', event.tagged_run_metadata.tag)
+ self.assertRaises(StopIteration, lambda: next(events)) # No more events.
if __name__ == '__main__':
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 7b06bffa4b..8a4ca04b1e 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
@@ -204,13 +205,17 @@ class Scaffold(object):
'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
Scaffold.default_local_init_op)
if self._summary_op is None:
+ def default_summary_op():
+ v1_op = summary.merge_all()
+ v2_ops = summary_ops_v2.all_summary_ops() or []
+ if v1_op is not None:
+ return control_flow_ops.with_dependencies(v2_ops, v1_op)
+ return control_flow_ops.group(v2_ops) if v2_ops else None
self._summary_op = Scaffold.get_or_default('summary_op',
ops.GraphKeys.SUMMARY_OP,
- summary.merge_all)
- # pylint: disable=g-long-lambda
+ default_summary_op)
if self._saver is None:
self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access
- # pylint: enable=g-long-lambda
self._saver.build()
ops.get_default_graph().finalize()
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index f75db08059..b9d42b034e 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -611,10 +611,8 @@ class Optimizer(
if isinstance(global_step, resource_variable_ops.ResourceVariable):
# TODO(apassos): the implicit read in assign_add is slow; consider
# making it less so.
- apply_updates = resource_variable_ops.assign_add_variable_op(
- global_step.handle,
- ops.convert_to_tensor(1, dtype=global_step.dtype),
- name=name)
+ apply_updates = global_step.assign_add(
+ 1, name=name, read_value=False)
else:
apply_updates = state_ops.assign_add(global_step, 1, name=name)