From 210abebd3febdd2c44ab5021bcebf8f1f5d451c4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 May 2018 16:17:21 -0700 Subject: [TF:XLA] Separate on-host and on-device shape and layout in HloModule. Previously, only one layout was stored with an HLO module. This CL allows HLO passes to modify the on-device layouts without affecting the on-host layout (provided by the client) PiperOrigin-RevId: 195014875 --- tensorflow/compiler/xla/client/local_client.cc | 36 ++++++++++++--- tensorflow/compiler/xla/client/local_client.h | 9 ---- .../compiler/xla/service/cpu/cpu_compiler.cc | 2 +- .../compiler/xla/service/cpu/cpu_executable.cc | 5 ++- .../xla/service/cpu/cpu_layout_assignment.h | 3 +- .../xla/service/cpu/cpu_layout_assignment_test.cc | 4 +- tensorflow/compiler/xla/service/executable.h | 4 +- .../compiler/xla/service/gpu/gpu_compiler.cc | 2 +- .../xla/service/gpu/gpu_layout_assignment.h | 3 +- .../xla/service/gpu/gpu_layout_assignment_test.cc | 8 ++-- tensorflow/compiler/xla/service/hlo_module.cc | 17 ++++--- tensorflow/compiler/xla/service/hlo_module.h | 16 +++++-- .../compiler/xla/service/hlo_module_config.cc | 17 +++++-- .../compiler/xla/service/hlo_module_config.h | 43 +++++++++++++----- .../compiler/xla/service/interpreter/compiler.cc | 2 +- .../compiler/xla/service/layout_assignment.cc | 15 +++---- .../compiler/xla/service/layout_assignment.h | 4 +- .../compiler/xla/service/layout_assignment_test.cc | 8 ++-- tensorflow/compiler/xla/service/service.cc | 52 +++++++++++++++++----- tensorflow/compiler/xla/service/service.h | 3 ++ tensorflow/compiler/xla/tests/BUILD | 1 + tensorflow/compiler/xla/tests/hlo_test_base.h | 20 ++++++--- tensorflow/compiler/xla/tools/parser/hlo_parser.cc | 10 ++++- .../compiler/xla/tools/parser/hlo_parser_test.cc | 2 +- 24 files changed, 195 insertions(+), 91 deletions(-) diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1c12705903..1acc6f8686 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -51,27 +51,49 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, tensorflow::Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { - const ComputationLayout& computation_layout = - executable_->module_config().entry_computation_layout(); + const ComputationLayout& host_computation_layout = + executable_->module_config().host_entry_computation_layout(); + const ComputationLayout& device_computation_layout = + executable_->module_config().device_entry_computation_layout(); // Check argument number, shapes, and layouts. - if (arguments.size() != computation_layout.parameter_count()) { + if (arguments.size() != host_computation_layout.parameter_count()) { return InvalidArgument( "invalid number of arguments for computation: expected %d, got %zu", - computation_layout.parameter_count(), arguments.size()); + host_computation_layout.parameter_count(), arguments.size()); + } + if (arguments.size() != device_computation_layout.parameter_count()) { + return InvalidArgument( + "invalid number of arguments for computation: expected %d, got %zu", + device_computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { - if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( + if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape( arguments[i]->on_host_shape())) { return InvalidParameterArgument( executable_.get(), i, - "Argument does not match shape or layout of computation parameter " + "Argument does not match host shape or layout of computation " + "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) + ShapeUtil::HumanString( + host_computation_layout.parameter_layout(i).shape()) .c_str(), ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } + if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape( + arguments[i]->on_device_shape())) { + return InvalidParameterArgument( + executable_.get(), i, + "Argument does not match device shape or layout of computation " + "parameter " + "%d: want %s, got %s", + i, + ShapeUtil::HumanString( + device_computation_layout.parameter_layout(i).shape()) + .c_str(), + ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str()); + } } if (run_options.stream() != nullptr) { diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 4ce7059f7e..d8fd7a5623 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -43,15 +43,6 @@ class LocalExecutable { const tensorflow::gtl::ArraySlice arguments, ExecutableRunOptions run_options); - // Return the layout (contained in a shape) of the result produced by the - // computation. - const Shape& result_layout() const { - return executable_->module_config() - .entry_computation_layout() - .result_layout() - .shape(); - } - // Return the options used to build the executable. const ExecutableBuildOptions& build_options() const { return build_options_; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index ec2bb6c762..d8ba289f29 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -294,7 +294,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_entry_computation_layout()); + module->device_entry_computation_layout()); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index aabf4d5161..32613b8690 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -249,8 +249,9 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( std::vector* buffers_in_result) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( - /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), - run_options->allocator(), stream->parent()->device_ordinal()); + /*on_host_shape=*/host_result_shape(), + /*on_device_shape=*/host_result_shape(), run_options->allocator(), + stream->parent()->device_ordinal()); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer which is returned to the caller. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index c8edbb9e15..09adb5cb02 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -27,7 +27,8 @@ namespace cpu { // layout constraints for operands and results of library calls. class CpuLayoutAssignment : public LayoutAssignment { public: - explicit CpuLayoutAssignment(ComputationLayout* entry_computation_layout) + explicit CpuLayoutAssignment( + const ComputationLayout& entry_computation_layout) : LayoutAssignment(entry_computation_layout) {} ~CpuLayoutAssignment() override {} diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 6ba030fff3..ba4c5a23d3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -49,7 +49,7 @@ class CpuLayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout) { - cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout); + cpu::CpuLayoutAssignment layout_assignment(*entry_computation_layout); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -311,7 +311,7 @@ static StatusOr RunDotOutputFusion( result.addend_fusion_param = fusion_instruction->operand( fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); - cpu::CpuLayoutAssignment layout_assignment(&computation_layout); + cpu::CpuLayoutAssignment layout_assignment(computation_layout); TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, layout_assignment.Run(module)); diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 99762f4586..4f0466c544 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -140,8 +140,8 @@ class Executable { // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. - const Shape& result_shape() const { - return hlo_module_->config().entry_computation_layout().result_shape(); + const Shape& host_result_shape() const { + return hlo_module_->config().host_entry_computation_layout().result_shape(); } // TODO(b/74197823): Delete the session module dumping helpers. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 796c3070f2..4fdc4c8961 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -248,7 +248,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, { HloPassPipeline pipeline("layout_assignment"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); + hlo_module->device_entry_computation_layout()); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 86a3a7111f..51aae79c3d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -27,7 +27,8 @@ namespace gpu { // layout constraints for operands and results of library calls. class GpuLayoutAssignment : public LayoutAssignment { public: - explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout) + explicit GpuLayoutAssignment( + const ComputationLayout& entry_computation_layout) : LayoutAssignment(entry_computation_layout) {} ~GpuLayoutAssignment() override {} diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 4c45d2e94a..7c80195594 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -69,7 +69,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape_with_layout); - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment(computation_layout); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -156,7 +156,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment(computation_layout); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -225,7 +225,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { {result_shape, offset_scale_shape, offset_scale_shape})); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment(computation_layout); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -305,7 +305,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { {result_shape, scale_shape, scale_shape})); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment(computation_layout); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index d4bad16f79..987c4b2719 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -55,7 +55,7 @@ HloComputation* HloModule::AddComputationInternal( // If the module configuration has no entry layout computation set, create a // default one based on the program shape. - if (!config_.has_entry_computation_layout()) { + if (!config_.has_host_entry_computation_layout()) { config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } @@ -229,11 +229,14 @@ StatusOr> HloModule::CreateFromProto( TF_RET_CHECK(proto.has_program_shape()) << "No program shape found in the proto"; const auto& expected_program_shape = proto.program_shape(); - TF_RET_CHECK(expected_program_shape.parameters_size() == - module_config.entry_computation_layout().parameter_count()); + TF_RET_CHECK( + expected_program_shape.parameters_size() == + module_config.device_entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = - module_config.entry_computation_layout().parameter_layout(i).shape(); + module_config.device_entry_computation_layout() + .parameter_layout(i) + .shape(); TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), parameter_shape)) << "HloModuleConfig has different shape for parameter " << i @@ -243,7 +246,7 @@ StatusOr> HloModule::CreateFromProto( << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); } const Shape& result_shape = - module_config.entry_computation_layout().result_layout().shape(); + module_config.device_entry_computation_layout().result_layout().shape(); TF_RET_CHECK( ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " @@ -303,7 +306,7 @@ StatusOr HloModule::CreateModuleConfigFromProto( // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. ComputationLayout* entry_layout = - module_config.mutable_entry_computation_layout(); + module_config.mutable_host_entry_computation_layout(); for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { TF_RETURN_IF_ERROR( entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -311,6 +314,8 @@ StatusOr HloModule::CreateModuleConfigFromProto( } TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( program_shape.result())); + *module_config.mutable_device_entry_computation_layout() = + module_config.host_entry_computation_layout(); return module_config; } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index aa843ead51..82d790ec3b 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -98,12 +98,20 @@ class HloModule { return entry_computation_; } - ComputationLayout* mutable_entry_computation_layout() { - return config_.mutable_entry_computation_layout(); + ComputationLayout* mutable_host_entry_computation_layout() { + return config_.mutable_host_entry_computation_layout(); } - const ComputationLayout& entry_computation_layout() const { - return config_.entry_computation_layout(); + const ComputationLayout& host_entry_computation_layout() const { + return config_.host_entry_computation_layout(); + } + + ComputationLayout* mutable_device_entry_computation_layout() { + return config_.mutable_device_entry_computation_layout(); + } + + const ComputationLayout& device_entry_computation_layout() const { + return config_.device_entry_computation_layout(); } const VersionedComputationHandle& entry_computation_handle() const { diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 4205b0402c..dae5578a31 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -31,11 +31,13 @@ using tensorflow::strings::StrAppend; HloModuleConfig::HloModuleConfig() {} HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape) - : entry_computation_layout_(program_shape) {} + : host_entry_computation_layout_(program_shape), + device_entry_computation_layout_(program_shape) {} void HloModuleConfig::SetDefaultComputationLayout( const ProgramShape& program_shape) { - entry_computation_layout_ = ComputationLayout(program_shape); + host_entry_computation_layout_ = ComputationLayout(program_shape); + device_entry_computation_layout_ = ComputationLayout(program_shape); } string HloModuleConfig::compilation_cache_key() const { @@ -44,11 +46,18 @@ string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : - entry_computation_layout_->parameter_layouts()) { + host_entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", - entry_computation_layout_->result_shape().SerializeAsString()); + host_entry_computation_layout_->result_shape().SerializeAsString()); + for (const ShapeLayout& param_layout : + device_entry_computation_layout_->parameter_layouts()) { + params.push_back(param_layout.shape().DebugString()); + } + StrAppend( + &key, tensorflow::str_util::Join(params, ", "), ") => ", + device_entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 586a03d412..cdb0b29a23 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -41,26 +41,44 @@ class HloModuleConfig { explicit HloModuleConfig(const ProgramShape& program_shape); // Checks if this config has an entry computation layout already. - bool has_entry_computation_layout() const { - return entry_computation_layout_.has_value(); + bool has_host_entry_computation_layout() const { + return host_entry_computation_layout_.has_value(); + } + + bool has_device_entry_computation_layout() const { + return device_entry_computation_layout_.has_value(); } // Sets the entry computation layout for this config. If the entry computation // layout already exists, it is silently replaced. void SetDefaultComputationLayout(const ProgramShape& program_shape); - // Returns a constant reference to the layout of the entry computation. + // Returns a constant reference to the on-host layout of the entry + // computation. Assumes the layout was set. + const ComputationLayout& host_entry_computation_layout() const { + CHECK(host_entry_computation_layout_.has_value()); + return *host_entry_computation_layout_; + } + + // Returns a mutable pointer to the layout of the on-host entry computation. // Assumes the layout was set. - const ComputationLayout& entry_computation_layout() const { - CHECK(entry_computation_layout_.has_value()); - return *entry_computation_layout_; + ComputationLayout* mutable_host_entry_computation_layout() { + CHECK(host_entry_computation_layout_.has_value()); + return &(*host_entry_computation_layout_); } - // Returns a mutable pointer to the layout of the entry computation. Assumes - // the layout was set. - ComputationLayout* mutable_entry_computation_layout() { - CHECK(entry_computation_layout_.has_value()); - return &(*entry_computation_layout_); + // Returns a constant reference to the on-device layout of the entry + // computation. Assumes the layout was set. + const ComputationLayout& device_entry_computation_layout() const { + CHECK(device_entry_computation_layout_.has_value()); + return *device_entry_computation_layout_; + } + + // Returns a mutable pointer to the layout of the on-device entry computation. + // Assumes the layout was set. + ComputationLayout* mutable_device_entry_computation_layout() { + CHECK(device_entry_computation_layout_.has_value()); + return &(*device_entry_computation_layout_); } // Returns whether to enable HLO-level profiling. @@ -109,7 +127,8 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional entry_computation_layout_; + tensorflow::gtl::optional host_entry_computation_layout_; + tensorflow::gtl::optional device_entry_computation_layout_; // Whether this is a 'host module'. bool is_host_module_ = false; diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 76b3ecad26..eecbbcb93d 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -45,7 +45,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); + hlo_module->device_entry_computation_layout()); return pipeline.Run(hlo_module).status(); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 2494569db5..cfa7ba5e81 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -909,22 +909,19 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { } LayoutAssignment::LayoutAssignment( - ComputationLayout* entry_computation_layout, + const ComputationLayout& entry_computation_layout, ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), channel_layout_constraints_(channel_constraints) { VLOG(1) << "entry computation layout given to layout assignment: " - << entry_computation_layout_->ToString(); + << entry_computation_layout_.ToString(); // Layouts of all parameter instructions must be set. for (const ShapeLayout& parameter_layout : - entry_computation_layout_->parameter_layouts()) { + entry_computation_layout_.parameter_layouts()) { CHECK(parameter_layout.LayoutIsSet()); } - // If the result layout is not set, then choose the default. - // TODO(b/29118294): Choose a better layout in this case. - if (!entry_computation_layout_->result_layout().LayoutIsSet()) { - entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout(); - } + // TODO(b/29118294): Choose a better layout if the result layout is not set. + CHECK(entry_computation_layout_.result_layout().LayoutIsSet()); } std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( @@ -1597,7 +1594,7 @@ StatusOr LayoutAssignment::Run(HloModule* module) { } if (computation == module->entry_computation()) { TF_RETURN_IF_ERROR(RunOnComputation( - *entry_computation_layout_, *points_to_analysis, + entry_computation_layout_, *points_to_analysis, module->entry_computation(), channel_layout_constraints_)); } else { ComputationLayout computation_layout(computation->ComputeProgramShape()); diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index ae4986d6ad..c83ae0388b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -288,7 +288,7 @@ class LayoutAssignment : public HloPassInterface { // If channel_constraints is nullptr, no kSend or kRecvs must be contained // within any module passed to `Run`. explicit LayoutAssignment( - ComputationLayout* entry_computation_layout, + const ComputationLayout& entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} tensorflow::StringPiece name() const override { return "layout-assignment"; } @@ -402,7 +402,7 @@ class LayoutAssignment : public HloPassInterface { // necessary conditions. Status CheckLayouts(HloModule* module); - ComputationLayout* entry_computation_layout_; + const ComputationLayout& entry_computation_layout_; protected: // Sets up the copy instruction according to the characteristic (sharding, diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 4b1c9bad41..7e1bb11eaa 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -53,7 +53,7 @@ class LayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout) { - LayoutAssignment layout_assignment(entry_computation_layout); + LayoutAssignment layout_assignment(*entry_computation_layout); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -285,7 +285,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - LayoutAssignment layout_assignment(&computation_layout); + LayoutAssignment layout_assignment(computation_layout); AssignLayouts(module.get(), &computation_layout); // Layout assignment should have deep copied the result of the computation to @@ -488,7 +488,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { public: explicit OperandsMustBeTheSameLayoutAssignment( ComputationLayout* entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + : LayoutAssignment(*entry_computation_layout) {} protected: Status PropagateBufferConstraint( @@ -808,7 +808,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - LayoutAssignment layout_assignment(&computation_layout); + LayoutAssignment layout_assignment(computation_layout); Status error_status = layout_assignment.Run(module.get()).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 175ee96bbc..6ce03ab39d 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -296,8 +296,10 @@ StatusOr> Service::CreateModuleConfig( const ExecutionOptions* execution_options, const UserComputation* user_computation) { auto config = MakeUnique(program_shape); - auto* computation_layout = config->mutable_entry_computation_layout(); - + ComputationLayout* host_computation_layout = + config->mutable_host_entry_computation_layout(); + ComputationLayout* device_computation_layout = + config->mutable_device_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { return InvalidArgument("computation takes %d parameters, but %zu given", program_shape.parameters_size(), @@ -322,9 +324,10 @@ StatusOr> Service::CreateModuleConfig( i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } - TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - *argument_shapes[i])); + TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i) + ->CopyLayoutFromShape(*argument_shapes[i])); + TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i) + ->CopyLayoutFromShape(*argument_shapes[i])); } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { @@ -333,10 +336,17 @@ StatusOr> Service::CreateModuleConfig( TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( + host_computation_layout->mutable_result_layout()->CopyLayoutFromShape( + shape_with_output_layout)); + TF_RETURN_IF_ERROR( + device_computation_layout->mutable_result_layout()->CopyLayoutFromShape( shape_with_output_layout)); } else { - computation_layout->mutable_result_layout()->Clear(); + // If the result layout is not set, then choose the default. + // TODO(b/29118294): Allow the compiler to choose a better layout in this + // case. + host_computation_layout->mutable_result_layout()->SetToDefaultLayout(); + device_computation_layout->mutable_result_layout()->SetToDefaultLayout(); } config->set_replica_count(options_.number_of_replicas()); @@ -488,6 +498,22 @@ StatusOr>> Service::BuildExecutables( return std::move(executables); } +Status Service::ValidateEntryComputationLayout(HloModule* module) { + const ComputationLayout& on_device = + module->device_entry_computation_layout(); + for (int64 i = 0; i < on_device.parameter_count(); ++i) { + TF_RET_CHECK(ShapeUtil::Equal( + on_device.parameter_shape(i), + execute_backend_->transfer_manager()->HostShapeToDeviceShape( + module->host_entry_computation_layout().parameter_shape(i)))); + } + TF_RET_CHECK(ShapeUtil::Equal( + module->device_entry_computation_layout().result_shape(), + execute_backend_->transfer_manager()->HostShapeToDeviceShape( + module->host_entry_computation_layout().result_shape()))); + return tensorflow::Status::OK(); +} + StatusOr> Service::BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, Backend* backend, @@ -526,6 +552,8 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, device_allocator)); + // Check that on-host and on-device shapes are consistent. + TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, backend->compiler()->RunBackend( @@ -889,7 +917,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, CreateModuleConfig(*program_shape, replicated_arguments.front(), request.execution_options(), user_computation)); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); + << module_config->host_entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); @@ -992,7 +1020,7 @@ tensorflow::Status Service::ExecuteGraphParallel( /*user_computation=*/nullptr)); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); + << module_config->host_entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); @@ -1142,7 +1170,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, arg->execution_options(), user_computation)); VLOG(3) << "Execute created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); + << module_config->host_entry_computation_layout().ToString(); TF_ASSIGN_OR_RETURN( std::shared_ptr executable, @@ -1212,6 +1240,8 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, device_allocator)); + // Check that on-host and on-device shapes are consistent. + TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, backend->compiler()->RunBackend( @@ -1313,7 +1343,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, arg->execution_options(), user_computation)); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); + << module_config->host_entry_computation_layout().ToString(); ExecutionProfile profile; diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 476bd0597d..f84fe407e0 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -295,6 +295,9 @@ class Service : public ServiceInterface { const ExecutionOptions& execution_options, tensorflow::gtl::ArraySlice arguments); + // Assert that host- and device-shapes are in a consistent state. + Status ValidateEntryComputationLayout(HloModule* module); + protected: friend class LocalExecutable; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 840292010d..54cf0543b8 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -632,6 +632,7 @@ xla_test( "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 6491208895..9539ae0680 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -177,9 +177,13 @@ class HloTestBase : public ::testing::Test { // 'layout'. void ForceParameterLayout(HloModule* module, int64 param_no, const Layout& layout) { - ASSERT_LT(param_no, - module->mutable_entry_computation_layout()->parameter_count()); - module->mutable_entry_computation_layout() + ASSERT_LT( + param_no, + module->mutable_host_entry_computation_layout()->parameter_count()); + module->mutable_host_entry_computation_layout() + ->mutable_parameter_layout(param_no) + ->ResetLayout(layout); + module->mutable_device_entry_computation_layout() ->mutable_parameter_layout(param_no) ->ResetLayout(layout); } @@ -187,7 +191,10 @@ class HloTestBase : public ::testing::Test { // Convenience method to force the layout of the computation result in a // module. The result layout of 'module' is set to 'layout'. void ForceResultLayout(HloModule* module, const Layout& layout) { - module->mutable_entry_computation_layout() + module->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(layout); + module->mutable_device_entry_computation_layout() ->mutable_result_layout() ->ResetLayout(layout); } @@ -195,7 +202,10 @@ class HloTestBase : public ::testing::Test { // Convenience method to clear the layout of the computation result in // 'module'. void ForceClearResultLayout(HloModule* module) { - module->mutable_entry_computation_layout() + module->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->Clear(); + module->mutable_device_entry_computation_layout() ->mutable_result_layout() ->Clear(); } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index fdbfc0210e..1bb31ddb7b 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -303,12 +303,18 @@ bool HloParser::ParseComputations() { // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module_->mutable_host_entry_computation_layout() + ->mutable_parameter_layout(p) + ->CopyLayoutFromShape(param_shape)); + TF_CHECK_OK(module_->mutable_device_entry_computation_layout() ->mutable_parameter_layout(p) ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module_->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->CopyLayoutFromShape(result_shape)); + TF_CHECK_OK(module_->mutable_device_entry_computation_layout() ->mutable_result_layout() ->CopyLayoutFromShape(result_shape)); } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index adc8b1d620..4e085bc89c 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -1239,7 +1239,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { auto module = Parse(original); TF_ASSERT_OK(module.status()); - auto program_layout = module.ValueOrDie()->entry_computation_layout(); + auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); auto param_layout = program_layout.parameter_layout(0).layout(); auto result_layout = program_layout.result_layout().layout(); -- cgit v1.2.3