aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-09-26 14:59:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 15:08:52 -0700
commitdc90d6c486f2ec1741766b0989e6f6e842d94437 (patch)
tree6e81a94617c78722b4b05b325d302340cf1c4cb2 /tensorflow/compiler/tf2xla
parent82af048bc8c3c044c98a27b1c4c27bb62d4e4a14 (diff)
[TF:XLA] Fix XLA lowering of TF BroadcastTo operator.
PiperOrigin-RevId: 214675055
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc5
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index 4bd7c74dca..696c1c39be 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -64,10 +64,9 @@ class BroadcastToOp : public XlaOpKernel {
output_shape.DebugString()));
broadcast_dims.push_back(broadcast_shape.size());
- if (output_dims[i] == input_dims[i] || input_dims[i] == 1) {
+ if (output_dims[i] == input_dims[i]) {
broadcast_shape.push_back(output_dims[i]);
- }
- if (output_dims[i] != input_dims[i]) {
+ } else if (output_dims[i] != input_dims[i]) {
// Add dimensions [I, O/I], which we will later flatten to just
// [O]. We must do this in two phases since XLA broadcasting does not
// support tiling.