From 32e96b1dc588cccf4e008259f831c4e50d948dc7 Mon Sep 17 00:00:00 2001 From: "Yan Facai (颜发才)" Date: Wed, 5 Sep 2018 15:46:09 +0800 Subject: ENH: add gradient for broadcast_to --- .../python/kernel_tests/broadcast_to_ops_test.py | 20 ++++++++++++++++++++ tensorflow/python/ops/array_grad.py | 19 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py index 6a1bd958ba..282a619094 100644 --- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py +++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.platform import test as test_lib @@ -81,5 +82,24 @@ class BroadcastToTest(test_util.TensorFlowTestCase): # check shape inference when shape input is constant self.assertAllEqual(shape, v_np.shape) + def testGradient(self): + x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 4, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientForScalar(self): + x = constant_op.constant(1, dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 4, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + if __name__ == "__main__": test_lib.main() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 6ae869b89e..ade86e85bf 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -805,3 +805,22 @@ def _ScatterNdNonAliasingAddGrad(op, grad): indices = op.inputs[1] updates_grad = array_ops.gather_nd(grad, indices) return [grad, None, updates_grad] + + +@ops.RegisterGradient("BroadcastTo") +def _BroadcastToGrad(op, grad): + input_value = op.inputs[0] + broadcast_shape = op.inputs[1] + # Assign ids for each position in input_value. + input_value_shape = array_ops.shape(input_value) + input_value_size = array_ops.size(input_value) + ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape) + broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape) + # Group by ids and sum its gradients. + grad_flatten = array_ops.reshape(grad, [-1]) + broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1]) + updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten, + broadcast_ids_flatten, + input_value_size) + updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape) + return [updates_grad, None] -- cgit v1.2.3 From 8859ee06cc0cba03d05ce9677b05ff1993c34b03 Mon Sep 17 00:00:00 2001 From: "Yan Facai (颜发才)" Date: Thu, 6 Sep 2018 22:45:25 +0800 Subject: TST: add more test cases --- .../python/kernel_tests/broadcast_to_ops_test.py | 30 ++++++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py index 282a619094..8bcf27466c 100644 --- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py +++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py @@ -82,8 +82,8 @@ class BroadcastToTest(test_util.TensorFlowTestCase): # check shape inference when shape input is constant self.assertAllEqual(shape, v_np.shape) - def testGradient(self): - x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32) + def testGradientForScalar(self): + x = constant_op.constant(1, dtype=dtypes.float32) v = array_ops.broadcast_to(x, [2, 4, 3]) out = 2 * v with self.test_session(): @@ -91,9 +91,29 @@ class BroadcastToTest(test_util.TensorFlowTestCase): out, out.get_shape()) self.assertLess(err, 1e-4) - def testGradientForScalar(self): - x = constant_op.constant(1, dtype=dtypes.float32) - v = array_ops.broadcast_to(x, [2, 4, 3]) + def testGradientWithSameRank(self): + x = constant_op.constant(np.reshape(np.arange(6), (2, 1, 3)), + dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 5, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithIncreasingRank(self): + x = constant_op.constant([[1], [2]], + dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [5, 2, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithBroadcastAllDimensions(self): + x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [5, 4, 6]) out = 2 * v with self.test_session(): err = gradient_checker.compute_gradient_error(x, x.get_shape(), -- cgit v1.2.3 From dce54446805ca6be5b4ecd7d5226f2a80a0e9aa1 Mon Sep 17 00:00:00 2001 From: "Yan Facai (颜发才)" Date: Fri, 7 Sep 2018 07:44:43 +0800 Subject: TST: make scalar test cpu-only --- tensorflow/python/kernel_tests/broadcast_to_ops_test.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py index 8bcf27466c..bd2339f31d 100644 --- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py +++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py @@ -21,6 +21,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker @@ -83,12 +84,15 @@ class BroadcastToTest(test_util.TensorFlowTestCase): self.assertAllEqual(shape, v_np.shape) def testGradientForScalar(self): - x = constant_op.constant(1, dtype=dtypes.float32) - v = array_ops.broadcast_to(x, [2, 4, 3]) - out = 2 * v - with self.test_session(): - err = gradient_checker.compute_gradient_error(x, x.get_shape(), - out, out.get_shape()) + # TODO(alextp): There is a bug with broadcast_to on GPU from scalars, + # hence we make this test cpu-only. + with ops.device("cpu:0"): + x = constant_op.constant(1, dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 4, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) self.assertLess(err, 1e-4) def testGradientWithSameRank(self): -- cgit v1.2.3