aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-16 18:06:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-16 18:10:53 -0700
commitecaa2eee832bd5b4286377f0f853c961c6ac2ab2 (patch)
tree892c9cc2b6cb8901dcb00ed1ea4c51b31ba6af8f /tensorflow/python/ops/math_grad.py
parent5c5dc8d5641b7c915f681109921dfb2b3e082a9b (diff)
math_grad: Fast path for when broadcasting is not needed.
PiperOrigin-RevId: 172407754
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r--tensorflow/python/ops/math_grad.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 3754e039ed..38fe093ba7 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -700,10 +700,26 @@ def _AddNGrad(op, grad):
return [grad] * len(op.inputs)
+def _ShapesFullySpecifiedAndEqual(x, y, grad):
+ # pylint: disable=protected-access
+ x_shape = x._shape_tuple()
+ y_shape = y._shape_tuple()
+ grad_shape = grad._shape_tuple()
+ # pylint: enable=protected-access
+ return (x_shape == y_shape and
+ x_shape == grad_shape and
+ x_shape is not None and
+ None not in x_shape)
+
+
@ops.RegisterGradient("Add")
def _AddGrad(op, grad):
+ """Gradient for Add."""
x = op.inputs[0]
y = op.inputs[1]
+ if (isinstance(grad, ops.Tensor) and
+ _ShapesFullySpecifiedAndEqual(x, y, grad)):
+ return grad, grad
sx = array_ops.shape(x)
sy = array_ops.shape(y)
# pylint: disable=protected-access
@@ -731,10 +747,14 @@ def _MulGrad(op, grad):
"""The gradient of scalar multiplication."""
x = op.inputs[0]
y = op.inputs[1]
+ # pylint: disable=protected-access
+ if (isinstance(grad, ops.Tensor) and
+ _ShapesFullySpecifiedAndEqual(x, y, grad) and
+ grad.dtype in (dtypes.int32, dtypes.float32)):
+ return gen_math_ops._mul(grad, y), gen_math_ops._mul(grad, x)
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
sx = array_ops.shape(x)
sy = array_ops.shape(y)
- # pylint: disable=protected-access
rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
# pylint: enable=protected-access
x = math_ops.conj(x)