diff options
author | 2018-06-22 02:07:07 -0700 | |
---|---|---|
committer | 2018-06-22 02:09:55 -0700 | |
commit | 289be76f8ed6d40752f6ee5c64632f4624fa7cc2 (patch) | |
tree | 29532e99008ca8aba2a0d23391b5e71d915f1c13 | |
parent | 945d1a77aebb2071b571598cb1d02fac5b1370c1 (diff) |
Simplify GPU copy insertion.
Previously, there was almost identical code for inserting copies.
This CL combines the two code paths.
PiperOrigin-RevId: 201655259
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc | 97 |
1 files changed, 49 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index c5ccdd4a7d..fbc1303085 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -52,60 +52,20 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) { HloDataflowAnalysis::Run(*module)); // Make sure all operands of a library call are in memory instead of constants - // in IR. - for (HloInstruction* hlo : - module->entry_computation()->MakeInstructionPostOrder()) { - // Inserts a copy of hlo->operand(n) if it's a constant. - auto copy_operand_if_constant = [&](int64 n) -> Status { - HloInstruction* operand = hlo->mutable_operand(n); - TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); - const auto& values = dataflow->GetValueSet(operand).values(); - if (std::any_of(values.begin(), values.end(), [](const HloValue* value) { - return value->defining_instruction()->opcode() == - HloOpcode::kConstant; - })) { - TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(n, copy)); - changed = true; - } - return Status::OK(); - }; - - if (IsCustomCallToDnnBatchNorm(*hlo)) { - // The epsilon and feature_index operands to a CUDNN batchnorm op don't - // need to be materialized in memory -- in fact, they must be constants. - // These are the last two operands of all three batchnorm ops. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (ImplementedAsLibraryCall(*hlo) || - hlo->opcode() == HloOpcode::kCrossReplicaSum) { - // For all other library calls and cross-replica-sum, materialize all the - // operands into memory. (Cross-replica-sum gets its constant args - // materialized even if it's not implemented as a libcall to simplify the - // implementation. It's slower, but we can constant fold away constant - // args *anyway*, so we just need to make it work.) - for (int64 i = 0; i < hlo->operand_count(); ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } - } - - // Init values of while and conditional nodes cannot be constants. Insert - // copies for any constants found at the operands of these nodes. + // in IR. Also, init values of while and conditional nodes cannot be + // constants. Insert copies for any constants found at the operands of these + // nodes. tensorflow::gtl::FlatSet<HloInstruction*> inserted_copies; for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kWhile && - instruction->opcode() != HloOpcode::kConditional) { - continue; - } - for (auto operand : instruction->operands()) { + for (HloInstruction* hlo : computation->instructions()) { + // Inserts a copy of hlo->operand(n) if it's a constant. + auto copy_operand_if_constant = [&](int64 n) -> Status { + HloInstruction* operand = hlo->mutable_operand(n); // Skip the operands that have already been replaced with a copy in a // previous iteration (which is possible when a constant is used as an // operand in multiple places). if (ContainsKey(inserted_copies, operand)) { - continue; + return Status::OK(); } for (auto& pair : dataflow->GetInstructionValueSet(operand)) { const HloValueSet& value_set = pair.second; @@ -121,6 +81,47 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) { } } } + return Status::OK(); + }; + + if (IsCustomCallToDnnBatchNorm(*hlo)) { + // The epsilon and feature_index operands to a CUDNN batchnorm op don't + // need to be materialized in memory -- in fact, they must be constants. + // These are the last two operands of all three batchnorm ops. + for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } + } else if (ImplementedAsLibraryCall(*hlo) || + hlo->opcode() == HloOpcode::kCrossReplicaSum || + hlo->opcode() == HloOpcode::kWhile || + hlo->opcode() == HloOpcode::kConditional) { + // For all other library calls, cross-replica-sum, while and conditional + // ops materialize all the operands into memory. (Cross-replica-sum + // gets its constant args materialized even if it's not implemented as a + // libcall to simplify the implementation. It's slower, but we can + // constant fold away constant args *anyway*, so we just need to make it + // work.) + for (int64 i = 0; i < hlo->operand_count(); ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } + } + } + } + + if (changed) { + // Check the assumption that the epsilon and feature_index constants of the + // CUDNN batchnorm op are not shared with other ops where we would replace + // them with a copy. These custom op calls are generated with the + // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. + for (HloComputation* computation : module->computations()) { + for (HloInstruction* hlo : computation->instructions()) { + if (!IsCustomCallToDnnBatchNorm(*hlo)) { + continue; + } + for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); + ++i) { + CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); + } } } } |