diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-09-24 12:12:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 12:17:12 -0700 |
commit | 7919d64414ed47d217b8fc508d1be56b2a531a3c (patch) | |
tree | 9d6fd19ca3932d89743ab4643a56529afc974303 /tensorflow/python/eager | |
parent | f361fb8e4b4a9838e60a11ab45391c308bcb90da (diff) |
Wrap forward and backward pass in a defun for L2HMC.
Also a small bugfix to handle unknown shapes in backprop._num_elements.
Before:
entry {
name: "L2hmcBenchmark.eager_train_cpu_defun"
iters: 10
wall_time: 0.594115018845
extras {
key: "examples_per_sec"
value {
double_value: 336.635152548
}
}
}
After:
entry {
name: "L2hmcBenchmark.eager_train_cpu_defun"
iters: 10
wall_time: 0.322251081467
extras {
key: "examples_per_sec"
value {
double_value: 620.634069216
}
}
}
PiperOrigin-RevId: 214308142
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/backprop.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index d95e0fe721..78f3198011 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -564,7 +564,10 @@ def _aggregate_grads(gradients): def _num_elements(grad): """The number of elements in the `grad` tensor.""" if isinstance(grad, ops.Tensor): - return functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access + shape_tuple = grad._shape_tuple() # pylint: disable=protected-access + if shape_tuple is None or None in shape_tuple: + return 0 + return functools.reduce(operator.mul, shape_tuple, 1) if isinstance(grad, ops.IndexedSlices): return functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access raise ValueError("`grad` not a Tensor or IndexedSlices.") |