aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-06-22 02:07:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-22 02:09:55 -0700
commit289be76f8ed6d40752f6ee5c64632f4624fa7cc2 (patch)
tree29532e99008ca8aba2a0d23391b5e71d915f1c13
parent945d1a77aebb2071b571598cb1d02fac5b1370c1 (diff)
Simplify GPU copy insertion.
Previously, there was almost identical code for inserting copies. This CL combines the two code paths. PiperOrigin-RevId: 201655259
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc97
1 files changed, 49 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index c5ccdd4a7d..fbc1303085 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -52,60 +52,20 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
HloDataflowAnalysis::Run(*module));
// Make sure all operands of a library call are in memory instead of constants
- // in IR.
- for (HloInstruction* hlo :
- module->entry_computation()->MakeInstructionPostOrder()) {
- // Inserts a copy of hlo->operand(n) if it's a constant.
- auto copy_operand_if_constant = [&](int64 n) -> Status {
- HloInstruction* operand = hlo->mutable_operand(n);
- TF_RET_CHECK(ShapeUtil::IsArray(operand->shape()));
- const auto& values = dataflow->GetValueSet(operand).values();
- if (std::any_of(values.begin(), values.end(), [](const HloValue* value) {
- return value->defining_instruction()->opcode() ==
- HloOpcode::kConstant;
- })) {
- TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand));
- TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(n, copy));
- changed = true;
- }
- return Status::OK();
- };
-
- if (IsCustomCallToDnnBatchNorm(*hlo)) {
- // The epsilon and feature_index operands to a CUDNN batchnorm op don't
- // need to be materialized in memory -- in fact, they must be constants.
- // These are the last two operands of all three batchnorm ops.
- for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
- TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
- }
- } else if (ImplementedAsLibraryCall(*hlo) ||
- hlo->opcode() == HloOpcode::kCrossReplicaSum) {
- // For all other library calls and cross-replica-sum, materialize all the
- // operands into memory. (Cross-replica-sum gets its constant args
- // materialized even if it's not implemented as a libcall to simplify the
- // implementation. It's slower, but we can constant fold away constant
- // args *anyway*, so we just need to make it work.)
- for (int64 i = 0; i < hlo->operand_count(); ++i) {
- TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
- }
- }
- }
-
- // Init values of while and conditional nodes cannot be constants. Insert
- // copies for any constants found at the operands of these nodes.
+ // in IR. Also, init values of while and conditional nodes cannot be
+ // constants. Insert copies for any constants found at the operands of these
+ // nodes.
tensorflow::gtl::FlatSet<HloInstruction*> inserted_copies;
for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() != HloOpcode::kWhile &&
- instruction->opcode() != HloOpcode::kConditional) {
- continue;
- }
- for (auto operand : instruction->operands()) {
+ for (HloInstruction* hlo : computation->instructions()) {
+ // Inserts a copy of hlo->operand(n) if it's a constant.
+ auto copy_operand_if_constant = [&](int64 n) -> Status {
+ HloInstruction* operand = hlo->mutable_operand(n);
// Skip the operands that have already been replaced with a copy in a
// previous iteration (which is possible when a constant is used as an
// operand in multiple places).
if (ContainsKey(inserted_copies, operand)) {
- continue;
+ return Status::OK();
}
for (auto& pair : dataflow->GetInstructionValueSet(operand)) {
const HloValueSet& value_set = pair.second;
@@ -121,6 +81,47 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
}
}
}
+ return Status::OK();
+ };
+
+ if (IsCustomCallToDnnBatchNorm(*hlo)) {
+ // The epsilon and feature_index operands to a CUDNN batchnorm op don't
+ // need to be materialized in memory -- in fact, they must be constants.
+ // These are the last two operands of all three batchnorm ops.
+ for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
+ TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
+ }
+ } else if (ImplementedAsLibraryCall(*hlo) ||
+ hlo->opcode() == HloOpcode::kCrossReplicaSum ||
+ hlo->opcode() == HloOpcode::kWhile ||
+ hlo->opcode() == HloOpcode::kConditional) {
+ // For all other library calls, cross-replica-sum, while and conditional
+ // ops materialize all the operands into memory. (Cross-replica-sum
+ // gets its constant args materialized even if it's not implemented as a
+ // libcall to simplify the implementation. It's slower, but we can
+ // constant fold away constant args *anyway*, so we just need to make it
+ // work.)
+ for (int64 i = 0; i < hlo->operand_count(); ++i) {
+ TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
+ }
+ }
+ }
+ }
+
+ if (changed) {
+ // Check the assumption that the epsilon and feature_index constants of the
+ // CUDNN batchnorm op are not shared with other ops where we would replace
+ // them with a copy. These custom op calls are generated with the
+ // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them.
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* hlo : computation->instructions()) {
+ if (!IsCustomCallToDnnBatchNorm(*hlo)) {
+ continue;
+ }
+ for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count();
+ ++i) {
+ CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant);
+ }
}
}
}