diff options
author | 2018-09-10 10:24:43 -0700 | |
---|---|---|
committer | 2018-09-10 10:24:57 -0700 | |
commit | 9b674c1aa986eab6a169af81791719bffc8a505d (patch) | |
tree | db62eb6aabf635561941a56db3aea780fe27944d | |
parent | 7d3884bb87dc02c4548f55749f3d6db1b8364ddc (diff) | |
parent | dce54446805ca6be5b4ecd7d5226f2a80a0e9aa1 (diff) |
Merge pull request #22083 from facaiy:ENH/add_gradient_for_broadcast_to
PiperOrigin-RevId: 212288455
-rw-r--r-- | tensorflow/python/kernel_tests/broadcast_to_ops_test.py | 44 | ||||
-rw-r--r-- | tensorflow/python/ops/array_grad.py | 19 |
2 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py index 6a1bd958ba..bd2339f31d 100644 --- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py +++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py @@ -21,8 +21,10 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.platform import test as test_lib @@ -81,5 +83,47 @@ class BroadcastToTest(test_util.TensorFlowTestCase): # check shape inference when shape input is constant self.assertAllEqual(shape, v_np.shape) + def testGradientForScalar(self): + # TODO(alextp): There is a bug with broadcast_to on GPU from scalars, + # hence we make this test cpu-only. + with ops.device("cpu:0"): + x = constant_op.constant(1, dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 4, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithSameRank(self): + x = constant_op.constant(np.reshape(np.arange(6), (2, 1, 3)), + dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 5, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithIncreasingRank(self): + x = constant_op.constant([[1], [2]], + dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [5, 2, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithBroadcastAllDimensions(self): + x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [5, 4, 6]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + if __name__ == "__main__": test_lib.main() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 6ae869b89e..ade86e85bf 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -805,3 +805,22 @@ def _ScatterNdNonAliasingAddGrad(op, grad): indices = op.inputs[1] updates_grad = array_ops.gather_nd(grad, indices) return [grad, None, updates_grad] + + +@ops.RegisterGradient("BroadcastTo") +def _BroadcastToGrad(op, grad): + input_value = op.inputs[0] + broadcast_shape = op.inputs[1] + # Assign ids for each position in input_value. + input_value_shape = array_ops.shape(input_value) + input_value_size = array_ops.size(input_value) + ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape) + broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape) + # Group by ids and sum its gradients. + grad_flatten = array_ops.reshape(grad, [-1]) + broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1]) + updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten, + broadcast_ids_flatten, + input_value_size) + updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape) + return [updates_grad, None] |