aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/gather_op.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-10-11 15:18:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-11 15:23:18 -0700
commitccfa8f4f1492c5cf1a7db35b2dba1f7b5424f0e2 (patch)
tree192d0e1243e4c5453aaeb9880e3150ff351ba440 /tensorflow/compiler/tf2xla/kernels/gather_op.cc
parent10d0ae696c7b5618cae9e3845af8300fe62870a2 (diff)
[XLA:CPU] Switch TF gather's HLO implementation to use dynamic-update-slice in a "while" loop.
Benchmarks results (times in ms): nontrivial_gather.axis0_cpu: 0.110 nontrivial_gather.axis0_xla_cpu: 0.139 nontrivial_gather.axis1_cpu: 0.093 nontrivial_gather.axis1_xla_cpu: 0.142 nontrivial_gather.axis4_cpu: 1.183 nontrivial_gather.axis4_xla_cpu: 2.658 slice_gather.axis0_cpu: 0.00388 slice_gather.axis0_xla_cpu: 0.00397 slice_gather.axis1_cpu: 0.00421 slice_gather.axis1_xla_cpu: 0.00427 slice_gather.axis4_cpu: 0.252 slice_gather.axis4_xla_cpu: 0.114 As you can see, the pure-XLA implementation is slower in all the nontrivial cases and as-fast or faster in the slice-gather cases. The slice-gather cases are gathers that can be implemented as a single XLA dynamic-slice, and so the speedup here is likely understated: Once we can simplify the gather to a single dynamic-slice, we should be able to do many other optimizations to it, ideally fusing it so it has zero cost. The nontrivial gathers all gather more than one element, and are implemented with an XLA while loop. The most important one is the axis 0 gather -- gathering from an inner dimension is so slow no matter what you do that it's probably not worth optimizing. It's possible to make this XLA implementation faster -- one option I've considered is "unrolling" the gather into a series of dynamic-slice's that are then concat'ed together. This would be totally fusable, unlike the implementation in this CL. Another option would be adding a notion of uninitialized memory into XLA -- part of what makes us slow is that we have to initialize the memset our output to 0 before we overwrite it. But given that the shape we're benchmarking here is totally arbitrary, and given that we're getting decent performance, I think this is good enough to start with. PiperOrigin-RevId: 171883273
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/gather_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc165
1 files changed, 23 insertions, 142 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index 2c7d445600..db449ec345 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -30,7 +30,7 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice(
XlaOpKernelContext* context, const xla::ComputationDataHandle& input,
const TensorShape& input_shape, const xla::ComputationDataHandle& indices,
const TensorShape& indices_shape, int64 axis, DataType dtype,
- xla::ComputationBuilder* builder) {
+ DataType index_type, xla::ComputationBuilder* builder) {
// Although the indices Tensor is flattened into rank 1 during the lookup,
// and each scalar entry is used as an index into the first dimension of the
// input, the output is returned with shape:
@@ -80,22 +80,23 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice(
// Specify the shape of the loop-carried Tensor tuple.
xla::PrimitiveType ptype;
TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype));
+ xla::PrimitiveType idxtype;
+ TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &idxtype));
std::vector<xla::Shape> tuple_shapes(
{// The iteration counter i is a scalar, incremented each iteration.
- xla::ShapeUtil::MakeShape(xla::S32, {}),
+ xla::ShapeUtil::MakeShape(idxtype, {}),
// The input array has shape input_shape. Loop invariant.
xla::ShapeUtil::MakeShape(ptype, input_shape.dim_sizes()),
// The gather indices are reshaped to rank 1. Loop invariant.
- xla::ShapeUtil::MakeShape(xla::S32, {num_indices}),
+ xla::ShapeUtil::MakeShape(idxtype, {num_indices}),
// The output array is rank >= 3, and is updated on each loop iteration.
xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())});
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
// Construct the initial values of the loop-carried Tensors.
- auto init_i = builder->ConstantR0<int32>(0);
- auto init_out =
- builder->Broadcast(builder->ConstantLiteral(xla::Literal::Zero(ptype)),
- loop_out_shape.dim_sizes());
+ auto init_i = XlaHelpers::Zero(builder, index_type);
+ auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
+ loop_out_shape.dim_sizes());
// Flatten the indices into 1-D for ease of iteration.
auto indices_1d = builder->Reshape(indices, {num_indices});
auto init = builder->Tuple({init_i, input, indices_1d, init_out});
@@ -105,7 +106,7 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice(
"GatherWhileCond");
condb.Lt(condb.GetTupleElement(
condb.Parameter(0, tuple_shape, "GatherWhileTuple"), 0),
- condb.ConstantR0<int32>(num_indices));
+ XlaHelpers::IntegerLiteral(&condb, index_type, num_indices));
auto cond_status = condb.Build();
auto cond = cond_status.ConsumeValueOrDie();
@@ -127,7 +128,7 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice(
// Slice from the input array.
auto index = bodyb.DynamicSlice(indices, bodyb.Reshape(i, {1}), {1});
auto start_indices = bodyb.Pad(
- bodyb.Reshape(index, {1}), bodyb.ConstantR0<int32>(0),
+ bodyb.Reshape(index, {1}), XlaHelpers::Zero(&bodyb, index_type),
xla::MakeEdgePaddingConfig(
{{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}}));
auto slice_i = bodyb.Reshape(
@@ -136,7 +137,8 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice(
// Construct the index into the R3+ output Tensor 0, ..., <index>, 0, ...
std::vector<xla::ComputationDataHandle> out_index_vals(
- loop_out_shape.dims(), bodyb.ConstantR1<int32>({0}));
+ loop_out_shape.dims(),
+ bodyb.Reshape(XlaHelpers::Zero(&bodyb, index_type), {1}));
out_index_vals[input_shape_pre_axis.dims() + extra_dims] =
bodyb.Reshape(i, {1});
auto out_index = bodyb.ConcatInDim(out_index_vals, 0);
@@ -144,8 +146,8 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice(
// Update the output Tensor
auto updated_output = bodyb.DynamicUpdateSlice(output, slice_i, out_index);
- bodyb.Tuple({bodyb.Add(i, bodyb.ConstantR0<int32>(1)), input, indices,
- updated_output});
+ bodyb.Tuple({bodyb.Add(i, XlaHelpers::One(&bodyb, index_type)), input,
+ indices, updated_output});
}
auto body_status = bodyb.Build();
auto body = body_status.ConsumeValueOrDie();
@@ -156,124 +158,6 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice(
return builder->Reshape(gather_output, out_shape.dim_sizes());
}
-namespace {
-
-class GatherOpCustomCall : public XlaOpKernel {
- public:
- explicit GatherOpCustomCall(OpKernelConstruction* context)
- : XlaOpKernel(context) {}
-
- void Compile(XlaOpKernelContext* context) override {
- const TensorShape params_shape = context->InputShape(0);
- const auto params_dims = params_shape.dims();
- const TensorShape indices_shape = context->InputShape(1);
- OP_REQUIRES(
- context, TensorShapeUtils::IsVectorOrHigher(params_shape),
- errors::InvalidArgument("params must be at least 1 dimensional"));
-
- DataType index_type = input_type(1);
- OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
- errors::InvalidArgument("index must be int32 or int64"));
-
- // GatherV2 added an axis argument. We support both Gather and GatherV2 in
- // this kernel by defaulting axis to 0 if there are 2 inputs.
- int64 axis = 0;
- if (context->num_inputs() == 3) {
- const TensorShape axis_shape = context->InputShape(2);
- OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
- errors::InvalidArgument("axis must be scalar"));
- DataType axis_type = input_type(2);
- OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
- errors::InvalidArgument("axis must be int32 or int64"));
-
- xla::Literal literal;
- OP_REQUIRES_OK(context, context->ConstantInput(2, &literal));
- int64 axis_input = axis_type == DT_INT32 ? literal.Get<int32>({})
- : literal.Get<int64>({});
- axis = axis_input < 0 ? axis_input + params_dims : axis_input;
- OP_REQUIRES(context, 0 <= axis && axis < params_dims,
- errors::InvalidArgument("Expected axis in the range [",
- -params_dims, ", ", params_dims,
- "), but got ", axis_input));
- }
-
- // Check that we have enough index space.
- const int64 limit = index_type == DT_INT32
- ? std::numeric_limits<int32>::max()
- : std::numeric_limits<int64>::max();
- OP_REQUIRES(context, params_shape.dim_size(axis) <= limit,
- errors::InvalidArgument(
- "params.shape[", axis, "] too large for ",
- DataTypeString(index_type),
- " indexing: ", params_shape.dim_size(axis), " > ", limit));
-
- // The result shape is params.shape[0:axis] + indices.shape +
- // params.shape[axis + 1:].
- TensorShape result_shape;
- int64 outer_size = 1;
- int64 inner_size = 1;
- for (int i = 0; i < axis; i++) {
- result_shape.AddDim(params_shape.dim_size(i));
- outer_size *= params_shape.dim_size(i);
- }
- result_shape.AppendShape(indices_shape);
- for (int i = axis + 1; i < params_dims; i++) {
- result_shape.AddDim(params_shape.dim_size(i));
- inner_size *= params_shape.dim_size(i);
- }
-
- XlaContext& tc = XlaContext::Get(context);
- OP_REQUIRES(
- context, tc.allow_cpu_custom_calls(),
- errors::InvalidArgument("Gather op requires CustomCall on CPU"));
-
- xla::ComputationBuilder& b = *context->builder();
-
- // Call gather_xla_float_kernel (from gather_op_kernel_float.cc).
- // XLA passes <out> to the function, so it is not included here.
- std::vector<xla::ComputationDataHandle> args;
- args.push_back(tc.GetOrCreateRuntimeContextParameter());
- args.push_back(b.ConstantLiteral(
- *xla::Literal::CreateR0<int64>(indices_shape.num_elements())));
- args.push_back(
- b.ConstantLiteral(*xla::Literal::CreateR0<int64>(outer_size)));
- args.push_back(b.ConstantLiteral(
- *xla::Literal::CreateR0<int64>(params_shape.dim_size(axis))));
- args.push_back(
- b.ConstantLiteral(*xla::Literal::CreateR0<int64>(inner_size)));
- args.push_back(context->Input(0));
- args.push_back(context->Input(1));
-
- xla::Shape xla_out_shape;
- OP_REQUIRES_OK(
- context, TensorShapeToXLAShape(DT_FLOAT, result_shape, &xla_out_shape));
-
- // Call the custom code with args:
- xla::ComputationDataHandle output;
- if (index_type == DT_INT32) {
- output = b.CustomCall("gather_float_int32_xla_impl", args, xla_out_shape);
- } else {
- output = b.CustomCall("gather_float_int64_xla_impl", args, xla_out_shape);
- }
-
- context->SetOutput(0, output);
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(GatherOpCustomCall);
-};
-
-REGISTER_XLA_OP(Name("Gather")
- .TypeConstraint("Tparams", DT_FLOAT)
- .Device(DEVICE_CPU_XLA_JIT),
- GatherOpCustomCall);
-REGISTER_XLA_OP(Name("GatherV2")
- .TypeConstraint("Tparams", DT_FLOAT)
- .Device(DEVICE_CPU_XLA_JIT),
- GatherOpCustomCall);
-
-} // namespace
-
GatherOpDynamicSlice::GatherOpDynamicSlice(OpKernelConstruction* context)
: XlaOpKernel(context) {}
@@ -303,20 +187,17 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) {
", ", params_dims, "), but got ", axis));
}
- xla::ComputationDataHandle gather =
- XlaComputeGatherDynamicSlice(context, input, input_shape, indices,
- indices_shape, axis, DT_FLOAT, builder);
+ DataType index_type = input_type(1);
+ OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
+ errors::InvalidArgument("indices must be int32 or int64"));
+
+ xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
+ context, input, input_shape, indices, indices_shape, axis, DT_FLOAT,
+ index_type, builder);
context->SetOutput(0, gather);
}
-REGISTER_XLA_OP(Name("Gather")
- .TypeConstraint("Tparams", DT_FLOAT)
- .Device(DEVICE_GPU_XLA_JIT),
- GatherOpDynamicSlice);
-
-REGISTER_XLA_OP(Name("GatherV2")
- .TypeConstraint("Tparams", DT_FLOAT)
- .Device(DEVICE_GPU_XLA_JIT),
- GatherOpDynamicSlice);
+REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice);
+REGISTER_XLA_OP(Name("GatherV2"), GatherOpDynamicSlice);
} // namespace tensorflow