aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/cwise_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.cc57
1 files changed, 13 insertions, 44 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index ef1015552d..234f7b4a01 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
// compute valid broadcast shapes, but rely below on XLA to
// automatically perform the broadcast assuming its valid shapes are
// a superset of TensorFlow's valid shapes.
- BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape));
+ BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape),
+ /*fewer_dims_optimization=*/false);
if (!bcast.IsValid()) {
ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
lhs_shape.DebugString(), " vs. ",
@@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
}
/* static */ std::pair<xla::XlaOp, xla::XlaOp> XlaBinaryOp::Broadcast(
- xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
- const BCast& broadcast_helper) {
- // Manually construct the broadcasting since MapN does not do
- // automatic broadcasting. The bcast helper ensures that
- // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and
- // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have
- // the same shape, so can be operated on by MapN.
-
- // First reshape the inputs, which should be a metadata-only
- // operation since we are flattening the dimensions in order.
- auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape());
- auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape());
-
- // Next broadcast the necessary input dimensions. We rely on the
- // XLA optimizer to be smart about the fact that we are asking
- // it to broadcast size 1 on some of these dimensions, to avoid
- // adding complexity to this code.
- auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast());
- int lhs_size = broadcast_helper.x_bcast().size();
- auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast());
- int rhs_size = broadcast_helper.y_bcast().size();
-
- // Now reshape them to the correct output shape. After the
- // broadcast each side is twice as wide as it should be, since the
- // broadcast dimensions were prepended to the shape. Reshape
- // flattening each original dimension with the prepended broadcast
- // dimension. E.g. if we started out with lhs_shaped with shape
- // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have
- // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21].
- std::vector<int64> lhs_reorder;
- for (int i = 0; i < lhs_size; ++i) {
- lhs_reorder.push_back(i);
- lhs_reorder.push_back(i + lhs_size);
+ xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) {
+ auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape());
+ if (!lhs_output.ok()) {
+ xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status());
+ return {error, error};
}
- auto lhs_output =
- xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape());
- std::vector<int64> rhs_reorder;
- for (int i = 0; i < rhs_size; ++i) {
- rhs_reorder.push_back(i);
- rhs_reorder.push_back(i + rhs_size);
+ auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape());
+ if (!rhs_output.ok()) {
+ xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status());
+ return {error, error};
}
- auto rhs_output =
- xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape());
-
- return {lhs_output, rhs_output};
+ return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()};
}
} // namespace tensorflow