aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 10:24:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 10:24:57 -0700
commit9b674c1aa986eab6a169af81791719bffc8a505d (patch)
treedb62eb6aabf635561941a56db3aea780fe27944d
parent7d3884bb87dc02c4548f55749f3d6db1b8364ddc (diff)
parentdce54446805ca6be5b4ecd7d5226f2a80a0e9aa1 (diff)
Merge pull request #22083 from facaiy:ENH/add_gradient_for_broadcast_to
PiperOrigin-RevId: 212288455
-rw-r--r--tensorflow/python/kernel_tests/broadcast_to_ops_test.py44
-rw-r--r--tensorflow/python/ops/array_grad.py19
2 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
index 6a1bd958ba..bd2339f31d 100644
--- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
+++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
@@ -21,8 +21,10 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.platform import test as test_lib
@@ -81,5 +83,47 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
# check shape inference when shape input is constant
self.assertAllEqual(shape, v_np.shape)
+ def testGradientForScalar(self):
+ # TODO(alextp): There is a bug with broadcast_to on GPU from scalars,
+ # hence we make this test cpu-only.
+ with ops.device("cpu:0"):
+ x = constant_op.constant(1, dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [2, 4, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithSameRank(self):
+ x = constant_op.constant(np.reshape(np.arange(6), (2, 1, 3)),
+ dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [2, 5, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithIncreasingRank(self):
+ x = constant_op.constant([[1], [2]],
+ dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [5, 2, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithBroadcastAllDimensions(self):
+ x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [5, 4, 6])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 6ae869b89e..ade86e85bf 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -805,3 +805,22 @@ def _ScatterNdNonAliasingAddGrad(op, grad):
indices = op.inputs[1]
updates_grad = array_ops.gather_nd(grad, indices)
return [grad, None, updates_grad]
+
+
+@ops.RegisterGradient("BroadcastTo")
+def _BroadcastToGrad(op, grad):
+ input_value = op.inputs[0]
+ broadcast_shape = op.inputs[1]
+ # Assign ids for each position in input_value.
+ input_value_shape = array_ops.shape(input_value)
+ input_value_size = array_ops.size(input_value)
+ ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape)
+ broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape)
+ # Group by ids and sum its gradients.
+ grad_flatten = array_ops.reshape(grad, [-1])
+ broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1])
+ updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten,
+ broadcast_ids_flatten,
+ input_value_size)
+ updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape)
+ return [updates_grad, None]