aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-10-05 08:03:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 08:06:48 -0700
commit388ed2929ea024adcfb76ea9ddd78a38a87470b7 (patch)
tree47bc22ef16b7630c8df784bec6e796a393bb47e5 /tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
parent92c8a77ba480bf4aeddea412cc1d2988f6ad81cd (diff)
[TF:XLA] Move broadcasting code out of BroadcastTo op into a common helper library.
Change XlaBinaryOp::Broadcast to use the BroadcastTo lowering, since it produces fewer extraneous reshapes and transposes. Even if the reshapes and transposes would later optimize away, this yields more readable output and makes life easier for HLO rewrites that run early. Change in preparation for removing reshapes from SoftmaxCrossEntropyWithLogits. PiperOrigin-RevId: 215906847
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/cwise_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.cc57
1 files changed, 13 insertions, 44 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index ef1015552d..234f7b4a01 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
// compute valid broadcast shapes, but rely below on XLA to
// automatically perform the broadcast assuming its valid shapes are
// a superset of TensorFlow's valid shapes.
- BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape));
+ BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape),
+ /*fewer_dims_optimization=*/false);
if (!bcast.IsValid()) {
ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
lhs_shape.DebugString(), " vs. ",
@@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
}
/* static */ std::pair<xla::XlaOp, xla::XlaOp> XlaBinaryOp::Broadcast(
- xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
- const BCast& broadcast_helper) {
- // Manually construct the broadcasting since MapN does not do
- // automatic broadcasting. The bcast helper ensures that
- // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and
- // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have
- // the same shape, so can be operated on by MapN.
-
- // First reshape the inputs, which should be a metadata-only
- // operation since we are flattening the dimensions in order.
- auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape());
- auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape());
-
- // Next broadcast the necessary input dimensions. We rely on the
- // XLA optimizer to be smart about the fact that we are asking
- // it to broadcast size 1 on some of these dimensions, to avoid
- // adding complexity to this code.
- auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast());
- int lhs_size = broadcast_helper.x_bcast().size();
- auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast());
- int rhs_size = broadcast_helper.y_bcast().size();
-
- // Now reshape them to the correct output shape. After the
- // broadcast each side is twice as wide as it should be, since the
- // broadcast dimensions were prepended to the shape. Reshape
- // flattening each original dimension with the prepended broadcast
- // dimension. E.g. if we started out with lhs_shaped with shape
- // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have
- // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21].
- std::vector<int64> lhs_reorder;
- for (int i = 0; i < lhs_size; ++i) {
- lhs_reorder.push_back(i);
- lhs_reorder.push_back(i + lhs_size);
+ xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) {
+ auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape());
+ if (!lhs_output.ok()) {
+ xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status());
+ return {error, error};
}
- auto lhs_output =
- xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape());
- std::vector<int64> rhs_reorder;
- for (int i = 0; i < rhs_size; ++i) {
- rhs_reorder.push_back(i);
- rhs_reorder.push_back(i + rhs_size);
+ auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape());
+ if (!rhs_output.ok()) {
+ xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status());
+ return {error, error};
}
- auto rhs_output =
- xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape());
-
- return {lhs_output, rhs_output};
+ return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()};
}
} // namespace tensorflow