aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-07-31 10:17:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 10:21:35 -0700
commitcd4fccda8b60bda2b36859b4a4e7eae37ed54cce (patch)
tree0695becd1d665623837deede9836c6e2e825c6e8
parent3be04971716fcaf0c11ad9262e60efa428553e14 (diff)
Simplify sort in Algebraic simplifier if no sorting is needed.
PiperOrigin-RevId: 206782323
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc17
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc34
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.cc6
3 files changed, 51 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 505c0e8dff..946ef6f0d6 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -150,6 +150,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice) override;
+ Status HandleSort(HloInstruction* sort) override;
+
Status HandleTranspose(HloInstruction* transpose) override;
Status HandleSubtract(HloInstruction* sub) override;
@@ -2105,6 +2107,21 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
/*reduce_computation=*/function));
}
+Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
+ auto operand = sort->mutable_operand(0);
+ int64 dimension_to_sort = sort->dimensions(0);
+ if (ShapeUtil::IsZeroElementArray(operand->shape()) ||
+ operand->shape().dimensions(dimension_to_sort) <= 1) {
+ if (sort->operand_count() == 1) {
+ return ReplaceInstruction(sort, operand);
+ }
+ // If it is key/value sort, the output of sort is a tuple.
+ return ReplaceWithNewInstruction(
+ sort, HloInstruction::CreateTuple({operand, sort->mutable_operand(1)}));
+ }
+ return Status::OK();
+}
+
Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
auto operand = transpose->mutable_operand(0);
if (std::is_sorted(transpose->dimensions().begin(),
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 8b81b4c97e..ad14fe6f2c 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1941,6 +1941,40 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) {
EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4);
}
+TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape keys_shape = ShapeUtil::MakeShape(F32, {1});
+ auto keys = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, keys_shape, "keys"));
+ builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys));
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
+ EXPECT_THAT(computation->root_instruction(), keys);
+}
+
+TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0});
+ Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0});
+ auto keys = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, keys_shape, "keys"));
+ auto values = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, values_shape, "values"));
+ builder.AddInstruction(HloInstruction::CreateSort(
+ ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values));
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
+ EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values));
+}
+
TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
struct ConvTestOptions {
int in_batch = 10;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
index 5187948e29..e546f5cc4a 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
@@ -93,12 +93,6 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
const gpu::LaunchDimensions* launch_dimensions) {
const Shape& keys_shape = keys_array.GetShape();
- // TODO(b/26783907): This case can probably be avoided with the Algebraic
- // Simplifier.
- if (ShapeUtil::IsScalar(keys_shape)) {
- return Status::OK();
- }
-
// Create loop nests which loop through the operand dimensions. The sort
// dimension is handled in the innermost loop which performs the sorting.
ForLoopNest loop_nest(name, b);