aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-08-23 13:24:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 13:35:49 -0700
commit0a427ca13e52bc121f2d42f21c65e6f03a520a1a (patch)
treeaa82b97e3ced0d6294d5946d1cd7451bd67010e6 /tensorflow/compiler/tests
parentd7682bb16f575eb0c4cbb1622d8098c592fed2b7 (diff)
[TF:XLA] Implement BroadcastTo.
PiperOrigin-RevId: 209988299
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 4a281c37e4..ed4940f204 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1372,5 +1372,40 @@ class BinaryOpsTest(xla_test.XLATestCase):
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
dtype=dtype))
+ def testBroadcastTo(self):
+ for dtype in self.all_types:
+ x = np.random.randint(0, high=100, size=[2, 3])
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([2, 3], dtype=np.int32),
+ expected=x)
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([6, 6], dtype=np.int32),
+ expected=np.tile(x, [3, 2]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 4, 3], dtype=np.int32),
+ expected=np.tile(x, [7, 2, 1]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 0, 3], dtype=np.int32),
+ expected=np.zeros([7, 0, 3], dtype=dtype))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 1, 2, 9], dtype=np.int32),
+ expected=np.tile(x, [7, 1, 1, 3]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ np.zeros([2, 0], dtype=dtype),
+ np.array([4, 0], dtype=np.int32),
+ expected=np.zeros([4, 0], dtype=dtype))
+
+
if __name__ == "__main__":
googletest.main()