diff options
author | Mark Heffernan <meheff@google.com> | 2018-08-15 13:04:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-15 13:13:46 -0700 |
commit | 75399bba469508a43cb784f1c261af63580528c6 (patch) | |
tree | 8f53f967f796a3eb4ea883f4180d3b3db247614c /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | bc646fd57603a58485076186d7797e6dc4853978 (diff) |
Support multiple index-space constraints in synthetic input generator.
PiperOrigin-RevId: 208868489
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 71 |
1 files changed, 35 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 2647937013..faeec657b6 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -208,16 +208,12 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex( - const Shape& input_shape, const Shape& slice_shape, - std::minstd_rand0* engine) { - const int64 rank = ShapeUtil::Rank(input_shape); - std::vector<int32> start_indices(rank); +std::unique_ptr<Literal> MakeRandomIndex( + tensorflow::gtl::ArraySlice<int64> index_space, std::minstd_rand0* engine) { + std::vector<int32> start_indices(index_space.size()); if (engine != nullptr) { - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution<int32> generator(0, upper_bound); + for (int i = 0; i < index_space.size(); ++i) { + std::uniform_int_distribution<int32> generator(0, index_space[i]); start_indices[i] = generator(*engine); } } @@ -267,37 +263,42 @@ std::vector<HloInstruction*> FindConstrainedUses( StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { - HloInstruction* needs_index = nullptr; - HloInstruction* needs_constant = nullptr; + std::vector<int64> index_space; + bool needs_constant = false; ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { switch (use->opcode()) { case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - if (needs_index != nullptr) { - auto needs_index_shape = needs_index->shape(); - auto use_shape = use->shape(); - if (needs_index->opcode() == HloOpcode::kDynamicSlice) { - needs_index_shape = needs_index->operand(0)->shape(); - } - if (use->opcode() == HloOpcode::kDynamicSlice) { - use_shape = use->operand(0)->shape(); + case HloOpcode::kDynamicUpdateSlice: { + const Shape& indexed_shape = use->operand(0)->shape(); + const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice + ? use->shape() + : use->operand(1)->shape(); + const int64 rank = ShapeUtil::Rank(indexed_shape); + if (!index_space.empty()) { + TF_RET_CHECK(rank == index_space.size()); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = std::min( + index_space[i], ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i)); } - if (!ShapeUtil::Equal(needs_index_shape, use_shape)) { - return Unimplemented( - "Conflicting operand generation slice index constraints\n"); + } else { + index_space.resize(rank); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); } } - needs_index = use; break; + } case HloOpcode::kReduce: case HloOpcode::kReduceWindow: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->to_apply()); break; case HloOpcode::kSelectAndScatter: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->scatter()); break; @@ -307,16 +308,14 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( use->ToString().c_str()); } } - if (needs_index != nullptr && needs_constant != nullptr) { + if (!index_space.empty() && needs_constant) { return Unimplemented( - "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " - "constant: %s\n", - needs_index->ToString().c_str(), needs_constant->ToString().c_str()); + "Conflicting operand generation constraints. Dynamically indexes a " + "shape and is the init value of a reduction."); } - if (needs_index != nullptr) { - return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), - needs_index->shape(), engine); - } else if (needs_constant != nullptr) { + if (!index_space.empty()) { + return MakeRandomIndex(index_space, engine); + } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); @@ -356,8 +355,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr; std::vector<std::unique_ptr<Literal>> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( - *dataflow, *params[i], engine.get())); + arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get()) + .ValueOrDie(); } return std::move(arguments); } |