diff options
author | 2017-10-16 18:06:20 -0700 | |
---|---|---|
committer | 2017-10-16 18:10:53 -0700 | |
commit | ecaa2eee832bd5b4286377f0f853c961c6ac2ab2 (patch) | |
tree | 892c9cc2b6cb8901dcb00ed1ea4c51b31ba6af8f /tensorflow/python/ops/math_grad.py | |
parent | 5c5dc8d5641b7c915f681109921dfb2b3e082a9b (diff) |
math_grad: Fast path for when broadcasting is not needed.
PiperOrigin-RevId: 172407754
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r-- | tensorflow/python/ops/math_grad.py | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 3754e039ed..38fe093ba7 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -700,10 +700,26 @@ def _AddNGrad(op, grad): return [grad] * len(op.inputs) +def _ShapesFullySpecifiedAndEqual(x, y, grad): + # pylint: disable=protected-access + x_shape = x._shape_tuple() + y_shape = y._shape_tuple() + grad_shape = grad._shape_tuple() + # pylint: enable=protected-access + return (x_shape == y_shape and + x_shape == grad_shape and + x_shape is not None and + None not in x_shape) + + @ops.RegisterGradient("Add") def _AddGrad(op, grad): + """Gradient for Add.""" x = op.inputs[0] y = op.inputs[1] + if (isinstance(grad, ops.Tensor) and + _ShapesFullySpecifiedAndEqual(x, y, grad)): + return grad, grad sx = array_ops.shape(x) sy = array_ops.shape(y) # pylint: disable=protected-access @@ -731,10 +747,14 @@ def _MulGrad(op, grad): """The gradient of scalar multiplication.""" x = op.inputs[0] y = op.inputs[1] + # pylint: disable=protected-access + if (isinstance(grad, ops.Tensor) and + _ShapesFullySpecifiedAndEqual(x, y, grad) and + grad.dtype in (dtypes.int32, dtypes.float32)): + return gen_math_ops._mul(grad, y), gen_math_ops._mul(grad, x) assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) sx = array_ops.shape(x) sy = array_ops.shape(y) - # pylint: disable=protected-access rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) # pylint: enable=protected-access x = math_ops.conj(x) |