aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-09-26 14:59:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 15:08:52 -0700
commitdc90d6c486f2ec1741766b0989e6f6e842d94437 (patch)
tree6e81a94617c78722b4b05b325d302340cf1c4cb2 /tensorflow/compiler
parent82af048bc8c3c044c98a27b1c4c27bb62d4e4a14 (diff)
[TF:XLA] Fix XLA lowering of TF BroadcastTo operator.
PiperOrigin-RevId: 214675055
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc5
2 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index e219cf3d88..1b39d53dc0 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1445,6 +1445,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([4, 0], dtype=np.int32),
expected=np.zeros([4, 0], dtype=dtype))
+ x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype)
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array((3, 7, 8, 9), dtype=np.int32),
+ expected=np.tile(x, (1, 7, 8, 9)))
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index 4bd7c74dca..696c1c39be 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -64,10 +64,9 @@ class BroadcastToOp : public XlaOpKernel {
output_shape.DebugString()));
broadcast_dims.push_back(broadcast_shape.size());
- if (output_dims[i] == input_dims[i] || input_dims[i] == 1) {
+ if (output_dims[i] == input_dims[i]) {
broadcast_shape.push_back(output_dims[i]);
- }
- if (output_dims[i] != input_dims[i]) {
+ } else if (output_dims[i] != input_dims[i]) {
// Add dimensions [I, O/I], which we will later flatten to just
// [O]. We must do this in two phases since XLA broadcasting does not
// support tiling.