aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-01 16:17:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 16:19:55 -0700
commit210abebd3febdd2c44ab5021bcebf8f1f5d451c4 (patch)
tree22810ab30037169b4df8fec07016a8d8f69b7a6e
parentb25e6fe32cccd29ec4cb4014bbb45d62b75835b4 (diff)
[TF:XLA] Separate on-host and on-device shape and layout in HloModule.
Previously, only one layout was stored with an HLO module. This CL allows HLO passes to modify the on-device layouts without affecting the on-host layout (provided by the client) PiperOrigin-RevId: 195014875
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc36
-rw-r--r--tensorflow/compiler/xla/client/local_client.h9
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/executable.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h43
-rw-r--r--tensorflow/compiler/xla/service/interpreter/compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc15
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h4
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/service.cc52
-rw-r--r--tensorflow/compiler/xla/service/service.h3
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h20
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc10
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc2
24 files changed, 195 insertions, 91 deletions
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 1c12705903..1acc6f8686 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -51,27 +51,49 @@ LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
tensorflow::Status LocalExecutable::ValidateExecutionOptions(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableRunOptions& run_options, const Backend& backend) {
- const ComputationLayout& computation_layout =
- executable_->module_config().entry_computation_layout();
+ const ComputationLayout& host_computation_layout =
+ executable_->module_config().host_entry_computation_layout();
+ const ComputationLayout& device_computation_layout =
+ executable_->module_config().device_entry_computation_layout();
// Check argument number, shapes, and layouts.
- if (arguments.size() != computation_layout.parameter_count()) {
+ if (arguments.size() != host_computation_layout.parameter_count()) {
return InvalidArgument(
"invalid number of arguments for computation: expected %d, got %zu",
- computation_layout.parameter_count(), arguments.size());
+ host_computation_layout.parameter_count(), arguments.size());
+ }
+ if (arguments.size() != device_computation_layout.parameter_count()) {
+ return InvalidArgument(
+ "invalid number of arguments for computation: expected %d, got %zu",
+ device_computation_layout.parameter_count(), arguments.size());
}
for (int i = 0; i < arguments.size(); ++i) {
- if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
+ if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape(
arguments[i]->on_host_shape())) {
return InvalidParameterArgument(
executable_.get(), i,
- "Argument does not match shape or layout of computation parameter "
+ "Argument does not match host shape or layout of computation "
+ "parameter "
"%d: want %s, got %s",
i,
- ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape())
+ ShapeUtil::HumanString(
+ host_computation_layout.parameter_layout(i).shape())
.c_str(),
ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str());
}
+ if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape(
+ arguments[i]->on_device_shape())) {
+ return InvalidParameterArgument(
+ executable_.get(), i,
+ "Argument does not match device shape or layout of computation "
+ "parameter "
+ "%d: want %s, got %s",
+ i,
+ ShapeUtil::HumanString(
+ device_computation_layout.parameter_layout(i).shape())
+ .c_str(),
+ ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str());
+ }
}
if (run_options.stream() != nullptr) {
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 4ce7059f7e..d8fd7a5623 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -43,15 +43,6 @@ class LocalExecutable {
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
ExecutableRunOptions run_options);
- // Return the layout (contained in a shape) of the result produced by the
- // computation.
- const Shape& result_layout() const {
- return executable_->module_config()
- .entry_computation_layout()
- .result_layout()
- .shape();
- }
-
// Return the options used to build the executable.
const ExecutableBuildOptions& build_options() const { return build_options_; }
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index ec2bb6c762..d8ba289f29 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -294,7 +294,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_entry_computation_layout());
+ module->device_entry_computation_layout());
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index aabf4d5161..32613b8690 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -249,8 +249,9 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
std::vector<bool>* buffers_in_result) {
se::Stream* stream = run_options->stream();
ScopedShapedBuffer result_buffer(
- /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(),
- run_options->allocator(), stream->parent()->device_ordinal());
+ /*on_host_shape=*/host_result_shape(),
+ /*on_device_shape=*/host_result_shape(), run_options->allocator(),
+ stream->parent()->device_ordinal());
// Copy DeviceMemoryBase values which contain the array(s) of the result into
// the respective location in ShapedBuffer which is returned to the caller.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
index c8edbb9e15..09adb5cb02 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
@@ -27,7 +27,8 @@ namespace cpu {
// layout constraints for operands and results of library calls.
class CpuLayoutAssignment : public LayoutAssignment {
public:
- explicit CpuLayoutAssignment(ComputationLayout* entry_computation_layout)
+ explicit CpuLayoutAssignment(
+ const ComputationLayout& entry_computation_layout)
: LayoutAssignment(entry_computation_layout) {}
~CpuLayoutAssignment() override {}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 6ba030fff3..ba4c5a23d3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -49,7 +49,7 @@ class CpuLayoutAssignmentTest : public HloTestBase {
protected:
void AssignLayouts(HloModule* module,
ComputationLayout* entry_computation_layout) {
- cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout);
+ cpu::CpuLayoutAssignment layout_assignment(*entry_computation_layout);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
};
@@ -311,7 +311,7 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
result.addend_fusion_param = fusion_instruction->operand(
fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number());
- cpu::CpuLayoutAssignment layout_assignment(&computation_layout);
+ cpu::CpuLayoutAssignment layout_assignment(computation_layout);
TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something,
layout_assignment.Run(module));
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 99762f4586..4f0466c544 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -140,8 +140,8 @@ class Executable {
// The shape (including layout) that results from this execution. This is the
// shape of the DeviceMemoryBase result value in ExecuteOnStream above.
- const Shape& result_shape() const {
- return hlo_module_->config().entry_computation_layout().result_shape();
+ const Shape& host_result_shape() const {
+ return hlo_module_->config().host_entry_computation_layout().result_shape();
}
// TODO(b/74197823): Delete the session module dumping helpers.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 796c3070f2..4fdc4c8961 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -248,7 +248,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
{
HloPassPipeline pipeline("layout_assignment");
pipeline.AddPass<GpuLayoutAssignment>(
- hlo_module->mutable_entry_computation_layout());
+ hlo_module->device_entry_computation_layout());
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index 86a3a7111f..51aae79c3d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -27,7 +27,8 @@ namespace gpu {
// layout constraints for operands and results of library calls.
class GpuLayoutAssignment : public LayoutAssignment {
public:
- explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout)
+ explicit GpuLayoutAssignment(
+ const ComputationLayout& entry_computation_layout)
: LayoutAssignment(entry_computation_layout) {}
~GpuLayoutAssignment() override {}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 4c45d2e94a..7c80195594 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -69,7 +69,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
*computation_layout.mutable_result_layout() =
ShapeLayout(result_shape_with_layout);
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(computation_layout);
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
for (const HloInstruction* operand : add->operands()) {
@@ -156,7 +156,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
*computation_layout.mutable_result_layout() = ShapeLayout(result_shape);
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(computation_layout);
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -225,7 +225,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
{result_shape, offset_scale_shape, offset_scale_shape}));
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(computation_layout);
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -305,7 +305,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
{result_shape, scale_shape, scale_shape}));
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(computation_layout);
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first and fourth operands to the batchnorm call should have the
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index d4bad16f79..987c4b2719 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -55,7 +55,7 @@ HloComputation* HloModule::AddComputationInternal(
// If the module configuration has no entry layout computation set, create a
// default one based on the program shape.
- if (!config_.has_entry_computation_layout()) {
+ if (!config_.has_host_entry_computation_layout()) {
config_.SetDefaultComputationLayout(
entry_computation_->ComputeProgramShape());
}
@@ -229,11 +229,14 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
TF_RET_CHECK(proto.has_program_shape())
<< "No program shape found in the proto";
const auto& expected_program_shape = proto.program_shape();
- TF_RET_CHECK(expected_program_shape.parameters_size() ==
- module_config.entry_computation_layout().parameter_count());
+ TF_RET_CHECK(
+ expected_program_shape.parameters_size() ==
+ module_config.device_entry_computation_layout().parameter_count());
for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
const Shape& parameter_shape =
- module_config.entry_computation_layout().parameter_layout(i).shape();
+ module_config.device_entry_computation_layout()
+ .parameter_layout(i)
+ .shape();
TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
parameter_shape))
<< "HloModuleConfig has different shape for parameter " << i
@@ -243,7 +246,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
<< ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
}
const Shape& result_shape =
- module_config.entry_computation_layout().result_layout().shape();
+ module_config.device_entry_computation_layout().result_layout().shape();
TF_RET_CHECK(
ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
<< "HloModuleConfig has different result shape than the HLO module. "
@@ -303,7 +306,7 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
// The module config is constructed with default layouts regardless of what is
// passed in via the ProgramShape. Set the layouts to the appropriate values.
ComputationLayout* entry_layout =
- module_config.mutable_entry_computation_layout();
+ module_config.mutable_host_entry_computation_layout();
for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
TF_RETURN_IF_ERROR(
entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
@@ -311,6 +314,8 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
}
TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
program_shape.result()));
+ *module_config.mutable_device_entry_computation_layout() =
+ module_config.host_entry_computation_layout();
return module_config;
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index aa843ead51..82d790ec3b 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -98,12 +98,20 @@ class HloModule {
return entry_computation_;
}
- ComputationLayout* mutable_entry_computation_layout() {
- return config_.mutable_entry_computation_layout();
+ ComputationLayout* mutable_host_entry_computation_layout() {
+ return config_.mutable_host_entry_computation_layout();
}
- const ComputationLayout& entry_computation_layout() const {
- return config_.entry_computation_layout();
+ const ComputationLayout& host_entry_computation_layout() const {
+ return config_.host_entry_computation_layout();
+ }
+
+ ComputationLayout* mutable_device_entry_computation_layout() {
+ return config_.mutable_device_entry_computation_layout();
+ }
+
+ const ComputationLayout& device_entry_computation_layout() const {
+ return config_.device_entry_computation_layout();
}
const VersionedComputationHandle& entry_computation_handle() const {
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
index 4205b0402c..dae5578a31 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -31,11 +31,13 @@ using tensorflow::strings::StrAppend;
HloModuleConfig::HloModuleConfig() {}
HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape)
- : entry_computation_layout_(program_shape) {}
+ : host_entry_computation_layout_(program_shape),
+ device_entry_computation_layout_(program_shape) {}
void HloModuleConfig::SetDefaultComputationLayout(
const ProgramShape& program_shape) {
- entry_computation_layout_ = ComputationLayout(program_shape);
+ host_entry_computation_layout_ = ComputationLayout(program_shape);
+ device_entry_computation_layout_ = ComputationLayout(program_shape);
}
string HloModuleConfig::compilation_cache_key() const {
@@ -44,11 +46,18 @@ string HloModuleConfig::compilation_cache_key() const {
StrAppend(&key, "::(");
std::vector<string> params;
for (const ShapeLayout& param_layout :
- entry_computation_layout_->parameter_layouts()) {
+ host_entry_computation_layout_->parameter_layouts()) {
params.push_back(param_layout.shape().DebugString());
}
StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ",
- entry_computation_layout_->result_shape().SerializeAsString());
+ host_entry_computation_layout_->result_shape().SerializeAsString());
+ for (const ShapeLayout& param_layout :
+ device_entry_computation_layout_->parameter_layouts()) {
+ params.push_back(param_layout.shape().DebugString());
+ }
+ StrAppend(
+ &key, tensorflow::str_util::Join(params, ", "), ") => ",
+ device_entry_computation_layout_->result_shape().SerializeAsString());
if (seed() != 0) {
// TODO(b/32083678): force recompilation to reset global state.
static std::atomic<int> counter{0};
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index 586a03d412..cdb0b29a23 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -41,26 +41,44 @@ class HloModuleConfig {
explicit HloModuleConfig(const ProgramShape& program_shape);
// Checks if this config has an entry computation layout already.
- bool has_entry_computation_layout() const {
- return entry_computation_layout_.has_value();
+ bool has_host_entry_computation_layout() const {
+ return host_entry_computation_layout_.has_value();
+ }
+
+ bool has_device_entry_computation_layout() const {
+ return device_entry_computation_layout_.has_value();
}
// Sets the entry computation layout for this config. If the entry computation
// layout already exists, it is silently replaced.
void SetDefaultComputationLayout(const ProgramShape& program_shape);
- // Returns a constant reference to the layout of the entry computation.
+ // Returns a constant reference to the on-host layout of the entry
+ // computation. Assumes the layout was set.
+ const ComputationLayout& host_entry_computation_layout() const {
+ CHECK(host_entry_computation_layout_.has_value());
+ return *host_entry_computation_layout_;
+ }
+
+ // Returns a mutable pointer to the layout of the on-host entry computation.
// Assumes the layout was set.
- const ComputationLayout& entry_computation_layout() const {
- CHECK(entry_computation_layout_.has_value());
- return *entry_computation_layout_;
+ ComputationLayout* mutable_host_entry_computation_layout() {
+ CHECK(host_entry_computation_layout_.has_value());
+ return &(*host_entry_computation_layout_);
}
- // Returns a mutable pointer to the layout of the entry computation. Assumes
- // the layout was set.
- ComputationLayout* mutable_entry_computation_layout() {
- CHECK(entry_computation_layout_.has_value());
- return &(*entry_computation_layout_);
+ // Returns a constant reference to the on-device layout of the entry
+ // computation. Assumes the layout was set.
+ const ComputationLayout& device_entry_computation_layout() const {
+ CHECK(device_entry_computation_layout_.has_value());
+ return *device_entry_computation_layout_;
+ }
+
+ // Returns a mutable pointer to the layout of the on-device entry computation.
+ // Assumes the layout was set.
+ ComputationLayout* mutable_device_entry_computation_layout() {
+ CHECK(device_entry_computation_layout_.has_value());
+ return &(*device_entry_computation_layout_);
}
// Returns whether to enable HLO-level profiling.
@@ -109,7 +127,8 @@ class HloModuleConfig {
private:
// If you add new members, be sure to update compilation_cache_key.
- tensorflow::gtl::optional<ComputationLayout> entry_computation_layout_;
+ tensorflow::gtl::optional<ComputationLayout> host_entry_computation_layout_;
+ tensorflow::gtl::optional<ComputationLayout> device_entry_computation_layout_;
// Whether this is a 'host module'.
bool is_host_module_ = false;
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index 76b3ecad26..eecbbcb93d 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -45,7 +45,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
HloPassPipeline pipeline("Interpreter");
pipeline.AddPass<LayoutAssignment>(
- hlo_module->mutable_entry_computation_layout());
+ hlo_module->device_entry_computation_layout());
return pipeline.Run(hlo_module).status();
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 2494569db5..cfa7ba5e81 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -909,22 +909,19 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
}
LayoutAssignment::LayoutAssignment(
- ComputationLayout* entry_computation_layout,
+ const ComputationLayout& entry_computation_layout,
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),
channel_layout_constraints_(channel_constraints) {
VLOG(1) << "entry computation layout given to layout assignment: "
- << entry_computation_layout_->ToString();
+ << entry_computation_layout_.ToString();
// Layouts of all parameter instructions must be set.
for (const ShapeLayout& parameter_layout :
- entry_computation_layout_->parameter_layouts()) {
+ entry_computation_layout_.parameter_layouts()) {
CHECK(parameter_layout.LayoutIsSet());
}
- // If the result layout is not set, then choose the default.
- // TODO(b/29118294): Choose a better layout in this case.
- if (!entry_computation_layout_->result_layout().LayoutIsSet()) {
- entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout();
- }
+ // TODO(b/29118294): Choose a better layout if the result layout is not set.
+ CHECK(entry_computation_layout_.result_layout().LayoutIsSet());
}
std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
@@ -1597,7 +1594,7 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
}
if (computation == module->entry_computation()) {
TF_RETURN_IF_ERROR(RunOnComputation(
- *entry_computation_layout_, *points_to_analysis,
+ entry_computation_layout_, *points_to_analysis,
module->entry_computation(), channel_layout_constraints_));
} else {
ComputationLayout computation_layout(computation->ComputeProgramShape());
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index ae4986d6ad..c83ae0388b 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -288,7 +288,7 @@ class LayoutAssignment : public HloPassInterface {
// If channel_constraints is nullptr, no kSend or kRecvs must be contained
// within any module passed to `Run`.
explicit LayoutAssignment(
- ComputationLayout* entry_computation_layout,
+ const ComputationLayout& entry_computation_layout,
ChannelLayoutConstraints* channel_constraints = nullptr);
~LayoutAssignment() override {}
tensorflow::StringPiece name() const override { return "layout-assignment"; }
@@ -402,7 +402,7 @@ class LayoutAssignment : public HloPassInterface {
// necessary conditions.
Status CheckLayouts(HloModule* module);
- ComputationLayout* entry_computation_layout_;
+ const ComputationLayout& entry_computation_layout_;
protected:
// Sets up the copy instruction according to the characteristic (sharding,
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 4b1c9bad41..7e1bb11eaa 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -53,7 +53,7 @@ class LayoutAssignmentTest : public HloTestBase {
protected:
void AssignLayouts(HloModule* module,
ComputationLayout* entry_computation_layout) {
- LayoutAssignment layout_assignment(entry_computation_layout);
+ LayoutAssignment layout_assignment(*entry_computation_layout);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
};
@@ -285,7 +285,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
result_shape));
- LayoutAssignment layout_assignment(&computation_layout);
+ LayoutAssignment layout_assignment(computation_layout);
AssignLayouts(module.get(), &computation_layout);
// Layout assignment should have deep copied the result of the computation to
@@ -488,7 +488,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
public:
explicit OperandsMustBeTheSameLayoutAssignment(
ComputationLayout* entry_computation_layout)
- : LayoutAssignment(entry_computation_layout) {}
+ : LayoutAssignment(*entry_computation_layout) {}
protected:
Status PropagateBufferConstraint(
@@ -808,7 +808,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
- LayoutAssignment layout_assignment(&computation_layout);
+ LayoutAssignment layout_assignment(computation_layout);
Status error_status = layout_assignment.Run(module.get()).status();
EXPECT_FALSE(error_status.ok());
EXPECT_THAT(
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 175ee96bbc..6ce03ab39d 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -296,8 +296,10 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ExecutionOptions* execution_options,
const UserComputation* user_computation) {
auto config = MakeUnique<HloModuleConfig>(program_shape);
- auto* computation_layout = config->mutable_entry_computation_layout();
-
+ ComputationLayout* host_computation_layout =
+ config->mutable_host_entry_computation_layout();
+ ComputationLayout* device_computation_layout =
+ config->mutable_device_entry_computation_layout();
if (program_shape.parameters_size() != argument_shapes.size()) {
return InvalidArgument("computation takes %d parameters, but %zu given",
program_shape.parameters_size(),
@@ -322,9 +324,10 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
ShapeUtil::HumanString(*argument_shapes[i]).c_str());
}
- TF_RETURN_IF_ERROR(
- computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
- *argument_shapes[i]));
+ TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i)
+ ->CopyLayoutFromShape(*argument_shapes[i]));
+ TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i)
+ ->CopyLayoutFromShape(*argument_shapes[i]));
}
if (execution_options != nullptr &&
execution_options->has_shape_with_output_layout()) {
@@ -333,10 +336,17 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout,
program_shape.result()));
TF_RETURN_IF_ERROR(
- computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ host_computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ shape_with_output_layout));
+ TF_RETURN_IF_ERROR(
+ device_computation_layout->mutable_result_layout()->CopyLayoutFromShape(
shape_with_output_layout));
} else {
- computation_layout->mutable_result_layout()->Clear();
+ // If the result layout is not set, then choose the default.
+ // TODO(b/29118294): Allow the compiler to choose a better layout in this
+ // case.
+ host_computation_layout->mutable_result_layout()->SetToDefaultLayout();
+ device_computation_layout->mutable_result_layout()->SetToDefaultLayout();
}
config->set_replica_count(options_.number_of_replicas());
@@ -488,6 +498,22 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
return std::move(executables);
}
+Status Service::ValidateEntryComputationLayout(HloModule* module) {
+ const ComputationLayout& on_device =
+ module->device_entry_computation_layout();
+ for (int64 i = 0; i < on_device.parameter_count(); ++i) {
+ TF_RET_CHECK(ShapeUtil::Equal(
+ on_device.parameter_shape(i),
+ execute_backend_->transfer_manager()->HostShapeToDeviceShape(
+ module->host_entry_computation_layout().parameter_shape(i))));
+ }
+ TF_RET_CHECK(ShapeUtil::Equal(
+ module->device_entry_computation_layout().result_shape(),
+ execute_backend_->transfer_manager()->HostShapeToDeviceShape(
+ module->host_entry_computation_layout().result_shape())));
+ return tensorflow::Status::OK();
+}
+
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const VersionedComputationHandle& versioned_handle,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
@@ -526,6 +552,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
device_allocator));
+ // Check that on-host and on-device shapes are consistent.
+ TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(
@@ -889,7 +917,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
CreateModuleConfig(*program_shape, replicated_arguments.front(),
request.execution_options(), user_computation));
VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: "
- << module_config->entry_computation_layout().ToString();
+ << module_config->host_entry_computation_layout().ToString();
// Adds to the vectors to build and execute the computations after the loop.
all_arguments.push_back(replicated_arguments);
@@ -992,7 +1020,7 @@ tensorflow::Status Service::ExecuteGraphParallel(
/*user_computation=*/nullptr));
VLOG(3)
<< "ExecuteGraphParallel created HloModuleConfig computation layout: "
- << module_config->entry_computation_layout().ToString();
+ << module_config->host_entry_computation_layout().ToString();
// Adds to the vectors to build and execute the computations after the loop.
all_arguments.push_back(replicated_arguments);
@@ -1142,7 +1170,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
arg->execution_options(), user_computation));
VLOG(3) << "Execute created HloModuleConfig computation layout: "
- << module_config->entry_computation_layout().ToString();
+ << module_config->host_entry_computation_layout().ToString();
TF_ASSIGN_OR_RETURN(
std::shared_ptr<Executable> executable,
@@ -1212,6 +1240,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
device_allocator));
+ // Check that on-host and on-device shapes are consistent.
+ TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(
@@ -1313,7 +1343,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
arg->execution_options(), user_computation));
VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: "
- << module_config->entry_computation_layout().ToString();
+ << module_config->host_entry_computation_layout().ToString();
ExecutionProfile profile;
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 476bd0597d..f84fe407e0 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -295,6 +295,9 @@ class Service : public ServiceInterface {
const ExecutionOptions& execution_options,
tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments);
+ // Assert that host- and device-shapes are in a consistent state.
+ Status ValidateEntryComputationLayout(HloModule* module);
+
protected:
friend class LocalExecutable;
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 840292010d..54cf0543b8 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -632,6 +632,7 @@ xla_test(
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 6491208895..9539ae0680 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -177,9 +177,13 @@ class HloTestBase : public ::testing::Test {
// 'layout'.
void ForceParameterLayout(HloModule* module, int64 param_no,
const Layout& layout) {
- ASSERT_LT(param_no,
- module->mutable_entry_computation_layout()->parameter_count());
- module->mutable_entry_computation_layout()
+ ASSERT_LT(
+ param_no,
+ module->mutable_host_entry_computation_layout()->parameter_count());
+ module->mutable_host_entry_computation_layout()
+ ->mutable_parameter_layout(param_no)
+ ->ResetLayout(layout);
+ module->mutable_device_entry_computation_layout()
->mutable_parameter_layout(param_no)
->ResetLayout(layout);
}
@@ -187,7 +191,10 @@ class HloTestBase : public ::testing::Test {
// Convenience method to force the layout of the computation result in a
// module. The result layout of 'module' is set to 'layout'.
void ForceResultLayout(HloModule* module, const Layout& layout) {
- module->mutable_entry_computation_layout()
+ module->mutable_host_entry_computation_layout()
+ ->mutable_result_layout()
+ ->ResetLayout(layout);
+ module->mutable_device_entry_computation_layout()
->mutable_result_layout()
->ResetLayout(layout);
}
@@ -195,7 +202,10 @@ class HloTestBase : public ::testing::Test {
// Convenience method to clear the layout of the computation result in
// 'module'.
void ForceClearResultLayout(HloModule* module) {
- module->mutable_entry_computation_layout()
+ module->mutable_host_entry_computation_layout()
+ ->mutable_result_layout()
+ ->Clear();
+ module->mutable_device_entry_computation_layout()
->mutable_result_layout()
->Clear();
}
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index fdbfc0210e..1bb31ddb7b 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -303,12 +303,18 @@ bool HloParser::ParseComputations() {
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module_->mutable_host_entry_computation_layout()
+ ->mutable_parameter_layout(p)
+ ->CopyLayoutFromShape(param_shape));
+ TF_CHECK_OK(module_->mutable_device_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module_->mutable_host_entry_computation_layout()
+ ->mutable_result_layout()
+ ->CopyLayoutFromShape(result_shape));
+ TF_CHECK_OK(module_->mutable_device_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index adc8b1d620..4e085bc89c 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -1239,7 +1239,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
auto module = Parse(original);
TF_ASSERT_OK(module.status());
- auto program_layout = module.ValueOrDie()->entry_computation_layout();
+ auto program_layout = module.ValueOrDie()->host_entry_computation_layout();
ASSERT_EQ(program_layout.parameter_count(), 1);
auto param_layout = program_layout.parameter_layout(0).layout();
auto result_layout = program_layout.result_layout().layout();