aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-03-20 11:13:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-20 11:18:12 -0700
commit1c4e42b39fd9ae2da14d7eb323bedc144a6e659b (patch)
treefa911d6aaf141acbbc1f7a8efc37c8c0165534a5
parent3bca4298aacd9f89de2ac532bb7fedcdec1a5bb6 (diff)
Use 32 bit induction variable in gather expander
Right now this is unconditional (and we fail with Unimplemented() if a 32 bit induction variable is not large enough), but eventually we may want to be smarter about this. PiperOrigin-RevId: 189773581
-rw-r--r--tensorflow/compiler/xla/service/BUILD12
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc19
-rw-r--r--tensorflow/compiler/xla/service/gather_expander_test.cc51
-rw-r--r--tensorflow/compiler/xla/service/while_util.cc21
-rw-r--r--tensorflow/compiler/xla/service/while_util.h2
-rw-r--r--tensorflow/compiler/xla/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc60
-rw-r--r--tensorflow/compiler/xla/util.h9
8 files changed, 163 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 43c56484ea..d4d67872cf 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1276,6 +1276,18 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "gather_expander_test",
+ srcs = ["gather_expander_test.cc"],
+ deps = [
+ ":gather_expander",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/tests:test_macros_header",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ ],
+)
+
cc_library(
name = "conditional_simplifier",
srcs = ["conditional_simplifier.cc"],
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index 488bed35fe..221ff7900f 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -306,18 +306,33 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloComputation* computation = gather_instr->parent();
HloInstruction* operand = gather_instr->mutable_operand(0);
HloInstruction* gather_indices = gather_instr->mutable_operand(1);
+ const Shape& gather_indices_shape = gather_indices->shape();
const Shape& output_shape = gather_instr->shape();
int64 output_rank = output_shape.dimensions_size();
const GatherDimensionNumbers& dim_numbers =
gather_instr->gather_dimension_numbers();
+ int64 gather_loop_trip_count = 1;
+ for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ if (i != dim_numbers.index_vector_dim()) {
+ gather_loop_trip_count *= gather_indices_shape.dimensions(i);
+ }
+ }
+
+ if (!IsInt32(gather_loop_trip_count)) {
+ return Unimplemented(
+ "Gather operations with more than 2147483647 gather indices are not "
+ "supported. This error occurred for %s.",
+ gather_instr->ToString().c_str());
+ }
+
TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices,
CanonicalizeGatherIndices(
gather_indices, dim_numbers.index_vector_dim()));
- const int64 gather_loop_trip_count =
- canonical_gather_indices->shape().dimensions(0);
+ CHECK_EQ(gather_loop_trip_count,
+ canonical_gather_indices->shape().dimensions(0));
TF_ASSIGN_OR_RETURN(
HloInstruction * accumulator_init,
diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc
new file mode 100644
index 0000000000..ba41ee8428
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gather_expander_test.cc
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gather_expander.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace {
+TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) {
+ const string hlo_text = R"(
+HloModule TensorFlowGatherMultipleBatchDims
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2147483647,5] parameter(1)
+ ROOT gather = s32[2147483647,3,5] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=2,
+ window_bounds={3, 1}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_text));
+
+ Status status = GatherExpander{}.Run(module.get()).status();
+ EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
+
+ ASSERT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Gather operations with more than 2147483647 gather "
+ "indices are not supported."));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc
index 8cd5882f32..bd07941843 100644
--- a/tensorflow/compiler/xla/service/while_util.cc
+++ b/tensorflow/compiler/xla/service/while_util.cc
@@ -142,23 +142,23 @@ WhileUtil::MakeInstructionsLiveIn(
static StatusOr<std::unique_ptr<HloComputation>>
MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
- int64 trip_count) {
+ int32 trip_count) {
Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
- Shape scalar_s64 = ShapeUtil::MakeShape(S64, {});
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> cond_computation,
CreateComputationWithSignature(
{&loop_state_shape}, scalar_pred, "while_cond"));
HloInstruction* trip_count_constant = cond_computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(trip_count)));
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(trip_count)));
HloInstruction* param = cond_computation->parameter_instruction(0);
- TF_ASSIGN_OR_RETURN(HloInstruction * counter,
+ TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
MakeGetTupleElementHlo(param, 0));
+
TF_ASSIGN_OR_RETURN(
HloInstruction * compare,
- MakeBinaryHlo(HloOpcode::kLt, counter, trip_count_constant));
+ MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant));
cond_computation->set_root_instruction(compare);
return std::move(cond_computation);
}
@@ -171,8 +171,7 @@ static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
CreateComputationWithSignature(
{&loop_state_shape}, loop_state_shape, "while_body"));
HloInstruction* one = body_computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(1)));
-
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
HloInstruction* param = body_computation->parameter_instruction(0);
TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
MakeGetTupleElementHlo(param, 0));
@@ -200,7 +199,7 @@ static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
std::vector<HloInstruction*> init_values_with_indvar;
init_values_with_indvar.reserve(init_values.size() + 1);
HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(0)));
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
init_values_with_indvar.push_back(zero);
c_copy(init_values, std::back_inserter(init_values_with_indvar));
return computation->AddInstruction(
@@ -210,16 +209,18 @@ static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) {
std::vector<Shape> loop_state_shape_components;
loop_state_shape_components.reserve(init_values.size() + 1);
- loop_state_shape_components.push_back(ShapeUtil::MakeShape(S64, {}));
+ loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
c_transform(init_values, std::back_inserter(loop_state_shape_components),
[](HloInstruction* instr) { return instr->shape(); });
return ShapeUtil::MakeTupleShape(loop_state_shape_components);
}
/*static*/ StatusOr<WhileUtil::LoopStateTy> WhileUtil::MakeCountedLoop(
- HloComputation* computation, int64 trip_count,
+ HloComputation* computation, int32 trip_count,
const WhileUtil::LoopStateTy& init_values,
const WhileUtil::LoopBodyGeneratorTy& loop_body_generator) {
+ CHECK_GE(trip_count, 0);
+
Shape loop_state_shape = MakeLoopStateShape(init_values);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloComputation> cond,
diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h
index 80f7e16e64..1688d46742 100644
--- a/tensorflow/compiler/xla/service/while_util.h
+++ b/tensorflow/compiler/xla/service/while_util.h
@@ -71,7 +71,7 @@ class WhileUtil {
// return loop_state;
// }
static StatusOr<LoopStateTy> MakeCountedLoop(
- HloComputation* computation, int64 trip_count,
+ HloComputation* computation, int32 trip_count,
const LoopStateTy& init_values,
const LoopBodyGeneratorTy& loop_body_generator);
};
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 025ac129d7..04a9c1ef79 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -676,7 +676,9 @@ xla_test(
name = "gather_operation_test",
srcs = ["gather_operation_test.cc"],
deps = [
+ ":client_library_test_base",
":hlo_test_base",
+ "//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 8ba91946c0..9db68ff7a6 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
@@ -397,5 +399,63 @@ ENTRY main {
RunTest(hlo_text, operand.get(), gather_indices.get());
}
+class GatherClientLibraryTest : public ClientLibraryTestBase {};
+
+// TODO(b/30671675): Asynchronous execution on stream is not yet supported on
+// GPU and CPU_PARALLEL.
+XLA_TEST_F(GatherClientLibraryTest,
+ DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(Basic))) {
+ // We create this HLO, but using the ComputationBuilder API.
+ //
+ // ENTRY main {
+ // operand = s32[3,3] parameter(0)
+ // indices = s32[2] parameter(1)
+ // ROOT gather = s32[2,3] gather(operand, indices),
+ // output_window_dims={1},
+ // elided_window_dims={0},
+ // gather_dims_to_operand_dims={0},
+ // index_vector_dim=1,
+ // window_bounds={1, 3}
+ // }
+
+ ComputationBuilder builder(client_, "gather_basic");
+
+ Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3});
+ Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
+
+ auto operand = builder.Parameter(0, operand_shape, "operand");
+ auto indices = builder.Parameter(1, indices_shape, "indices");
+ GatherDimensionNumbers dim_numbers;
+ dim_numbers.add_output_window_dims(1);
+ dim_numbers.add_elided_window_dims(0);
+ dim_numbers.add_gather_dims_to_operand_dims(0);
+ dim_numbers.set_index_vector_dim(1);
+ builder.Gather(operand, indices, dim_numbers, {1, 3});
+
+ std::vector<int32> expected = {};
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> operand_arg,
+ client_->TransferToServer(*Literal::CreateR2<int32>(
+ {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<GlobalData> indices_arg,
+ client_->TransferToServer(*Literal::CreateR1<int32>({0, 2})));
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
+ client_->GetDeviceHandles(1));
+ xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
+ *execution_options.add_device_handles() = devices[0];
+ TF_ASSERT_OK_AND_ASSIGN(Computation computation, builder.Build());
+ std::vector<xla::Client::ComputationInstance> computation_instances = {
+ {computation,
+ {operand_arg.get(), indices_arg.get()},
+ execution_options,
+ /*execution_profile=*/nullptr}};
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::vector<std::unique_ptr<xla::GlobalData>> result_data,
+ client_->ExecuteParallel(computation_instances));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
+ client_->Transfer(*(result_data[0])));
+ LiteralTestUtil::ExpectEqual(
+ *result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}));
+}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index ff99d3728d..2da9f9ed6f 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -519,6 +519,15 @@ int64 FindIndex(const C& c, Value&& value) {
auto it = c_find(c, std::forward<Value>(value));
return std::distance(c.begin(), it);
}
+
+// Returns true if `x` fits in 32-bits.
+template <typename T>
+bool IsInt32(T x) {
+ // Following conversion rules: "the value is unchanged if it can be
+ // represented in the destination type (and bit-field width); otherwise, the
+ // value is implementation-defined."
+ return static_cast<int32>(x) == x;
+}
} // namespace xla
#define XLA_LOG_LINES(SEV, STRING) \