diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc | 63 |
1 files changed, 4 insertions, 59 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index 696c1c39be..9bb11fb67e 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -13,16 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "absl/algorithm/container.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/bcast.h" namespace tensorflow { namespace { @@ -37,59 +32,9 @@ class BroadcastToOp : public XlaOpKernel { TensorShape output_shape; OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); - OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(), - errors::InvalidArgument( - "Input rank (", input_shape.dims(), - ") must be less than or equal to the output rank (", - output_shape.dims(), ")")); - - auto input_dims = input_shape.dim_sizes(); - auto output_dims = output_shape.dim_sizes(); - - // Broadcasting is done right-to-left on right-aligned dimensions; reverse - // the two vectors so elements to be broadcast are aligned. - absl::c_reverse(input_dims); - absl::c_reverse(output_dims); - - std::vector<int64> broadcast_dims; - std::vector<int64> broadcast_shape; - for (int i = 0; i < output_shape.dims(); ++i) { - if (i < input_shape.dims()) { - OP_REQUIRES( - context, - (output_dims[i] == 0 && input_dims[i] == 0) || - (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0), - errors::InvalidArgument("invalid shape to broadcast from ", - input_shape.DebugString(), " to ", - output_shape.DebugString())); - - broadcast_dims.push_back(broadcast_shape.size()); - if (output_dims[i] == input_dims[i]) { - broadcast_shape.push_back(output_dims[i]); - } else if (output_dims[i] != input_dims[i]) { - // Add dimensions [I, O/I], which we will later flatten to just - // [O]. We must do this in two phases since XLA broadcasting does not - // support tiling. - broadcast_shape.push_back(input_dims[i]); - broadcast_shape.push_back(output_dims[i] / input_dims[i]); - } - } else { - broadcast_shape.push_back(output_dims[i]); - } - } - absl::c_reverse(broadcast_dims); - int broadcast_shape_size = broadcast_shape.size(); - for (int64& broadcast_dim : broadcast_dims) { - broadcast_dim = broadcast_shape_size - broadcast_dim - 1; - } - absl::c_reverse(broadcast_shape); - xla::XlaOp output = xla::Reshape( - xla::BroadcastInDim(context->Input(0), - xla::ShapeUtil::MakeShape( - context->input_xla_type(0), broadcast_shape), - broadcast_dims), - output_shape.dim_sizes()); - context->SetOutput(0, output); + auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes()); + OP_REQUIRES_OK(context, output.status()); + context->SetOutput(0, output.ValueOrDie()); } }; |