aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks.py
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-18 21:00:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 21:04:00 -0700
commit6070ae0e148f50dbc8f36e1654f0a3f53b8b067e (patch)
tree165e4c050220180a76512e304b70eee0cd02a2db /tensorflow/python/keras/callbacks.py
parent60b78d6152e6f8d985f3086930ff986c140c36bf (diff)
Merge changes from github.
PiperOrigin-RevId: 201110240
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r--tensorflow/python/keras/callbacks.py21
1 files changed, 12 insertions, 9 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 70b6a8431a..9f91368e5b 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -724,15 +724,6 @@ class TensorBoard(Callback):
for weight in layer.weights:
mapped_weight_name = weight.name.replace(':', '_')
tf_summary.histogram(mapped_weight_name, weight)
- if self.write_grads:
- grads = model.optimizer.get_gradients(model.total_loss, weight)
-
- def is_indexed_slices(grad):
- return type(grad).__name__ == 'IndexedSlices'
-
- grads = [grad.values if is_indexed_slices(grad) else grad
- for grad in grads]
- tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
if self.write_images:
w_img = array_ops.squeeze(weight)
shape = K.int_shape(w_img)
@@ -759,6 +750,18 @@ class TensorBoard(Callback):
assert len(shape) == 4 and shape[-1] in [1, 3, 4]
tf_summary.image(mapped_weight_name, w_img)
+ if self.write_grads:
+ for weight in layer.trainable_weights:
+ mapped_weight_name = weight.name.replace(':', '_')
+ grads = model.optimizer.get_gradients(model.total_loss, weight)
+
+ def is_indexed_slices(grad):
+ return type(grad).__name__ == 'IndexedSlices'
+
+ grads = [grad.values if is_indexed_slices(grad) else grad
+ for grad in grads]
+ tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
+
if hasattr(layer, 'output'):
tf_summary.histogram('{}_out'.format(layer.name), layer.output)
self.merged = tf_summary.merge_all()