aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-08-15 13:04:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 13:13:46 -0700
commit75399bba469508a43cb784f1c261af63580528c6 (patch)
tree8f53f967f796a3eb4ea883f4180d3b3db247614c /tensorflow/compiler/xla/tests/test_utils.cc
parentbc646fd57603a58485076186d7797e6dc4853978 (diff)
Support multiple index-space constraints in synthetic input generator.
PiperOrigin-RevId: 208868489
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc71
1 files changed, 35 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 2647937013..faeec657b6 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -208,16 +208,12 @@ bool NeedsInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
-std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
- const Shape& input_shape, const Shape& slice_shape,
- std::minstd_rand0* engine) {
- const int64 rank = ShapeUtil::Rank(input_shape);
- std::vector<int32> start_indices(rank);
+std::unique_ptr<Literal> MakeRandomIndex(
+ tensorflow::gtl::ArraySlice<int64> index_space, std::minstd_rand0* engine) {
+ std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
- for (int i = 0; i < rank; ++i) {
- const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) -
- ShapeUtil::GetDimension(slice_shape, i);
- std::uniform_int_distribution<int32> generator(0, upper_bound);
+ for (int i = 0; i < index_space.size(); ++i) {
+ std::uniform_int_distribution<int32> generator(0, index_space[i]);
start_indices[i] = generator(*engine);
}
}
@@ -267,37 +263,42 @@ std::vector<HloInstruction*> FindConstrainedUses(
StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
- HloInstruction* needs_index = nullptr;
- HloInstruction* needs_constant = nullptr;
+ std::vector<int64> index_space;
+ bool needs_constant = false;
ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
- case HloOpcode::kDynamicUpdateSlice:
- if (needs_index != nullptr) {
- auto needs_index_shape = needs_index->shape();
- auto use_shape = use->shape();
- if (needs_index->opcode() == HloOpcode::kDynamicSlice) {
- needs_index_shape = needs_index->operand(0)->shape();
- }
- if (use->opcode() == HloOpcode::kDynamicSlice) {
- use_shape = use->operand(0)->shape();
+ case HloOpcode::kDynamicUpdateSlice: {
+ const Shape& indexed_shape = use->operand(0)->shape();
+ const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
+ ? use->shape()
+ : use->operand(1)->shape();
+ const int64 rank = ShapeUtil::Rank(indexed_shape);
+ if (!index_space.empty()) {
+ TF_RET_CHECK(rank == index_space.size());
+ for (int64 i = 0; i < rank; ++i) {
+ index_space[i] = std::min(
+ index_space[i], ShapeUtil::GetDimension(indexed_shape, i) -
+ ShapeUtil::GetDimension(slice_shape, i));
}
- if (!ShapeUtil::Equal(needs_index_shape, use_shape)) {
- return Unimplemented(
- "Conflicting operand generation slice index constraints\n");
+ } else {
+ index_space.resize(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) -
+ ShapeUtil::GetDimension(slice_shape, i);
}
}
- needs_index = use;
break;
+ }
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
- needs_constant = use;
+ needs_constant = true;
constant_type = GetInitValue(*use->to_apply());
break;
case HloOpcode::kSelectAndScatter:
- needs_constant = use;
+ needs_constant = true;
constant_type = GetInitValue(*use->scatter());
break;
@@ -307,16 +308,14 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
use->ToString().c_str());
}
}
- if (needs_index != nullptr && needs_constant != nullptr) {
+ if (!index_space.empty() && needs_constant) {
return Unimplemented(
- "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
- "constant: %s\n",
- needs_index->ToString().c_str(), needs_constant->ToString().c_str());
+ "Conflicting operand generation constraints. Dynamically indexes a "
+ "shape and is the init value of a reduction.");
}
- if (needs_index != nullptr) {
- return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
- needs_index->shape(), engine);
- } else if (needs_constant != nullptr) {
+ if (!index_space.empty()) {
+ return MakeRandomIndex(index_space, engine);
+ } else if (needs_constant) {
switch (constant_type) {
case ConstantType::kZero:
return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
@@ -356,8 +355,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
std::vector<std::unique_ptr<Literal>> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
- TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument(
- *dataflow, *params[i], engine.get()));
+ arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get())
+ .ValueOrDie();
}
return std::move(arguments);
}