aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-10-08 14:26:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 14:34:02 -0700
commit396a8a4105edd409d0821c4d5d0b920b315ffb72 (patch)
tree428350d427ffb29470e284077a2734b273b7cc4d /tensorflow/compiler/tf2xla
parentbc5635dc3ac78007caee88fabd81d23ad945b637 (diff)
Add custom call with layout constraints.
Add a variant of CustomCall which specifies arbitrary layout constraints on the operands and result. The existing non-layout-constrained CustomCall is changed to have no layout preference and can now be assigned arbitrary layouts by layout assignment. PiperOrigin-RevId: 216249615
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc22
1 files changed, 16 insertions, 6 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 3d81ae9eb8..f210bfbd88 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -88,20 +88,30 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
}
- xla::Shape xla_shape =
- xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes());
+ // The argmax function expects row-major layout.
+ xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(
+ xla::S64, output_shape.dim_sizes());
+ std::vector<xla::Shape> arg_shapes;
+ for (const xla::XlaOp& arg : args) {
+ auto shape_status = b.GetShape(arg);
+ OP_REQUIRES_OK(ctx, shape_status.status());
+ xla::Shape arg_shape = shape_status.ConsumeValueOrDie();
+ *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout(
+ xla::ShapeUtil::Rank(arg_shape));
+ arg_shapes.push_back(std::move(arg_shape));
+ }
// Tell XLA to call the custom code, defined in
// index_ops_kernel_argmax_float_1d.cc.
xla::XlaOp output;
switch (input_shape.dims()) {
case 1:
- output =
- xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape);
+ output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args,
+ xla_shape, arg_shapes);
break;
case 2:
- output =
- xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape);
+ output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args,
+ xla_shape, arg_shapes);
break;
default:
OP_REQUIRES(ctx, false,