diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/broadcast_to_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/broadcast_to_ops_test.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py index 6a1bd958ba..bd2339f31d 100644 --- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py +++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py @@ -21,8 +21,10 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.platform import test as test_lib @@ -81,5 +83,47 @@ class BroadcastToTest(test_util.TensorFlowTestCase): # check shape inference when shape input is constant self.assertAllEqual(shape, v_np.shape) + def testGradientForScalar(self): + # TODO(alextp): There is a bug with broadcast_to on GPU from scalars, + # hence we make this test cpu-only. + with ops.device("cpu:0"): + x = constant_op.constant(1, dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 4, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithSameRank(self): + x = constant_op.constant(np.reshape(np.arange(6), (2, 1, 3)), + dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 5, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithIncreasingRank(self): + x = constant_op.constant([[1], [2]], + dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [5, 2, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithBroadcastAllDimensions(self): + x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [5, 4, 6]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + if __name__ == "__main__": test_lib.main() |