aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc63
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.cc57
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.h3
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD16
-rw-r--r--tensorflow/compiler/tf2xla/lib/broadcast.cc93
-rw-r--r--tensorflow/compiler/tf2xla/lib/broadcast.h32
8 files changed, 165 insertions, 110 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 95a010a119..224e5ea123 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -121,6 +121,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
+ "//tensorflow/compiler/tf2xla/lib:broadcast",
"//tensorflow/compiler/tf2xla/lib:cholesky",
"//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random",
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index a988d3c33e..47e517a657 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -64,7 +64,7 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
// }
static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto y_equals_0 = xla::Eq(y, zero);
auto zeros = xla::ZerosLike(x);
@@ -84,7 +84,7 @@ XLA_MAKE_BINARY(DivNoNan,
// }
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
if (DataTypeIsUnsigned(dtype)) {
return xla::Div(x, y);
}
@@ -105,7 +105,7 @@ XLA_MAKE_BINARY(FloorDiv,
static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto is_zero = xla::Eq(x, zero);
return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
@@ -114,7 +114,7 @@ XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto is_zero = xla::Eq(x, zero);
return xla::Select(is_zero, zero, xla::Div(x, y));
@@ -126,7 +126,7 @@ XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero));
auto trunc_mod = xla::Rem(x, y);
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index 696c1c39be..9bb11fb67e 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -13,16 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "absl/algorithm/container.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
namespace {
@@ -37,59 +32,9 @@ class BroadcastToOp : public XlaOpKernel {
TensorShape output_shape;
OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
- OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(),
- errors::InvalidArgument(
- "Input rank (", input_shape.dims(),
- ") must be less than or equal to the output rank (",
- output_shape.dims(), ")"));
-
- auto input_dims = input_shape.dim_sizes();
- auto output_dims = output_shape.dim_sizes();
-
- // Broadcasting is done right-to-left on right-aligned dimensions; reverse
- // the two vectors so elements to be broadcast are aligned.
- absl::c_reverse(input_dims);
- absl::c_reverse(output_dims);
-
- std::vector<int64> broadcast_dims;
- std::vector<int64> broadcast_shape;
- for (int i = 0; i < output_shape.dims(); ++i) {
- if (i < input_shape.dims()) {
- OP_REQUIRES(
- context,
- (output_dims[i] == 0 && input_dims[i] == 0) ||
- (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0),
- errors::InvalidArgument("invalid shape to broadcast from ",
- input_shape.DebugString(), " to ",
- output_shape.DebugString()));
-
- broadcast_dims.push_back(broadcast_shape.size());
- if (output_dims[i] == input_dims[i]) {
- broadcast_shape.push_back(output_dims[i]);
- } else if (output_dims[i] != input_dims[i]) {
- // Add dimensions [I, O/I], which we will later flatten to just
- // [O]. We must do this in two phases since XLA broadcasting does not
- // support tiling.
- broadcast_shape.push_back(input_dims[i]);
- broadcast_shape.push_back(output_dims[i] / input_dims[i]);
- }
- } else {
- broadcast_shape.push_back(output_dims[i]);
- }
- }
- absl::c_reverse(broadcast_dims);
- int broadcast_shape_size = broadcast_shape.size();
- for (int64& broadcast_dim : broadcast_dims) {
- broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
- }
- absl::c_reverse(broadcast_shape);
- xla::XlaOp output = xla::Reshape(
- xla::BroadcastInDim(context->Input(0),
- xla::ShapeUtil::MakeShape(
- context->input_xla_type(0), broadcast_shape),
- broadcast_dims),
- output_shape.dim_sizes());
- context->SetOutput(0, output);
+ auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes());
+ OP_REQUIRES_OK(context, output.status());
+ context->SetOutput(0, output.ValueOrDie());
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index ef1015552d..234f7b4a01 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
// compute valid broadcast shapes, but rely below on XLA to
// automatically perform the broadcast assuming its valid shapes are
// a superset of TensorFlow's valid shapes.
- BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape));
+ BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape),
+ /*fewer_dims_optimization=*/false);
if (!bcast.IsValid()) {
ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
lhs_shape.DebugString(), " vs. ",
@@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
}
/* static */ std::pair<xla::XlaOp, xla::XlaOp> XlaBinaryOp::Broadcast(
- xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
- const BCast& broadcast_helper) {
- // Manually construct the broadcasting since MapN does not do
- // automatic broadcasting. The bcast helper ensures that
- // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and
- // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have
- // the same shape, so can be operated on by MapN.
-
- // First reshape the inputs, which should be a metadata-only
- // operation since we are flattening the dimensions in order.
- auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape());
- auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape());
-
- // Next broadcast the necessary input dimensions. We rely on the
- // XLA optimizer to be smart about the fact that we are asking
- // it to broadcast size 1 on some of these dimensions, to avoid
- // adding complexity to this code.
- auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast());
- int lhs_size = broadcast_helper.x_bcast().size();
- auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast());
- int rhs_size = broadcast_helper.y_bcast().size();
-
- // Now reshape them to the correct output shape. After the
- // broadcast each side is twice as wide as it should be, since the
- // broadcast dimensions were prepended to the shape. Reshape
- // flattening each original dimension with the prepended broadcast
- // dimension. E.g. if we started out with lhs_shaped with shape
- // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have
- // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21].
- std::vector<int64> lhs_reorder;
- for (int i = 0; i < lhs_size; ++i) {
- lhs_reorder.push_back(i);
- lhs_reorder.push_back(i + lhs_size);
+ xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) {
+ auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape());
+ if (!lhs_output.ok()) {
+ xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status());
+ return {error, error};
}
- auto lhs_output =
- xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape());
- std::vector<int64> rhs_reorder;
- for (int i = 0; i < rhs_size; ++i) {
- rhs_reorder.push_back(i);
- rhs_reorder.push_back(i + rhs_size);
+ auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape());
+ if (!rhs_output.ok()) {
+ xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status());
+ return {error, error};
}
- auto rhs_output =
- xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape());
-
- return {lhs_output, rhs_output};
+ return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()};
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
index 6653944a91..516ead4bfe 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
@@ -67,8 +67,7 @@ class XlaBinaryOp : public XlaOpKernel {
// 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
// shape.
static std::pair<xla::XlaOp, xla::XlaOp> Broadcast(
- xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
- const BCast& broadcast_helper);
+ xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper);
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 8597e7f139..1ce3930fd1 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -32,6 +32,22 @@ cc_library(
)
cc_library(
+ name = "broadcast",
+ srcs = ["broadcast.cc"],
+ hdrs = ["broadcast.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
name = "cholesky",
srcs = ["cholesky.cc"],
hdrs = ["cholesky.h"],
diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc
new file mode 100644
index 0000000000..3e402ef855
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc
@@ -0,0 +1,93 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
+
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace tensorflow {
+
+xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
+ absl::Span<int64 const> output_dims) {
+ xla::XlaBuilder* builder = input.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
+ absl::Span<int64 const> input_dims =
+ xla::AsInt64Slice(input_shape.dimensions());
+
+ if (input_dims == output_dims) {
+ return input;
+ }
+
+ if (input_dims.size() > output_dims.size()) {
+ return errors::InvalidArgument(
+ "Input shape (", xla::ShapeUtil::HumanString(input_shape),
+ ") must have rank less than or equal to the output shape [",
+ absl::StrJoin(output_dims, ","), "]");
+ }
+
+ std::vector<int64> broadcast_dims;
+ std::vector<int64> broadcast_shape;
+ auto input_it = input_dims.rbegin();
+ for (auto output_it = output_dims.rbegin(); output_it != output_dims.rend();
+ ++output_it) {
+ if (input_it != input_dims.rend()) {
+ if (!(*output_it == 0 && *input_it == 0) &&
+ !(*input_it != 0 && *output_it % *input_it == 0)) {
+ return errors::InvalidArgument("Invalid shape broadcast from ",
+ xla::ShapeUtil::HumanString(input_shape),
+ " to [", absl::StrJoin(output_dims, ","),
+ "]");
+ }
+
+ broadcast_dims.push_back(broadcast_shape.size());
+ if (*output_it == *input_it) {
+ broadcast_shape.push_back(*output_it);
+ } else if (*output_it != *input_it) {
+ // Add dimensions [I, O/I], which we will later flatten to just
+ // [O]. We must do this in two phases since XLA broadcasting does not
+ // support tiling.
+ broadcast_shape.push_back(*input_it);
+ broadcast_shape.push_back(*output_it / *input_it);
+ }
+ ++input_it;
+ } else {
+ broadcast_shape.push_back(*output_it);
+ }
+ }
+ TF_RET_CHECK(input_it == input_dims.rend());
+
+ absl::c_reverse(broadcast_dims);
+ int broadcast_shape_size = broadcast_shape.size();
+ for (int64& broadcast_dim : broadcast_dims) {
+ broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
+ }
+ absl::c_reverse(broadcast_shape);
+ xla::XlaOp output = xla::BroadcastInDim(
+ input,
+ xla::ShapeUtil::MakeShape(input_shape.element_type(), broadcast_shape),
+ broadcast_dims);
+ if (broadcast_shape != output_dims) {
+ output = xla::Reshape(output, output_dims);
+ }
+ return output;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h
new file mode 100644
index 0000000000..591e696f06
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/broadcast.h
@@ -0,0 +1,32 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_
+
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace tensorflow {
+
+// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting
+// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling.
+xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
+ absl::Span<int64 const> output_dims);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_