aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-08-15 13:04:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 13:13:46 -0700
commit75399bba469508a43cb784f1c261af63580528c6 (patch)
tree8f53f967f796a3eb4ea883f4180d3b3db247614c /tensorflow/compiler/xla
parentbc646fd57603a58485076186d7797e6dc4853978 (diff)
Support multiple index-space constraints in synthetic input generator.
PiperOrigin-RevId: 208868489
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc71
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc55
3 files changed, 92 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index b8e700ae97..0f8cffd466 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2076,6 +2076,8 @@ tf_cc_test(
xla_test(
name = "test_utils_test",
srcs = ["test_utils_test.cc"],
+ # There is nothing backend specific in this test, so just pick an arbitrary backend.
+ backends = ["cpu"],
deps = [
":local_client_test_base",
":test_utils",
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 2647937013..faeec657b6 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -208,16 +208,12 @@ bool NeedsInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
-std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
- const Shape& input_shape, const Shape& slice_shape,
- std::minstd_rand0* engine) {
- const int64 rank = ShapeUtil::Rank(input_shape);
- std::vector<int32> start_indices(rank);
+std::unique_ptr<Literal> MakeRandomIndex(
+ tensorflow::gtl::ArraySlice<int64> index_space, std::minstd_rand0* engine) {
+ std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
- for (int i = 0; i < rank; ++i) {
- const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) -
- ShapeUtil::GetDimension(slice_shape, i);
- std::uniform_int_distribution<int32> generator(0, upper_bound);
+ for (int i = 0; i < index_space.size(); ++i) {
+ std::uniform_int_distribution<int32> generator(0, index_space[i]);
start_indices[i] = generator(*engine);
}
}
@@ -267,37 +263,42 @@ std::vector<HloInstruction*> FindConstrainedUses(
StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
- HloInstruction* needs_index = nullptr;
- HloInstruction* needs_constant = nullptr;
+ std::vector<int64> index_space;
+ bool needs_constant = false;
ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
- case HloOpcode::kDynamicUpdateSlice:
- if (needs_index != nullptr) {
- auto needs_index_shape = needs_index->shape();
- auto use_shape = use->shape();
- if (needs_index->opcode() == HloOpcode::kDynamicSlice) {
- needs_index_shape = needs_index->operand(0)->shape();
- }
- if (use->opcode() == HloOpcode::kDynamicSlice) {
- use_shape = use->operand(0)->shape();
+ case HloOpcode::kDynamicUpdateSlice: {
+ const Shape& indexed_shape = use->operand(0)->shape();
+ const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
+ ? use->shape()
+ : use->operand(1)->shape();
+ const int64 rank = ShapeUtil::Rank(indexed_shape);
+ if (!index_space.empty()) {
+ TF_RET_CHECK(rank == index_space.size());
+ for (int64 i = 0; i < rank; ++i) {
+ index_space[i] = std::min(
+ index_space[i], ShapeUtil::GetDimension(indexed_shape, i) -
+ ShapeUtil::GetDimension(slice_shape, i));
}
- if (!ShapeUtil::Equal(needs_index_shape, use_shape)) {
- return Unimplemented(
- "Conflicting operand generation slice index constraints\n");
+ } else {
+ index_space.resize(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) -
+ ShapeUtil::GetDimension(slice_shape, i);
}
}
- needs_index = use;
break;
+ }
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
- needs_constant = use;
+ needs_constant = true;
constant_type = GetInitValue(*use->to_apply());
break;
case HloOpcode::kSelectAndScatter:
- needs_constant = use;
+ needs_constant = true;
constant_type = GetInitValue(*use->scatter());
break;
@@ -307,16 +308,14 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
use->ToString().c_str());
}
}
- if (needs_index != nullptr && needs_constant != nullptr) {
+ if (!index_space.empty() && needs_constant) {
return Unimplemented(
- "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
- "constant: %s\n",
- needs_index->ToString().c_str(), needs_constant->ToString().c_str());
+ "Conflicting operand generation constraints. Dynamically indexes a "
+ "shape and is the init value of a reduction.");
}
- if (needs_index != nullptr) {
- return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
- needs_index->shape(), engine);
- } else if (needs_constant != nullptr) {
+ if (!index_space.empty()) {
+ return MakeRandomIndex(index_space, engine);
+ } else if (needs_constant) {
switch (constant_type) {
case ConstantType::kZero:
return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
@@ -356,8 +355,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
std::vector<std::unique_ptr<Literal>> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
- TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument(
- *dataflow, *params[i], engine.get()));
+ arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get())
+ .ValueOrDie();
}
return std::move(arguments);
}
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index a2f0338e25..64d9e2031e 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -72,5 +72,60 @@ XLA_TEST_F(TestUtilsTest, Token) {
TF_ASSERT_OK(MakeFakeArguments(module.get()).status());
}
+XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
+ auto module = ParseHloString(
+ R"(HloModule index_space_module
+
+ ENTRY IndexSpace {
+ index_param = s32[3]{0} parameter(0)
+ array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
+ array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
+ dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3}
+ ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
+ })")
+ .ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ MakeFakeArguments(module.get()));
+ ASSERT_EQ(args.size(), 3);
+ const Literal& index_arg = *args[0];
+
+ EXPECT_EQ(index_arg.Get<int32>({0}), 0);
+
+ EXPECT_GE(index_arg.Get<int32>({1}), 0);
+ EXPECT_LE(index_arg.Get<int32>({1}), 2);
+
+ EXPECT_GE(index_arg.Get<int32>({2}), 0);
+ EXPECT_LE(index_arg.Get<int32>({2}), 3);
+}
+
+XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
+ auto module = ParseHloString(
+ R"(HloModule index_space_module
+
+ ENTRY IndexSpace {
+ index_param = s32[3]{0} parameter(0)
+ array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
+ array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
+ update_param.1 = f32[1,2,3]{0,1,2} parameter(3)
+ update_param.2 = f32[3,2,2]{0,1,2} parameter(4)
+
+ dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param)
+ ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
+ })")
+ .ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ MakeFakeArguments(module.get()));
+ ASSERT_EQ(args.size(), 5);
+ const Literal& index_arg = *args[0];
+
+ EXPECT_EQ(index_arg.Get<int32>({0}), 0);
+
+ EXPECT_GE(index_arg.Get<int32>({1}), 0);
+ EXPECT_LE(index_arg.Get<int32>({1}), 2);
+
+ EXPECT_GE(index_arg.Get<int32>({2}), 0);
+ EXPECT_LE(index_arg.Get<int32>({2}), 3);
+}
+
} // namespace
} // namespace xla