aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/bcast_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/bcast_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/bcast_ops_test.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/bcast_ops_test.py b/tensorflow/python/kernel_tests/bcast_ops_test.py
new file mode 100644
index 0000000000..c62a910496
--- /dev/null
+++ b/tensorflow/python/kernel_tests/bcast_ops_test.py
@@ -0,0 +1,76 @@
+"""Tests for tensorflow.kernels.bcast_ops."""
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from tensorflow.python.ops.gen_array_ops import _broadcast_gradient_args
+
+
+class BcastOpsTest(tf.test.TestCase):
+
+ def _GetGradientArgs(self, xs, ys):
+ with self.test_session() as sess:
+ return sess.run(_broadcast_gradient_args(xs, ys))
+
+ def testBasic(self):
+ r0, r1 = self._GetGradientArgs([2, 3, 5], [1])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 1, 2])
+
+ r0, r1 = self._GetGradientArgs([1], [2, 3, 5])
+ self.assertAllEqual(r0, [0, 1, 2])
+ self.assertAllEqual(r1, [])
+
+ r0, r1 = self._GetGradientArgs([2, 3, 5], [5])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 1])
+
+ r0, r1 = self._GetGradientArgs([5], [2, 3, 5])
+ self.assertAllEqual(r0, [0, 1])
+ self.assertAllEqual(r1, [])
+
+ r0, r1 = self._GetGradientArgs([2, 3, 5], [3, 5])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0])
+
+ r0, r1 = self._GetGradientArgs([3, 5], [2, 3, 5])
+ self.assertAllEqual(r0, [0])
+ self.assertAllEqual(r1, [])
+
+ r0, r1 = self._GetGradientArgs([2, 3, 5], [3, 1])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 2])
+
+ r0, r1 = self._GetGradientArgs([3, 1], [2, 3, 5])
+ self.assertAllEqual(r0, [0, 2])
+ self.assertAllEqual(r1, [])
+
+ r0, r1 = self._GetGradientArgs([2, 1, 5], [3, 1])
+ self.assertAllEqual(r0, [1])
+ self.assertAllEqual(r1, [0, 2])
+
+ r0, r1 = self._GetGradientArgs([3, 1], [2, 1, 5])
+ self.assertAllEqual(r0, [0, 2])
+ self.assertAllEqual(r1, [1])
+
+ def testZeroDims(self):
+ r0, r1 = self._GetGradientArgs([2, 0, 3, 0, 5], [3, 0, 5])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 1])
+
+ r0, r1 = self._GetGradientArgs([3, 0, 5], [2, 0, 3, 0, 5])
+ self.assertAllEqual(r0, [0, 1])
+ self.assertAllEqual(r1, [])
+
+ r0, r1 = self._GetGradientArgs([2, 0, 3, 0, 5], [3, 1, 5])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 1, 3])
+
+ r0, r1 = self._GetGradientArgs([3, 1, 5], [2, 0, 3, 0, 5])
+ self.assertAllEqual(r0, [0, 1, 3])
+ self.assertAllEqual(r1, [])
+
+
+if __name__ == "__main__":
+ tf.test.main()