aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/broadcast_to_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/broadcast_to_ops_test.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
index 6a1bd958ba..bd2339f31d 100644
--- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
+++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
@@ -21,8 +21,10 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.platform import test as test_lib
@@ -81,5 +83,47 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
# check shape inference when shape input is constant
self.assertAllEqual(shape, v_np.shape)
+ def testGradientForScalar(self):
+ # TODO(alextp): There is a bug with broadcast_to on GPU from scalars,
+ # hence we make this test cpu-only.
+ with ops.device("cpu:0"):
+ x = constant_op.constant(1, dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [2, 4, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithSameRank(self):
+ x = constant_op.constant(np.reshape(np.arange(6), (2, 1, 3)),
+ dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [2, 5, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithIncreasingRank(self):
+ x = constant_op.constant([[1], [2]],
+ dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [5, 2, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithBroadcastAllDimensions(self):
+ x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [5, 4, 6])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+
if __name__ == "__main__":
test_lib.main()