diff options
author | 2017-05-22 18:19:49 -0700 | |
---|---|---|
committer | 2017-05-22 18:23:35 -0700 | |
commit | e1b85cdd85471d7f6f19ce9eeeb4db616a5cc402 (patch) | |
tree | f1022d4fdeb6a29bdca3768c7f32894b6d1f0362 | |
parent | 7440479a1422157cfd330d7c2cab03d8e072bbfd (diff) |
[XLA] Teach Executable to do its own profiling (patch 2/4).
This follows patch 1 by removing the ShapeSizeBytes function from xla::Compiler, and changing its callers to use Executable::ShapeSizeBytesFunction.
Many of the callers are reproducing the same pattern to make a lambda to computes the size of a buffer, and so we condense this into xla::Compiler::BufferSizeBytesFunction.
PiperOrigin-RevId: 156814087
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment_test.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/compiler.h | 17 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/cpu_compiler.cc | 35 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/cpu_compiler.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_compiler.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_compiler.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/service.cc | 5 |
8 files changed, 23 insertions, 53 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 104491ecf6..96f8a21733 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -506,6 +506,7 @@ cc_library( ":executable", ":hlo", ":hlo_module_config", + ":logical_buffer", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index ac1d769010..9f922ad3e0 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -84,10 +84,7 @@ class BufferAssignmentTest : public HloTestBase { int64 alignment = 1) { return BufferAssigner::Run( module, MakeUnique<DependencyHloOrdering>(module), - [this](const LogicalBuffer& buffer) { - return backend_->compiler()->ShapeSizeBytes(buffer.shape()); - }, - alignment) + backend_->compiler()->BufferSizeBytesFunction(), alignment) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 551a5f091c..7ae285170e 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -148,13 +149,19 @@ class Compiler { static StatusOr<Compiler*> GetForPlatform( const perftools::gputools::Platform* platform); - // Returns the size in bytes of the top-level buffer of a shape. - virtual int64 ShapeSizeBytes(const Shape& shape) const = 0; - - // Returns a function that computes the size in bytes of the top-level - // buffer of a shape. + // Returns a function that computes the size in bytes of the logical + // buffer that contains a shape. virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0; + // Returns a function that computes the size in bytes of a given + // logical buffer. + std::function<int64(const LogicalBuffer&)> BufferSizeBytesFunction() { + HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); + return [shape_size](const LogicalBuffer& buffer) { + return shape_size(buffer.shape()); + }; + } + private: // Mutex that guards the platform-compiler map. static tensorflow::mutex* platform_compiler_mutex_; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 6abc9fd326..881b00f084 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -341,10 +341,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile( std::unique_ptr<BufferAssignment> assignment, BufferAssigner::Run(module.get(), MakeUnique<DependencyHloOrdering>(module.get()), - [this](const LogicalBuffer& buffer) { - return ShapeSizeBytes(buffer.shape()); - }, - kMemoryAlignment)); + BufferSizeBytesFunction(), kMemoryAlignment)); // If we are using the parallel CPU backend, we need to create map from // HloInstruction to the corresponding generated function name. @@ -360,7 +357,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile( // Copy the constant out of the ProtocolBuffer so that we can give it a // higher alignment. const void* data = LiteralUtil::InternalData(instruction->literal()); - int64 size = ShapeSizeBytes(instruction->shape()); + int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); auto iter = aligned_constants.emplace( instruction, MakeUnique<unsigned char[]>(size)); CHECK_EQ(iter.second, true); @@ -425,10 +422,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile( // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, - [this](const LogicalBuffer& buffer) { - return ShapeSizeBytes(buffer.shape()); - })); + CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -437,10 +431,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile( BufferAssigner::Run( module.get(), MakeUnique<SequentialHloOrdering>(module.get(), module_sequence), - [this](const LogicalBuffer& buffer) { - return ShapeSizeBytes(buffer.shape()); - }, - kMemoryAlignment)); + BufferSizeBytesFunction(), kMemoryAlignment)); // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from @@ -586,10 +577,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, - [this](const LogicalBuffer& buffer) { - return ShapeSizeBytes(buffer.shape()); - })); + CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -597,10 +585,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, std::unique_ptr<BufferAssignment> assignment, BufferAssigner::Run( module, MakeUnique<SequentialHloOrdering>(module, module_sequence), - [this](const LogicalBuffer& buffer) { - return ShapeSizeBytes(buffer.shape()); - }, - kMemoryAlignment)); + BufferSizeBytesFunction(), kMemoryAlignment)); IrEmitter ir_emitter(*module, *assignment, &llvm_module, /*hlo_to_profile_idx=*/nullptr); @@ -663,14 +648,6 @@ se::Platform::Id CpuCompiler::PlatformId() const { return se::host::kHostPlatformId; } -int64 CpuCompiler::ShapeSizeBytes(const Shape& shape) const { - // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { - return sizeof(void*); - } - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); -} - HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { return CpuExecutable::ShapeSizeBytes; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index d91365cd4d..f0dcab8c8a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -126,8 +126,6 @@ class CpuCompiler : public Compiler { perftools::gputools::Platform::Id PlatformId() const override; - int64 ShapeSizeBytes(const Shape& shape) const override; - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; private: diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 430b45f8d5..693ec2ad45 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -268,10 +268,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile( TF_ASSIGN_OR_RETURN( std::unique_ptr<BufferAssignment> buffer_assignment, BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), - [this](const LogicalBuffer& buffer) { - return ShapeSizeBytes(buffer.shape()); - }, - kMemoryAlignment)); + BufferSizeBytesFunction(), kMemoryAlignment)); IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), &stream_exec->GetDeviceDescription(), @@ -352,10 +349,6 @@ se::Platform::Id GpuCompiler::PlatformId() const { return se::cuda::kCudaPlatformId; } -int64 GpuCompiler::ShapeSizeBytes(const Shape& shape) const { - return ShapeUtil::ByteSizeOf(shape, pointer_size_); -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index 253c10182b..da52f5ab1f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -55,8 +55,6 @@ class GpuCompiler : public Compiler { perftools::gputools::Platform::Id PlatformId() const override; - int64 ShapeSizeBytes(const Shape& shape) const override; - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { // Capture just the pointer size, not the entire GpuCompiler object. int64 pointer_size = pointer_size_; diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 16babac861..98080e3b57 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1168,9 +1168,8 @@ tensorflow::Status Service::GetComputationStats( MakeHloDumper()(*module, "computation statistics subject"); // Run HLO analysis to get the computation statistics. - HloCostAnalysis analysis([this](const Shape& shape) { - return execute_backend_->compiler()->ShapeSizeBytes(shape); - }); + HloCostAnalysis analysis( + execute_backend_->compiler()->ShapeSizeBytesFunction()); TF_RETURN_IF_ERROR( module->entry_computation()->root_instruction()->Accept(&analysis)); |