diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-06-18 21:00:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-18 21:04:00 -0700 |
commit | 6070ae0e148f50dbc8f36e1654f0a3f53b8b067e (patch) | |
tree | 165e4c050220180a76512e304b70eee0cd02a2db /tensorflow/python/keras/callbacks.py | |
parent | 60b78d6152e6f8d985f3086930ff986c140c36bf (diff) |
Merge changes from github.
PiperOrigin-RevId: 201110240
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 21 |
1 files changed, 12 insertions, 9 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 70b6a8431a..9f91368e5b 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -724,15 +724,6 @@ class TensorBoard(Callback): for weight in layer.weights: mapped_weight_name = weight.name.replace(':', '_') tf_summary.histogram(mapped_weight_name, weight) - if self.write_grads: - grads = model.optimizer.get_gradients(model.total_loss, weight) - - def is_indexed_slices(grad): - return type(grad).__name__ == 'IndexedSlices' - - grads = [grad.values if is_indexed_slices(grad) else grad - for grad in grads] - tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) if self.write_images: w_img = array_ops.squeeze(weight) shape = K.int_shape(w_img) @@ -759,6 +750,18 @@ class TensorBoard(Callback): assert len(shape) == 4 and shape[-1] in [1, 3, 4] tf_summary.image(mapped_weight_name, w_img) + if self.write_grads: + for weight in layer.trainable_weights: + mapped_weight_name = weight.name.replace(':', '_') + grads = model.optimizer.get_gradients(model.total_loss, weight) + + def is_indexed_slices(grad): + return type(grad).__name__ == 'IndexedSlices' + + grads = [grad.values if is_indexed_slices(grad) else grad + for grad in grads] + tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) + if hasattr(layer, 'output'): tf_summary.histogram('{}_out'.format(layer.name), layer.output) self.merged = tf_summary.merge_all() |