aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Nick Desaulniers <ndesaulniers@google.com>2017-12-05 14:00:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-05 14:03:44 -0800
commitb352b38aabd33404e7ae987778caa6e4b44d86d1 (patch)
tree36909cf74ee8089ce448cc54d6ab0ded245047bf /tensorflow/compiler/xla/tests/test_utils.cc
parent248176bbc74127e26a15b7b5c63c3f9c114123ba (diff)
Rather than make potentially complex modifications to the Hlo graph, simply generate input data that is constrained for certain entry computation parameters.
Generate fake literals that are within bounds for DynamicSlice and other operations that accept dynamically computed indices. PiperOrigin-RevId: 178006866
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc162
1 files changed, 109 insertions, 53 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 0d56c9f483..93bce97a3e 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@@ -47,42 +48,113 @@ void PopulateWithRandomIntegralData(Literal* literal) {
}));
}
-bool LooksLikeSum(const HloInstruction& instruction) {
- return instruction.opcode() == HloOpcode::kAdd &&
- instruction.operand(0)->opcode() == HloOpcode::kParameter &&
- instruction.operand(1)->opcode() == HloOpcode::kParameter &&
- instruction.operand(0) != instruction.operand(1);
+// Matches binary addition computations.
+bool LooksLikeSum(const HloComputation& computation) {
+ const HloInstruction* const root = computation.root_instruction();
+ return root->opcode() == HloOpcode::kAdd &&
+ computation.num_parameters() == 2 &&
+ root->operand(0)->opcode() == HloOpcode::kParameter &&
+ root->operand(1)->opcode() == HloOpcode::kParameter &&
+ root->operand(0) != root->operand(1);
}
-// Given an instruction and operand number, replace the given operand with
-// a Literal Constant Zero. Handle the case of a fusion instruction by
-// replacing the fusion's parent's parameter with a Literal Constant Zero,
-// unless the fusion's parent is itself a fusion.
-Status MaybeReplaceParameterInputWithZero(HloInstruction* const instruction,
- const int64 operand_number) {
- CHECK_LT(operand_number, instruction->operand_count());
- if (instruction->operand(operand_number)->opcode() != HloOpcode::kParameter) {
- return Status::OK();
- }
+// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition,
+// which requires an init_value of 0 rather than a random value.
+bool NeedsZeroInitValue(const HloUse& use) {
+ const HloInstruction* const instruction = use.instruction;
+ const HloOpcode opcode = instruction->opcode();
+ const int64 op_num = use.operand_number;
+ return (
+ ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) &&
+ op_num == 1 && LooksLikeSum(*instruction->to_apply())) ||
+ (opcode == HloOpcode::kSelectAndScatter && op_num == 2 &&
+ LooksLikeSum(*instruction->scatter())));
+}
- HloComputation* const computation = instruction->parent();
- std::unique_ptr<HloInstruction> zero = HloInstruction::CreateConstant(
- MakeUnique<Literal>(Literal::Zero(instruction->shape().element_type())));
+// Generate random values that are constrained to the input_shape minus the
+// output_shape so as not to produce wrapping slices, for instance.
+std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
+ const Shape& input_shape, const Shape& slice_shape) {
+ const int64 rank = ShapeUtil::Rank(input_shape);
+ std::vector<int32> start_indices(rank);
+ std::minstd_rand0 engine;
+ for (int i = 0; i < rank; ++i) {
+ const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) -
+ ShapeUtil::GetDimension(slice_shape, i);
+ std::uniform_int_distribution<int32> generator(0, upper_bound);
+ start_indices[i] = generator(engine);
+ }
+ return Literal::CreateR1<int32>(start_indices);
+}
- if (computation->IsFusionComputation()) {
- HloInstruction* const fusion_instruction = computation->FusionInstruction();
- if (fusion_instruction->IsFused()) {
- return Unimplemented(
- "Unable to replace fused parameter of fusion instruction");
+// Use dataflow analysis on each parameter to see if there are uses that would
+// be problematic when generating input data. Returns the list of instructions
+// that correspond to their uses.
+//
+// Should be paired with the CreateLiteralForConstrainedUses() function below.
+std::vector<HloInstruction*> FindConstrainedUses(
+ const HloDataflowAnalysis& dataflow, const HloInstruction& param) {
+ std::vector<HloInstruction*> constrained_uses;
+ for (const auto& pair : dataflow.GetInstructionValueSet(&param)) {
+ const HloValue& value = dataflow.GetUniqueValueAt(&param, pair.first);
+ for (const HloUse& use : value.uses()) {
+ HloInstruction* instruction = use.instruction;
+ const HloOpcode opcode = instruction->opcode();
+ const int64 op_num = use.operand_number;
+ if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) ||
+ (opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) {
+ constrained_uses.push_back(instruction);
+ } else if (opcode == HloOpcode::kFusion) {
+ const HloInstruction* const to_analyze =
+ instruction->fused_parameter(op_num);
+ auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
+ constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
+ fused_uses.end());
+ } else if (NeedsZeroInitValue(use)) {
+ constrained_uses.push_back(instruction);
+ }
}
- TF_RETURN_IF_ERROR(fusion_instruction->ReplaceOperandWith(
- instruction->operand(operand_number)->parameter_number(),
- fusion_instruction->parent()->AddInstruction(std::move(zero))));
- } else {
- TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(
- operand_number, computation->AddInstruction(std::move(zero))));
}
- return Status::OK();
+ return constrained_uses;
+}
+
+// Given a parameter, generate a random Literal to use as input if there exist
+// no constrained uses in the dataflow graph. If such constraints exist,
+// generate a constrained literal (either bounded in the case of indices, or
+// zero in the case of init_values for reductions).
+StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
+ const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
+ const HloInstruction& param) {
+ const auto count = constrained_uses.size();
+ if (count > 1) {
+ return Unimplemented("multiple constrained uses not yet supported");
+ }
+
+ if (count == 0) {
+ return MakeFakeLiteral(param.shape());
+ }
+
+ const HloInstruction* const use = constrained_uses[0];
+ switch (use->opcode()) {
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kDynamicUpdateSlice:
+ return MakeRandomNonwrappingSliceIndex(use->operand(0)->shape(),
+ use->shape());
+ case HloOpcode::kReduce:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kSelectAndScatter:
+ return Literal::CreateFromShape(param.shape());
+ default:
+ return Unimplemented("constrained use given; no equivalent literal");
+ }
+}
+
+// Given a module entry parameter, use the dataflow analysis to see if a
+// special case literal must be created, or if we can generate fake data.
+StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument(
+ const HloDataflowAnalysis& dataflow, const HloInstruction& param) {
+ const auto constrained_uses = FindConstrainedUses(dataflow, param);
+ return CreateLiteralForConstrainedUses(constrained_uses, param);
}
} // namespace
@@ -146,33 +218,17 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
}
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- const HloModule& module) {
- std::vector<std::unique_ptr<Literal>> arguments;
- for (const ShapeLayout& shape_layout :
- module.config().entry_computation_layout().parameter_layouts()) {
- TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape()));
- arguments.push_back(std::move(literal));
+ HloModule* const module) {
+ TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module));
+ const auto params = module->entry_computation()->parameter_instructions();
+ std::vector<std::unique_ptr<Literal>> arguments(params.size());
+ for (int i = 0; i < params.size(); ++i) {
+ TF_ASSIGN_OR_RETURN(arguments[i],
+ MakeConstrainedArgument(*dataflow, *params[i]));
}
return std::move(arguments);
}
-Status ReplaceInitsWithConstants(HloModule* const module) {
- for (HloComputation* const computation : module->computations()) {
- for (HloInstruction* const instruction : computation->instructions()) {
- const HloOpcode opcode = instruction->opcode();
- if ((opcode == HloOpcode::kReduce ||
- opcode == HloOpcode::kReduceWindow) &&
- LooksLikeSum(*instruction->to_apply()->root_instruction())) {
- TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 1));
- } else if (opcode == HloOpcode::kSelectAndScatter &&
- LooksLikeSum(*instruction->scatter()->root_instruction())) {
- TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 2));
- }
- }
- }
- return Status::OK();
-}
-
Status VerifyHloModule(const perftools::gputools::Platform& platform,
HloModule* const module) {
return HloVerifier(