aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-10-05 08:03:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 08:06:48 -0700
commit388ed2929ea024adcfb76ea9ddd78a38a87470b7 (patch)
tree47bc22ef16b7630c8df784bec6e796a393bb47e5 /tensorflow/compiler
parent92c8a77ba480bf4aeddea412cc1d2988f6ad81cd (diff)
[TF:XLA] Move broadcasting code out of BroadcastTo op into a common helper library.
Change XlaBinaryOp::Broadcast to use the BroadcastTo lowering, since it produces fewer extraneous reshapes and transposes. Even if the reshapes and transposes would later optimize away, this yields more readable output and makes life easier for HLO rewrites that run early. Change in preparation for removing reshapes from SoftmaxCrossEntropyWithLogits. PiperOrigin-RevId: 215906847
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc63
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.cc57
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.h3
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD16
-rw-r--r--tensorflow/compiler/tf2xla/lib/broadcast.cc93
-rw-r--r--tensorflow/compiler/tf2xla/lib/broadcast.h32
8 files changed, 165 insertions, 110 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 95a010a119..224e5ea123 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -121,6 +121,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
+ "//tensorflow/compiler/tf2xla/lib:broadcast",
"//tensorflow/compiler/tf2xla/lib:cholesky",
"//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random",
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index a988d3c33e..47e517a657 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -64,7 +64,7 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
// }
static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto y_equals_0 = xla::Eq(y, zero);
auto zeros = xla::ZerosLike(x);
@@ -84,7 +84,7 @@ XLA_MAKE_BINARY(DivNoNan,
// }
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
if (DataTypeIsUnsigned(dtype)) {
return xla::Div(x, y);
}
@@ -105,7 +105,7 @@ XLA_MAKE_BINARY(FloorDiv,
static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto is_zero = xla::Eq(x, zero);
return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
@@ -114,7 +114,7 @@ XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto is_zero = xla::Eq(x, zero);
return xla::Select(is_zero, zero, xla::Div(x, y));
@@ -126,7 +126,7 @@ XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero));
auto trunc_mod = xla::Rem(x, y);
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index 696c1c39be..9bb11fb67e 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -13,16 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "absl/algorithm/container.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
namespace {
@@ -37,59 +32,9 @@ class BroadcastToOp : public XlaOpKernel {
TensorShape output_shape;
OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
- OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(),
- errors::InvalidArgument(
- "Input rank (", input_shape.dims(),
- ") must be less than or equal to the output rank (",
- output_shape.dims(), ")"));
-
- auto input_dims = input_shape.dim_sizes();
- auto output_dims = output_shape.dim_sizes();
-
- // Broadcasting is done right-to-left on right-aligned dimensions; reverse
- // the two vectors so elements to be broadcast are aligned.
- absl::c_reverse(input_dims);
- absl::c_reverse(output_dims);
-
- std::vector<int64> broadcast_dims;
- std::vector<int64> broadcast_shape;
- for (int i = 0; i < output_shape.dims(); ++i) {
- if (i < input_shape.dims()) {
- OP_REQUIRES(
- context,
- (output_dims[i] == 0 && input_dims[i] == 0) ||
- (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0),
- errors::InvalidArgument("invalid shape to broadcast from ",
- input_shape.DebugString(), " to ",
- output_shape.DebugString()));
-
- broadcast_dims.push_back(broadcast_shape.size());
- if (output_dims[i] == input_dims[i]) {
- broadcast_shape.push_back(output_dims[i]);
- } else if (output_dims[i] != input_dims[i]) {
- // Add dimensions [I, O/I], which we will later flatten to just
- // [O]. We must do this in two phases since XLA broadcasting does not
- // support tiling.
- broadcast_shape.push_back(input_dims[i]);
- broadcast_shape.push_back(output_dims[i] / input_dims[i]);
- }
- } else {
- broadcast_shape.push_back(output_dims[i]);
- }
- }
- absl::c_reverse(broadcast_dims);
- int broadcast_shape_size = broadcast_shape.size();
- for (int64& broadcast_dim : broadcast_dims) {
- broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
- }
- absl::c_reverse(broadcast_shape);
- xla::XlaOp output = xla::Reshape(
- xla::BroadcastInDim(context->Input(0),
- xla::ShapeUtil::MakeShape(
- context->input_xla_type(0), broadcast_shape),
- broadcast_dims),
- output_shape.dim_sizes());
- context->SetOutput(0, output);
+ auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes());
+ OP_REQUIRES_OK(context, output.status());
+ context->SetOutput(0, output.ValueOrDie());
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index ef1015552d..234f7b4a01 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
// compute valid broadcast shapes, but rely below on XLA to
// automatically perform the broadcast assuming its valid shapes are
// a superset of TensorFlow's valid shapes.
- BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape));
+ BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape),
+ /*fewer_dims_optimization=*/false);
if (!bcast.IsValid()) {
ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
lhs_shape.DebugString(), " vs. ",
@@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
}
/* static */ std::pair<xla::XlaOp, xla::XlaOp> XlaBinaryOp::Broadcast(
- xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
- const BCast& broadcast_helper) {
- // Manually construct the broadcasting since MapN does not do
- // automatic broadcasting. The bcast helper ensures that
- // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and
- // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have
- // the same shape, so can be operated on by MapN.
-
- // First reshape the inputs, which should be a metadata-only
- // operation since we are flattening the dimensions in order.
- auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape());
- auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape());
-
- // Next broadcast the necessary input dimensions. We rely on the
- // XLA optimizer to be smart about the fact that we are asking
- // it to broadcast size 1 on some of these dimensions, to avoid
- // adding complexity to this code.
- auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast());
- int lhs_size = broadcast_helper.x_bcast().size();
- auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast());
- int rhs_size = broadcast_helper.y_bcast().size();
-
- // Now reshape them to the correct output shape. After the
- // broadcast each side is twice as wide as it should be, since the
- // broadcast dimensions were prepended to the shape. Reshape
- // flattening each original dimension with the prepended broadcast
- // dimension. E.g. if we started out with lhs_shaped with shape
- // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have
- // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21].
- std::vector<int64> lhs_reorder;
- for (int i = 0; i < lhs_size; ++i) {
- lhs_reorder.push_back(i);
- lhs_reorder.push_back(i + lhs_size);
+ xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) {
+ auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape());
+ if (!lhs_output.ok()) {
+ xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status());
+ return {error, error};
}
- auto lhs_output =
- xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape());
- std::vector<int64> rhs_reorder;
- for (int i = 0; i < rhs_size; ++i) {
- rhs_reorder.push_back(i);
- rhs_reorder.push_back(i + rhs_size);
+ auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape());
+ if (!rhs_output.ok()) {
+ xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status());
+ return {error, error};
}
- auto rhs_output =
- xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape());
-
- return {lhs_output, rhs_output};
+ return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()};
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
index 6653944a91..516ead4bfe 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
@@ -67,8 +67,7 @@ class XlaBinaryOp : public XlaOpKernel {
// 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
// shape.
static std::pair<xla::XlaOp, xla::XlaOp> Broadcast(
- xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
- const BCast& broadcast_helper);
+ xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper);
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 8597e7f139..1ce3930fd1 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -32,6 +32,22 @@ cc_library(
)
cc_library(
+ name = "broadcast",
+ srcs = ["broadcast.cc"],
+ hdrs = ["broadcast.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
name = "cholesky",
srcs = ["cholesky.cc"],
hdrs = ["cholesky.h"],
diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc
new file mode 100644
index 0000000000..3e402ef855
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc
@@ -0,0 +1,93 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
+
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace tensorflow {
+
+xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
+ absl::Span<int64 const> output_dims) {
+ xla::XlaBuilder* builder = input.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
+ absl::Span<int64 const> input_dims =
+ xla::AsInt64Slice(input_shape.dimensions());
+
+ if (input_dims == output_dims) {
+ return input;
+ }
+
+ if (input_dims.size() > output_dims.size()) {
+ return errors::InvalidArgument(
+ "Input shape (", xla::ShapeUtil::HumanString(input_shape),
+ ") must have rank less than or equal to the output shape [",
+ absl::StrJoin(output_dims, ","), "]");
+ }
+
+ std::vector<int64> broadcast_dims;
+ std::vector<int64> broadcast_shape;
+ auto input_it = input_dims.rbegin();
+ for (auto output_it = output_dims.rbegin(); output_it != output_dims.rend();
+ ++output_it) {
+ if (input_it != input_dims.rend()) {
+ if (!(*output_it == 0 && *input_it == 0) &&
+ !(*input_it != 0 && *output_it % *input_it == 0)) {
+ return errors::InvalidArgument("Invalid shape broadcast from ",
+ xla::ShapeUtil::HumanString(input_shape),
+ " to [", absl::StrJoin(output_dims, ","),
+ "]");
+ }
+
+ broadcast_dims.push_back(broadcast_shape.size());
+ if (*output_it == *input_it) {
+ broadcast_shape.push_back(*output_it);
+ } else if (*output_it != *input_it) {
+ // Add dimensions [I, O/I], which we will later flatten to just
+ // [O]. We must do this in two phases since XLA broadcasting does not
+ // support tiling.
+ broadcast_shape.push_back(*input_it);
+ broadcast_shape.push_back(*output_it / *input_it);
+ }
+ ++input_it;
+ } else {
+ broadcast_shape.push_back(*output_it);
+ }
+ }
+ TF_RET_CHECK(input_it == input_dims.rend());
+
+ absl::c_reverse(broadcast_dims);
+ int broadcast_shape_size = broadcast_shape.size();
+ for (int64& broadcast_dim : broadcast_dims) {
+ broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
+ }
+ absl::c_reverse(broadcast_shape);
+ xla::XlaOp output = xla::BroadcastInDim(
+ input,
+ xla::ShapeUtil::MakeShape(input_shape.element_type(), broadcast_shape),
+ broadcast_dims);
+ if (broadcast_shape != output_dims) {
+ output = xla::Reshape(output, output_dims);
+ }
+ return output;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h
new file mode 100644
index 0000000000..591e696f06
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/broadcast.h
@@ -0,0 +1,32 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_
+
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace tensorflow {
+
+// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting
+// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling.
+xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
+ absl::Span<int64 const> output_dims);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_