diff options
author | 2018-03-29 19:57:49 -0700 | |
---|---|---|
committer | 2018-03-29 20:03:58 -0700 | |
commit | bf170839d2a8be1b16e0a6c6a74ac2f0dc427f96 (patch) | |
tree | 2751916850df09cc737fcf91008f6fb0d9a9766f /tensorflow | |
parent | df847112f0a8805dab02cc5581870a8460032ef3 (diff) |
Add summaries from only the first tower in distributed strategy.
PiperOrigin-RevId: 191024726
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/summary_op_util.py | 11 | ||||
-rw-r--r-- | tensorflow/python/ops/summary_ops.py | 3 | ||||
-rw-r--r-- | tensorflow/python/summary/summary.py | 11 |
4 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index fa63a0525d..aa0acd243c 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4174,6 +4174,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":constant_op", ":errors", ":framework", ":framework_for_generated_wrappers", diff --git a/tensorflow/python/ops/summary_op_util.py b/tensorflow/python/ops/summary_op_util.py index 37b80d5e20..a793f634bd 100644 --- a/tensorflow/python/ops/summary_op_util.py +++ b/tensorflow/python/ops/summary_op_util.py @@ -23,6 +23,7 @@ import re from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging +from tensorflow.python.training import distribute def collect(val, collections, default_collections): @@ -42,6 +43,16 @@ def collect(val, collections, default_collections): _INVALID_TAG_CHARACTERS = re.compile(r'[^-/\w\.]') +def skip_summary(): + # If using multiple towers in distributed strategy, skip summaries on all + # towers except the first one (tower_id=0). + # TODO(priyag): Add a new optional argument that will provide multiple + # alternatives to override default behavior. (e.g. run on last tower, + # compute sum or mean across towers). + tower_context = distribute.get_tower_context() + return tower_context and tower_context.tower_id > 0 + + def clean_tag(name): """Cleans a tag. Removes illegal characters for instance. diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py index 037bc9845a..ec4d4a6e92 100644 --- a/tensorflow/python/ops/summary_ops.py +++ b/tensorflow/python/ops/summary_ops.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.framework import summary_pb2 +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import gen_logging_ops from tensorflow.python.ops import summary_op_util @@ -71,6 +72,8 @@ def tensor_summary(name, serialized_summary_metadata = summary_metadata.SerializeToString() + if summary_op_util.skip_summary(): + return constant_op.constant("") with summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): val = gen_logging_ops.tensor_summary_v2( diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py index 97f2ddfdfc..1286ed6703 100644 --- a/tensorflow/python/summary/summary.py +++ b/tensorflow/python/summary/summary.py @@ -50,6 +50,7 @@ from tensorflow.core.util.event_pb2 import TaggedRunMetadata from tensorflow.python.eager import context as _context +from tensorflow.python.framework import constant_op as _constant_op from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import ops as _ops from tensorflow.python.ops import gen_logging_ops as _gen_logging_ops @@ -98,6 +99,8 @@ def scalar(name, tensor, collections=None, family=None): Raises: ValueError: If tensor has the wrong shape or type. """ + if _summary_op_util.skip_summary(): + return _constant_op.constant('') with _summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): val = _gen_logging_ops.scalar_summary(tags=tag, values=tensor, name=scope) @@ -151,6 +154,8 @@ def image(name, tensor, max_outputs=3, collections=None, family=None): A scalar `Tensor` of type `string`. The serialized `Summary` protocol buffer. """ + if _summary_op_util.skip_summary(): + return _constant_op.constant('') with _summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): val = _gen_logging_ops.image_summary( @@ -189,6 +194,8 @@ def histogram(name, values, collections=None, family=None): A scalar `Tensor` of type `string`. The serialized `Summary` protocol buffer. """ + if _summary_op_util.skip_summary(): + return _constant_op.constant('') with _summary_op_util.summary_scope( name, family, values=[values], default_name='HistogramSummary') as (tag, scope): @@ -234,6 +241,8 @@ def audio(name, tensor, sample_rate, max_outputs=3, collections=None, A scalar `Tensor` of type `string`. The serialized `Summary` protocol buffer. """ + if _summary_op_util.skip_summary(): + return _constant_op.constant('') with _summary_op_util.summary_scope( name, family=family, values=[tensor]) as (tag, scope): sample_rate = _ops.convert_to_tensor( @@ -282,6 +291,8 @@ def merge(inputs, collections=None, name=None): raise RuntimeError( 'Merging tf.summary.* ops is not compatible with eager execution. ' 'Use tf.contrib.summary instead.') + if _summary_op_util.skip_summary(): + return _constant_op.constant('') name = _summary_op_util.clean_tag(name) with _ops.name_scope(name, 'Merge', inputs): val = _gen_logging_ops.merge_summary(inputs=inputs, name=name) |