aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-03-29 19:57:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-29 20:03:58 -0700
commitbf170839d2a8be1b16e0a6c6a74ac2f0dc427f96 (patch)
tree2751916850df09cc737fcf91008f6fb0d9a9766f /tensorflow
parentdf847112f0a8805dab02cc5581870a8460032ef3 (diff)
Add summaries from only the first tower in distributed strategy.
PiperOrigin-RevId: 191024726
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/ops/summary_op_util.py11
-rw-r--r--tensorflow/python/ops/summary_ops.py3
-rw-r--r--tensorflow/python/summary/summary.py11
4 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index fa63a0525d..aa0acd243c 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -4174,6 +4174,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":constant_op",
":errors",
":framework",
":framework_for_generated_wrappers",
diff --git a/tensorflow/python/ops/summary_op_util.py b/tensorflow/python/ops/summary_op_util.py
index 37b80d5e20..a793f634bd 100644
--- a/tensorflow/python/ops/summary_op_util.py
+++ b/tensorflow/python/ops/summary_op_util.py
@@ -23,6 +23,7 @@ import re
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging
+from tensorflow.python.training import distribute
def collect(val, collections, default_collections):
@@ -42,6 +43,16 @@ def collect(val, collections, default_collections):
_INVALID_TAG_CHARACTERS = re.compile(r'[^-/\w\.]')
+def skip_summary():
+ # If using multiple towers in distributed strategy, skip summaries on all
+ # towers except the first one (tower_id=0).
+ # TODO(priyag): Add a new optional argument that will provide multiple
+ # alternatives to override default behavior. (e.g. run on last tower,
+ # compute sum or mean across towers).
+ tower_context = distribute.get_tower_context()
+ return tower_context and tower_context.tower_id > 0
+
+
def clean_tag(name):
"""Cleans a tag. Removes illegal characters for instance.
diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py
index 037bc9845a..ec4d4a6e92 100644
--- a/tensorflow/python/ops/summary_ops.py
+++ b/tensorflow/python/ops/summary_ops.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import summary_pb2
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.ops import summary_op_util
@@ -71,6 +72,8 @@ def tensor_summary(name,
serialized_summary_metadata = summary_metadata.SerializeToString()
+ if summary_op_util.skip_summary():
+ return constant_op.constant("")
with summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
val = gen_logging_ops.tensor_summary_v2(
diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py
index 97f2ddfdfc..1286ed6703 100644
--- a/tensorflow/python/summary/summary.py
+++ b/tensorflow/python/summary/summary.py
@@ -50,6 +50,7 @@ from tensorflow.core.util.event_pb2 import TaggedRunMetadata
from tensorflow.python.eager import context as _context
+from tensorflow.python.framework import constant_op as _constant_op
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops as _ops
from tensorflow.python.ops import gen_logging_ops as _gen_logging_ops
@@ -98,6 +99,8 @@ def scalar(name, tensor, collections=None, family=None):
Raises:
ValueError: If tensor has the wrong shape or type.
"""
+ if _summary_op_util.skip_summary():
+ return _constant_op.constant('')
with _summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
val = _gen_logging_ops.scalar_summary(tags=tag, values=tensor, name=scope)
@@ -151,6 +154,8 @@ def image(name, tensor, max_outputs=3, collections=None, family=None):
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
+ if _summary_op_util.skip_summary():
+ return _constant_op.constant('')
with _summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
val = _gen_logging_ops.image_summary(
@@ -189,6 +194,8 @@ def histogram(name, values, collections=None, family=None):
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
+ if _summary_op_util.skip_summary():
+ return _constant_op.constant('')
with _summary_op_util.summary_scope(
name, family, values=[values],
default_name='HistogramSummary') as (tag, scope):
@@ -234,6 +241,8 @@ def audio(name, tensor, sample_rate, max_outputs=3, collections=None,
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
+ if _summary_op_util.skip_summary():
+ return _constant_op.constant('')
with _summary_op_util.summary_scope(
name, family=family, values=[tensor]) as (tag, scope):
sample_rate = _ops.convert_to_tensor(
@@ -282,6 +291,8 @@ def merge(inputs, collections=None, name=None):
raise RuntimeError(
'Merging tf.summary.* ops is not compatible with eager execution. '
'Use tf.contrib.summary instead.')
+ if _summary_op_util.skip_summary():
+ return _constant_op.constant('')
name = _summary_op_util.clean_tag(name)
with _ops.name_scope(name, 'Merge', inputs):
val = _gen_logging_ops.merge_summary(inputs=inputs, name=name)