aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-09-24 12:12:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 12:17:12 -0700
commit7919d64414ed47d217b8fc508d1be56b2a531a3c (patch)
tree9d6fd19ca3932d89743ab4643a56529afc974303 /tensorflow/python/eager
parentf361fb8e4b4a9838e60a11ab45391c308bcb90da (diff)
Wrap forward and backward pass in a defun for L2HMC.
Also a small bugfix to handle unknown shapes in backprop._num_elements. Before: entry { name: "L2hmcBenchmark.eager_train_cpu_defun" iters: 10 wall_time: 0.594115018845 extras { key: "examples_per_sec" value { double_value: 336.635152548 } } } After: entry { name: "L2hmcBenchmark.eager_train_cpu_defun" iters: 10 wall_time: 0.322251081467 extras { key: "examples_per_sec" value { double_value: 620.634069216 } } } PiperOrigin-RevId: 214308142
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/backprop.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index d95e0fe721..78f3198011 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -564,7 +564,10 @@ def _aggregate_grads(gradients):
def _num_elements(grad):
"""The number of elements in the `grad` tensor."""
if isinstance(grad, ops.Tensor):
- return functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access
+ shape_tuple = grad._shape_tuple() # pylint: disable=protected-access
+ if shape_tuple is None or None in shape_tuple:
+ return 0
+ return functools.reduce(operator.mul, shape_tuple, 1)
if isinstance(grad, ops.IndexedSlices):
return functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access
raise ValueError("`grad` not a Tensor or IndexedSlices.")