aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-25 15:01:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-25 15:05:41 -0800
commitdfb59da4ede1daf163a167da590ac70c447eb41a (patch)
tree20e99ae0b40ed8dd153f730823cdeb496c87fd56
parent022890f6ac03bb87cc7b4f1a5b722cd6b058e616 (diff)
[XLA:GPU] Implement conditional as a sequence of thunks in the GPU backend.
This also includes the following fixes: (1) Update buffer assignment for conditionals so that the buffers corresponding to the true operand and the true computation parameter are colocated, and similarly, the buffers corresponding to the false operand and the false computation parameter are colocated. (2) Update GPU copy insertion pass to insert copies when constants appear as operands of conditional instructions. PiperOrigin-RevId: 183297282
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc37
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.cc72
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.h65
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc43
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc120
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk.h1
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc36
11 files changed, 334 insertions, 83 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 323620c131..d5594dc07c 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -1358,6 +1358,43 @@ void BufferAssigner::BuildColocatedBufferSets(
index, points_to_analysis, &colocated_set);
AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
});
+
+ // Add true_operand and conditional.true_computation.parameter(0) as a
+ // colocated buffer set. Note that this has to be done for each subshape
+ // in the true_operand of the conditional.
+ ShapeUtil::ForEachSubshape(
+ conditional_hlo->operand(1)->shape(),
+ [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
+ const Shape& /*subshape*/, const ShapeIndex& index) {
+ std::vector<const LogicalBuffer*> true_set;
+ // Add conditional.true_operand.
+ AddBufferToColocatedSet(conditional_hlo->operand(1), index,
+ points_to_analysis, &true_set);
+ // Add conditional.true_computation.parameter_instruction(0).
+ AddBufferToColocatedSet(
+ conditional_hlo->true_computation()->parameter_instruction(0),
+ index, points_to_analysis, &true_set);
+ AddSetToColocatedBufferSets(true_set, colocated_buffer_sets);
+ });
+
+ // Add false_operand and conditional.false_computation.parameter(0) as a
+ // colocated buffer set. Note that this has to be done for each subshape
+ // in the false_operand of the conditional.
+ ShapeUtil::ForEachSubshape(
+ conditional_hlo->operand(2)->shape(),
+ [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
+ const Shape& /*subshape*/, const ShapeIndex& index) {
+ std::vector<const LogicalBuffer*> false_set;
+ // Add conditional.false_operand.
+ AddBufferToColocatedSet(conditional_hlo->operand(2), index,
+ points_to_analysis, &false_set);
+ // Add conditional.false_computation.parameter_instruction(0).
+ AddBufferToColocatedSet(
+ conditional_hlo->false_computation()->parameter_instruction(
+ 0),
+ index, points_to_analysis, &false_set);
+ AddSetToColocatedBufferSets(false_set, colocated_buffer_sets);
+ });
}
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index df5e2e35f8..3c3328b9cd 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -228,6 +228,7 @@ cc_library(
cc_library(
name = "gpu_executable",
srcs = [
+ "conditional_thunk.cc",
"convolution_thunk.cc",
"copy_thunk.cc",
"cudnn_batchnorm_thunk.cc",
@@ -243,6 +244,7 @@ cc_library(
"while_thunk.cc",
],
hdrs = [
+ "conditional_thunk.h",
"convolution_thunk.h",
"copy_thunk.h",
"cudnn_batchnorm_thunk.h",
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
new file mode 100644
index 0000000000..790ca535b1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+namespace gpu {
+
+ConditionalThunk::ConditionalThunk(
+ const BufferAllocation::Slice& predicate_buffer_index,
+ const BufferAllocation::Slice& true_operand_buffer_index,
+ const BufferAllocation::Slice& false_operand_buffer_index,
+ ThunkSequence true_thunk_sequence, ThunkSequence false_thunk_sequence,
+ const HloInstruction* hlo)
+ : Thunk(Kind::kConditional, hlo),
+ predicate_buffer_index_(predicate_buffer_index),
+ true_operand_buffer_index_(true_operand_buffer_index),
+ false_operand_buffer_index_(false_operand_buffer_index),
+ true_thunk_(std::move(true_thunk_sequence), hlo),
+ false_thunk_(std::move(false_thunk_sequence), hlo) {}
+
+Status ConditionalThunk::Initialize(const GpuExecutable& executable) {
+ TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable));
+ TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable));
+ return Status::OK();
+}
+
+Status ConditionalThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) {
+ // Copy the predicate value from device.
+ bool predicate;
+ perftools::gputools::DeviceMemoryBase predicate_address =
+ buffer_allocations.GetDeviceAddress(predicate_buffer_index_);
+ stream->ThenMemcpy(&predicate, predicate_address, sizeof(bool));
+
+ Status block_status = stream->BlockHostUntilDone();
+ if (!block_status.ok()) {
+ return InternalError("Failed to retrieve predicate value on stream %p: %s.",
+ stream, block_status.error_message().c_str());
+ }
+
+ // Execute the true or the false computation depending on the value of the
+ // predicate.
+ if (predicate) {
+ TF_RETURN_IF_ERROR(true_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ } else {
+ TF_RETURN_IF_ERROR(
+ false_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ }
+
+ return Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
new file mode 100644
index 0000000000..7725c46a3b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// ConditionalThunk implements the conditional instruction on GPU by reading the
+// predicate of the conditional and executing the true or the false computation
+// depending on the value of the predicate.
+//
+// ConditionalThunk assumes that the buffers of the conditional result and the
+// result of the true and false computations share the same allocation. Also,
+// the buffers of the true operand of the conditional and that of the parameter
+// instruction of the true computation share the same allocation. Similarly, the
+// buffers of the false operand and that of the parameter instruction of the
+// false computation share the same allocation.
+class ConditionalThunk : public Thunk {
+ public:
+ ConditionalThunk(const BufferAllocation::Slice& predicate_buffer_index,
+ const BufferAllocation::Slice& true_operand_buffer_index,
+ const BufferAllocation::Slice& false_operand_buffer_index,
+ ThunkSequence true_thunk_sequence,
+ ThunkSequence false_thunk_sequence,
+ const HloInstruction* hlo);
+
+ ConditionalThunk(const ConditionalThunk&) = delete;
+ ConditionalThunk& operator=(const ConditionalThunk&) = delete;
+
+ Status Initialize(const GpuExecutable& executable) override;
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ BufferAllocation::Slice predicate_buffer_index_;
+ BufferAllocation::Slice true_operand_buffer_index_;
+ BufferAllocation::Slice false_operand_buffer_index_;
+ SequentialThunk true_thunk_;
+ SequentialThunk false_thunk_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index e67087d822..e3b493c663 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -36,7 +36,7 @@ namespace gpu {
StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
HloInstruction* hlo) {
- HloInstruction*& copy = inserted_copies_[hlo];
+ HloInstruction*& copy = hlo_to_copy_map_[hlo];
if (copy == nullptr) {
TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
}
@@ -86,27 +86,34 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
}
}
- // Init values of a while node cannot be constants. Insert copies for any
- // constants found at the operand of a while.
- tensorflow::gtl::FlatSet<HloInstruction*> copied_constants;
+ // Init values of while and conditional nodes cannot be constants. Insert
+ // copies for any constants found at the operands of these nodes.
+ tensorflow::gtl::FlatSet<HloInstruction*> inserted_copies;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() != HloOpcode::kWhile) {
+ if (instruction->opcode() != HloOpcode::kWhile &&
+ instruction->opcode() != HloOpcode::kConditional) {
continue;
}
- for (auto& pair :
- dataflow->GetInstructionValueSet(instruction->operand(0))) {
- const HloValueSet& value_set = pair.second;
- for (const HloValue* value : value_set.values()) {
- if (value->defining_instruction()->opcode() ==
- HloOpcode::kConstant &&
- !ContainsKey(copied_constants, value->defining_instruction())) {
- HloInstruction* constant = value->defining_instruction();
- TF_ASSIGN_OR_RETURN(HloInstruction * copy,
- FindOrInsertCopy(constant));
- TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy));
- copied_constants.insert(constant);
- changed = true;
+ for (auto operand : instruction->operands()) {
+ // Skip the operands that have already been replaced with a copy in a
+ // previous iteration (which is possible when a constant is used as an
+ // operand in multiple places).
+ if (ContainsKey(inserted_copies, operand)) {
+ continue;
+ }
+ for (auto& pair : dataflow->GetInstructionValueSet(operand)) {
+ const HloValueSet& value_set = pair.second;
+ for (const HloValue* value : value_set.values()) {
+ if (value->defining_instruction()->IsConstant() &&
+ !ContainsKey(hlo_to_copy_map_, value->defining_instruction())) {
+ HloInstruction* constant = value->defining_instruction();
+ TF_ASSIGN_OR_RETURN(HloInstruction * copy,
+ FindOrInsertCopy(constant));
+ TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy));
+ inserted_copies.insert(copy);
+ changed = true;
+ }
}
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
index 4d77f337e6..0c6f9b511f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
@@ -32,13 +32,13 @@ class GpuCopyInsertion : public HloPassInterface {
StatusOr<bool> Run(HloModule* module) override;
protected:
- // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making
+ // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making
// duplicate copies.
StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
// A map containing all copies inserted to materialize operands of library
// calls. The key is the copied instruction and the value is the copy.
- tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> inserted_copies_;
+ tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> hlo_to_copy_map_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 095c3df3bf..23b72c3f71 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -758,37 +758,6 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
}
-Status IrEmitter::HandleConditional(HloInstruction* conditional) {
- auto pred = conditional->operand(0);
- auto true_arg = conditional->operand(1);
- auto false_arg = conditional->operand(2);
-
- llvm::Value* conditional_result = GetBasePointer(*conditional);
-
- llvm::LoadInst* pred_value = ir_builder_.CreateLoad(
- GetBasePointer(*pred),
- llvm_ir::AsStringRef(IrName(conditional, "load_predicate_value")));
- llvm::Value* pred_cond = ir_builder_.CreateICmpNE(
- pred_value,
- llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
- llvm_ir::AsStringRef(IrName(conditional, "boolean_predicate")));
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- pred_cond, IrName(conditional, "if_then_else"), &ir_builder_);
-
- SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *conditional->true_computation(), {GetBasePointer(*true_arg)},
- conditional_result));
-
- SetToFirstInsertPoint(if_data.false_block, &ir_builder_);
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *conditional->false_computation(), {GetBasePointer(*false_arg)},
- conditional_result));
-
- SetToFirstInsertPoint(if_data.after_block, &ir_builder_);
- return Status::OK();
-}
-
llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest(
const llvm_ir::IrArray& operand_array, int64 reduction_dimension,
tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 39bafaa346..3aa178410f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -96,7 +96,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status HandleRng(HloInstruction* random) override;
- Status HandleConditional(HloInstruction* conditional) override;
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
@@ -367,6 +366,11 @@ class IrEmitterUnnested : public IrEmitter {
std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
const int64 loop_limit);
+ // Returns a ConditionalThunk that executes the thunk sequence for
+ // 'true_computation' or 'false_computation' depending on the value of the
+ // predicate in the given conditional instruction.
+ std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
+
Status Postprocess(HloInstruction* hlo) override;
// Returns the last generated thunk.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index be35351e87..fc8783e753 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
@@ -272,8 +273,8 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
}
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
- thunk_sequence_->push_back(BuildKernelThunk(conditional));
- return IrEmitter::HandleConditional(conditional);
+ thunk_sequence_->emplace_back(BuildConditionalThunk(conditional));
+ return Status::OK();
}
Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
@@ -2102,6 +2103,24 @@ Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo,
namespace {
+// Checks that the buffers corresponding to the given two HLOs share the same
+// allocation.
+Status CheckHloBuffersShareAllocation(
+ const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index,
+ const BufferAssignment& buffer_assignment) {
+ const BufferAllocation::Slice slice_a =
+ buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
+ const BufferAllocation::Slice slice_b =
+ buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
+ if (slice_a != slice_b) {
+ return InternalError(
+ "instruction %s %s does not share allocation with instruction %s %s",
+ a->ToString().c_str(), slice_a.ToString().c_str(),
+ b->ToString().c_str(), slice_b.ToString().c_str());
+ }
+ return Status::OK();
+}
+
// Checks that all buffers used during while loop iteration share the same
// buffer allocation. This includes buffers for while result, while init
// operand, condition parameter, body parameter and body result.
@@ -2111,37 +2130,65 @@ Status CheckWhileBuffersShareAllocation(
const BufferAssignment& buffer_assignment) {
return ShapeUtil::ForEachSubshapeWithStatus(
xla_while->shape(),
- [&buffer_assignment, &xla_while](const Shape& /*subshape*/,
- const ShapeIndex& index) -> Status {
- auto check = [&buffer_assignment](const HloInstruction* a,
- const HloInstruction* b,
- const ShapeIndex& index) -> Status {
- const BufferAllocation::Slice slice_a =
- buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
- const BufferAllocation::Slice slice_b =
- buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
- if (slice_a != slice_b) {
- return InternalError(
- "instruction %s %s does not share allocation with "
- "instruction %s %s",
- a->ToString().c_str(), slice_a.ToString().c_str(),
- b->ToString().c_str(), slice_b.ToString().c_str());
- }
- return Status::OK();
- };
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
const HloInstruction* condition_parameter =
xla_while->while_condition()->parameter_instruction(0);
const HloComputation* body = xla_while->while_body();
const HloInstruction* body_parameter = body->parameter_instruction(0);
const HloInstruction* body_result = body->root_instruction();
- TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
- TF_RETURN_IF_ERROR(check(xla_while, condition_parameter, index));
- TF_RETURN_IF_ERROR(check(xla_while, body_parameter, index));
- TF_RETURN_IF_ERROR(check(xla_while, body_result, index));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, xla_while->operand(0), index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, condition_parameter, index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, body_parameter, index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, body_result, index, buffer_assignment));
return Status::OK();
});
}
+// Checks that the buffers used in a conditional instruction are shared with the
+// operands and result as follows:
+// * The result buffer of the conditional should share the allocation with the
+// result buffers of the true and false computations.
+// * The buffer of operand 1 should share the allocation with the buffer of
+// the parameter 0 instruction of the true computation.
+// * The buffer of operand 2 should share the allocation with the buffer of
+// the parameter 0 instruction of the false computation.
+Status CheckConditionalBuffersShareAllocation(
+ const HloInstruction* conditional,
+ const BufferAssignment& buffer_assignment) {
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ conditional->shape(),
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ conditional, conditional->true_computation()->root_instruction(),
+ index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ conditional, conditional->false_computation()->root_instruction(),
+ index, buffer_assignment));
+ return Status::OK();
+ }));
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ conditional->operand(1)->shape(),
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
+ return CheckHloBuffersShareAllocation(
+ conditional->operand(1),
+ conditional->true_computation()->parameter_instruction(0), index,
+ buffer_assignment);
+ }));
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ conditional->operand(2)->shape(),
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
+ return CheckHloBuffersShareAllocation(
+ conditional->operand(2),
+ conditional->false_computation()->parameter_instruction(0), index,
+ buffer_assignment);
+ }));
+ return Status::OK();
+}
+
} // namespace
std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
@@ -2184,6 +2231,31 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
ir_emitter_body.ConsumeThunkSequence(), hlo);
}
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
+ const HloInstruction* hlo) {
+ // Check that the buffers used in conditional are shared with the operands and
+ // result appropriately.
+ TF_CHECK_OK(CheckConditionalBuffersShareAllocation(
+ hlo, ir_emitter_context_->buffer_assignment()));
+
+ HloComputation* true_computation = hlo->true_computation();
+ IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation,
+ ir_emitter_context_);
+ TF_CHECK_OK(true_computation->root_instruction()->Accept(&ir_emitter_true));
+
+ HloComputation* false_computation = hlo->false_computation();
+ IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation,
+ ir_emitter_context_);
+ TF_CHECK_OK(false_computation->root_instruction()->Accept(&ir_emitter_false));
+
+ return MakeUnique<ConditionalThunk>(
+ GetAllocationSlice(*hlo->operand(0)),
+ GetAllocationSlice(*hlo->operand(1)),
+ GetAllocationSlice(*hlo->operand(2)),
+ std::move(*ir_emitter_true.ConsumeThunkSequence()),
+ std::move(*ir_emitter_false.ConsumeThunkSequence()), hlo);
+}
+
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index 625c3f8bea..2c3032d79b 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -41,6 +41,7 @@ class GpuExecutable;
class Thunk {
public:
enum class Kind {
+ kConditional,
kConvolution,
kCopy,
kCudnnBatchNormBackward,
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index 0016b6cc61..bc82167482 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -355,8 +355,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
}
// Test true and false computations that return a tuple of arrays.
-// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
-XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) {
+XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
ComputationBuilder builder(client_, TestName());
auto pred = builder.ConstantR0<bool>(true);
auto operands = builder.Tuple({builder.ConstantR1<float>({12.2f, 15.8f}),
@@ -373,9 +372,7 @@ XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) {
// Test true and false computations that return a tuple of a predicate, a
// scalar, and an array.
-// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
-XLA_TEST_F(ConditionalOpTest,
- DISABLED_ON_GPU(ReturnTupleofPredicateScalarArray)) {
+XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
ComputationBuilder true_builder(client_, TestName() + ".true");
{
true_builder.Parameter(0, empty_tuple_, "tuple");
@@ -413,8 +410,7 @@ XLA_TEST_F(ConditionalOpTest,
}
// Test true and false computations that return a nested tuple.
-// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
-XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnNestedTuple)) {
+XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
ComputationBuilder true_builder(client_, TestName() + ".true");
{
true_builder.Parameter(0, empty_tuple_, "tuple");
@@ -532,6 +528,32 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
+XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
+ ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional");
+ {
+ Shape r0bool = ShapeUtil::MakeShape(PRED, {});
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
+ auto param0 = inner_builder.Parameter(0, tuple_shape, "param0");
+ auto pred_cond = inner_builder.GetTupleElement(param0, 0);
+ auto true_operand = inner_builder.GetTupleElement(param0, 1);
+ auto false_operand = inner_builder.GetTupleElement(param0, 2);
+ inner_builder.Conditional(pred_cond, true_operand,
+ CreateR0CeilComputation(), false_operand,
+ CreateR0FloorComputation());
+ }
+ auto inner_builder_result = inner_builder.Build();
+ EXPECT_IS_OK(inner_builder_result.status());
+
+ ComputationBuilder builder(client_, TestName());
+ auto pred2 = builder.ConstantR0<bool>(false);
+ auto operand1 = builder.ConstantR0<float>(1.1f);
+ auto operand2 = builder.ConstantR0<float>(12.2f);
+ auto tuple_operand = builder.Tuple({pred2, operand1, operand2});
+ builder.Call(inner_builder_result.ConsumeValueOrDie(), {tuple_operand});
+
+ ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+}
+
// Test a mismatch in the shape of the true operand and true computation.
XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
ComputationBuilder builder(client_, TestName());