aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-24 10:35:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-24 10:40:04 -0700
commit0ebcfe9cf241655adaaa5af53385980bfd605212 (patch)
tree0101faec4131f189028d400a7e5e8261c4cd5426
parented316e4b53ef0de6a6cd6403213ed9dd9a76272f (diff)
use 'global_gradient_norm' in summaries to enable only overall stats
PiperOrigin-RevId: 157006018
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py
index 1a6dfc12e9..50c11c696a 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers.py
@@ -51,6 +51,7 @@ OPTIMIZER_SUMMARIES = [
"loss",
"gradients",
"gradient_norm",
+ "global_gradient_norm",
]
@@ -182,7 +183,7 @@ def optimize_loss(loss,
"Got %s of type %s" % (str(learning_rate),
str(type(learning_rate))))
if summaries is None:
- summaries = ["loss", "learning_rate"]
+ summaries = ["loss", "learning_rate", "global_gradient_norm"]
else:
for summ in summaries:
if summ not in OPTIMIZER_SUMMARIES:
@@ -250,7 +251,7 @@ def optimize_loss(loss,
"Empty list of (gradient, var) pairs encountered. This is most "
"likely to be caused by an improper value of gradient_multipliers.")
- if "gradient_norm" in summaries:
+ if "global_gradient_norm" in summaries or "gradient_norm" in summaries:
summary.scalar("global_norm/gradient_norm",
clip_ops.global_norm(list(zip(*gradients))[0]))
@@ -282,7 +283,8 @@ def optimize_loss(loss,
summary.scalar("gradient_norm/%s" % var_name,
clip_ops.global_norm([grad_values]))
- if clip_gradients is not None and "gradient_norm" in summaries:
+ if clip_gradients is not None and ("global_gradient_norm" in summaries or
+ "gradient_norm" in summaries):
summary.scalar("global_norm/clipped_gradient_norm",
clip_ops.global_norm(list(zip(*gradients))[0]))