aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-02-10 15:28:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-10 15:49:40 -0800
commit56eb0007c04fe47f61323a3014c9a7fb176b0d70 (patch)
tree49e735249dab837db24d01f5c604ce4a8bf19874 /tensorflow/compiler
parent674fdc3d90031da055767a2fda5ac8bfd2b2feb5 (diff)
Add bytes accessed to HLO profile output. Bytes accessed is a measure of the bytes read/written from memory during execution of an HLO op. It is typically the sum of the sizes of the operands and output. Sample line from profile table:
337 cycles ( 47.87%) :: 0.5 usec @ f_nom :: 263.80MFLOP/s :: 12.0KiB :: %multiply = ... The 12.0KiB is the change. As part of this change unconditionally gather bytes accessed information with HloCostAnalysis. This requires that the shape size computation be universally accessible so ShapeSizeBytes method was added to xla::Compiler which enabled some cleanup in various places. Change: 147206509
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc19
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/compiler.h3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc54
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc15
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc36
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h19
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.cc34
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc3
-rw-r--r--tensorflow/compiler/xla/service/service.cc4
-rw-r--r--tensorflow/compiler/xla/service/service.h4
16 files changed, 127 insertions, 112 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index fa15b41a67..de9d9db15d 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -411,19 +411,12 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
/* static */
StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
- int64 pointer_size, int64 alignment) {
- return BufferAssigner::Run(module, std::move(hlo_ordering),
- [pointer_size](const LogicalBuffer& buffer) {
- return ShapeUtil::IsOpaque(buffer.shape())
- ? 0
- : ShapeUtil::ByteSizeOf(
- buffer.shape(), pointer_size);
- },
- [alignment](const LogicalBuffer& buffer) {
- return alignment;
- },
- /*colocate_related_buffers=*/true,
- /*combine_temp-allocations=*/true);
+ LogicalBuffer::SizeFunction buffer_size, int64 alignment) {
+ return BufferAssigner::Run(
+ module, std::move(hlo_ordering), std::move(buffer_size),
+ [alignment](const LogicalBuffer& buffer) { return alignment; },
+ /*colocate_related_buffers=*/true,
+ /*combine_temp-allocations=*/true);
}
bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index 29c479ac2b..b5e104b14d 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -383,7 +383,7 @@ class BufferAssigner {
// and assigns buffers to all HLO instructions in the module.
static StatusOr<std::unique_ptr<BufferAssignment>> Run(
const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
- int64 pointer_size, int64 alignment);
+ LogicalBuffer::SizeFunction buffer_size, int64 alignment);
private:
explicit BufferAssigner(LogicalBuffer::SizeFunction buffer_size,
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 0f42c657fe..bb7342d508 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -74,6 +74,17 @@ class BufferAssignmentTest : public HloTestBase {
BufferAssignmentTest() : computation_tracker_() {}
~BufferAssignmentTest() override {}
+ std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
+ int64 alignment = 1) {
+ return BufferAssigner::Run(
+ module, MakeUnique<DependencyHloOrdering>(module),
+ [this](const LogicalBuffer& buffer) {
+ return backend_->compiler()->ShapeSizeBytes(buffer.shape());
+ },
+ alignment)
+ .ConsumeValueOrDie();
+ }
+
// Builds an x+1.0 computation to use in a Map.
std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) {
auto builder = HloComputation::Builder(name);
@@ -235,15 +246,6 @@ class BufferAssignmentTest : public HloTestBase {
Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_});
};
-namespace {
-std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
- int64 alignment = 1) {
- return BufferAssigner::Run(module, MakeUnique<DependencyHloOrdering>(module),
- /*pointer_size=*/sizeof(void*), alignment)
- .ConsumeValueOrDie();
-}
-}
-
// Tests a computation consisting of a single scalar constant node.
TEST_F(BufferAssignmentTest, ScalarConstant) {
auto builder = HloComputation::Builder(TestName());
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index 85c2d03e1b..45cbe2b7ae 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -153,6 +153,9 @@ class Compiler {
static StatusOr<Compiler*> GetForPlatform(
const perftools::gputools::Platform* platform);
+ // Returns the size in bytes of the top-level buffer of a shape.
+ virtual int64 ShapeSizeBytes(const Shape& shape) const = 0;
+
private:
// Mutex that guards the platform-compiler map.
static tensorflow::mutex* platform_compiler_mutex_;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index af3fc39cb4..f09f842834 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -298,8 +298,6 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
CodeGenOptLevel());
llvm_module->setDataLayout(jit->data_layout());
llvm_module->setTargetTriple(jit->target_triple().getTriple());
- const llvm::DataLayout& data_layout = llvm_module->getDataLayout();
- int64 pointer_size = data_layout.getPointerSize();
TF_RETURN_IF_ERROR(
RunHloPasses(hlo_module.get(), module_config.get(), dump_hlo));
@@ -325,7 +323,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(hlo_module.get(),
MakeUnique<DependencyHloOrdering>(hlo_module.get()),
- pointer_size, kMemoryAlignment));
+ [this](const LogicalBuffer& buffer) {
+ return ShapeSizeBytes(buffer.shape());
+ },
+ kMemoryAlignment));
// If we are using the parallel CPU backend, we need to create map from
// HloInstruction to the corresponding generated function name.
@@ -341,7 +342,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// Copy the constant out of the ProtocolBuffer so that we can give it a
// higher alignment.
const void* data = LiteralUtil::InternalData(instruction->literal());
- int64 size = llvm_ir::ByteSizeOf(instruction->shape(), data_layout);
+ int64 size = ShapeSizeBytes(instruction->shape());
auto iter = aligned_constants.emplace(
instruction, MakeUnique<unsigned char[]>(size));
CHECK_EQ(iter.second, true);
@@ -406,10 +407,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// and reduced memory usage (as compared to using DependencyHloOrdering).
TF_ASSIGN_OR_RETURN(
SequentialHloOrdering::HloModuleSequence module_sequence,
- CreateMemoryMinimizingSequence(
- *hlo_module, [&](const LogicalBuffer& buffer) {
- return llvm_ir::ByteSizeOf(buffer.shape(), data_layout);
- }));
+ CreateMemoryMinimizingSequence(*hlo_module,
+ [this](const LogicalBuffer& buffer) {
+ return ShapeSizeBytes(buffer.shape());
+ }));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
@@ -418,7 +419,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
BufferAssigner::Run(hlo_module.get(),
MakeUnique<SequentialHloOrdering>(hlo_module.get(),
module_sequence),
- pointer_size, kMemoryAlignment));
+ [this](const LogicalBuffer& buffer) {
+ return ShapeSizeBytes(buffer.shape());
+ },
+ kMemoryAlignment));
// Each computation is a single function. Emit all embedded computations
// before the entry computation. The order of computations returned from
@@ -559,8 +563,6 @@ CpuCompiler::CompileAheadOfTime(
if (pie_level != llvm::PIELevel::Default) {
llvm_module.setPIELevel(pie_level);
}
- const llvm::DataLayout& data_layout = llvm_module.getDataLayout();
- int64 pointer_size = data_layout.getPointerSize();
std::vector<std::unique_ptr<AotCompilationResult>> results;
for (int i = 0; i < hlo_modules.size(); ++i) {
@@ -571,18 +573,22 @@ CpuCompiler::CompileAheadOfTime(
TF_ASSIGN_OR_RETURN(
SequentialHloOrdering::HloModuleSequence module_sequence,
- CreateMemoryMinimizingSequence(
- *hlo_module, [&](const LogicalBuffer& buffer) {
- return llvm_ir::ByteSizeOf(buffer.shape(), data_layout);
- }));
+ CreateMemoryMinimizingSequence(*hlo_module,
+ [this](const LogicalBuffer& buffer) {
+ return ShapeSizeBytes(buffer.shape());
+ }));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
- TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferAssignment> assignment,
- BufferAssigner::Run(hlo_module,
- MakeUnique<SequentialHloOrdering>(
- hlo_module, module_sequence),
- pointer_size, kMemoryAlignment));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<BufferAssignment> assignment,
+ BufferAssigner::Run(
+ hlo_module,
+ MakeUnique<SequentialHloOrdering>(hlo_module, module_sequence),
+ [this](const LogicalBuffer& buffer) {
+ return ShapeSizeBytes(buffer.shape());
+ },
+ kMemoryAlignment));
IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, &llvm_module,
/*hlo_to_profile_idx=*/nullptr);
@@ -645,6 +651,14 @@ se::Platform::Id CpuCompiler::PlatformId() const {
return se::host::kHostPlatformId;
}
+int64 CpuCompiler::ShapeSizeBytes(const Shape& shape) const {
+ // On the cpu, opaques are pointers.
+ if (ShapeUtil::IsOpaque(shape)) {
+ return sizeof(void*);
+ }
+ return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+}
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
index d7d77ce58a..a32aa84ea5 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -131,6 +131,8 @@ class CpuCompiler : public Compiler {
perftools::gputools::Platform::Id PlatformId() const override;
+ int64 ShapeSizeBytes(const Shape& shape) const override;
+
private:
// Initialize the LLVM target.
static void InitializeLLVMTarget();
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index fb053b62a7..3931e909a0 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -98,8 +98,9 @@ double CalculateFlopsToBytesRatio(HloInstruction* fusion) {
double bytes = CalculateBytesReadByFusionInstruction(fusion);
// Add bytes written to root instructions buffer.
bytes += ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape());
- // Calculate flops for all fused instructions.
- HloCostAnalysis analysis;
+ // Calculate flops for all fused instructions. Use a null shape size function
+ // because we don't care about bytes accessed by the ops.
+ HloCostAnalysis analysis([](const Shape& shape) { return 0; });
TF_CHECK_OK(fusion->fused_expression_root()->Accept(&analysis));
// Return flops / bytes.
return bytes > 0.0 ? analysis.flop_count() / bytes : analysis.flop_count();
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 16d4473fda..ad283d7e66 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -212,7 +212,9 @@ void DumpPtxasInfo(const string& ptx) {
} // namespace
-GpuCompiler::GpuCompiler() : libdevice_dir_(GetLibdeviceDir()) {}
+GpuCompiler::GpuCompiler()
+ : libdevice_dir_(GetLibdeviceDir()),
+ pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {}
StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
std::unique_ptr<HloModule> hlo_module,
@@ -240,8 +242,6 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
// Set the target triple and the data layout.
llvm_module.setTargetTriple(kTargetTriple);
llvm_module.setDataLayout(kDataLayout);
- const llvm::DataLayout& data_layout = llvm_module.getDataLayout();
- int64 pointer_size = data_layout.getPointerSize();
// Determine the HLO schedule, which is an ordering of HLO instructions. This
// is used by buffer assignment to enable buffer reuse, and the same ordering
@@ -256,7 +256,10 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> buffer_assignment,
BufferAssigner::Run(hlo_module.get(), hlo_schedule->ConsumeHloOrdering(),
- pointer_size, kMemoryAlignment));
+ [this](const LogicalBuffer& buffer) {
+ return ShapeSizeBytes(buffer.shape());
+ },
+ kMemoryAlignment));
IrEmitterContext ir_emitter_context(hlo_module.get(), buffer_assignment.get(),
&stream_exec->GetDeviceDescription(),
@@ -331,6 +334,10 @@ se::Platform::Id GpuCompiler::PlatformId() const {
return se::cuda::kCudaPlatformId;
}
+int64 GpuCompiler::ShapeSizeBytes(const Shape& shape) const {
+ return ShapeUtil::ByteSizeOf(shape, pointer_size_);
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
index a074607760..22f492b422 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
@@ -60,6 +60,8 @@ class GpuCompiler : public Compiler {
perftools::gputools::Platform::Id PlatformId() const override;
+ int64 ShapeSizeBytes(const Shape& shape) const override;
+
private:
// The parent directory of libdevice IR libraries.
const string libdevice_dir_;
@@ -70,6 +72,9 @@ class GpuCompiler : public Compiler {
tensorflow::mutex mutex_;
std::vector<std::unique_ptr<string>> generated_ptxes_ GUARDED_BY(mutex_);
+ // The size in bytes of a pointer. Used for computing ShapeSizeBytes.
+ int64 pointer_size_;
+
TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler);
};
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index b99132665b..41be997584 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -25,10 +25,6 @@ limitations under the License.
namespace xla {
-int64 HloCostAnalysis::ShapeSizeBytes(const Shape& shape) const {
- return shape_size_ == nullptr ? 0 : shape_size_(shape);
-}
-
Status HloCostAnalysis::Preprocess(HloInstruction* hlo) {
// Set current instruction cost values to reasonable default values. Each
// handler can overwrite these values. In Postprocess, these value are
@@ -39,12 +35,9 @@ Status HloCostAnalysis::Preprocess(HloInstruction* hlo) {
// The default element count for an instruction is the sum of elements in the
// operands and output. The default ShapeUtil::ByteSizeOf does not handle
// opaque types.
- current_bytes_accessed_ =
- ShapeUtil::IsOpaque(hlo->shape()) ? 0 : ShapeSizeBytes(hlo->shape());
+ current_bytes_accessed_ = shape_size_(hlo->shape());
for (const HloInstruction* operand : hlo->operands()) {
- current_bytes_accessed_ += ShapeUtil::IsOpaque(operand->shape())
- ? 0
- : ShapeSizeBytes(operand->shape());
+ current_bytes_accessed_ += shape_size_(operand->shape());
}
return Status::OK();
@@ -161,7 +154,7 @@ Status HloCostAnalysis::HandleTuple(
// The tuple instruction only gathers pointers from inputs (it doesn't iterate
// through them). The memory touched is then only the size of the output
// buffer.
- current_bytes_accessed_ = ShapeSizeBytes(tuple->shape());
+ current_bytes_accessed_ = shape_size_(tuple->shape());
return Status::OK();
}
@@ -412,12 +405,8 @@ Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while,
body_visitor.flop_count() + condition_visitor.flop_count();
current_transcendental_count_ = body_visitor.transcendental_count() +
condition_visitor.transcendental_count();
- if (shape_size_ != nullptr) {
- TF_ASSIGN_OR_RETURN(int64 body_bytes, body_visitor.bytes_accessed());
- TF_ASSIGN_OR_RETURN(int64 condition_bytes,
- condition_visitor.bytes_accessed());
- current_bytes_accessed_ = body_bytes + condition_bytes;
- }
+ current_bytes_accessed_ =
+ body_visitor.bytes_accessed() + condition_visitor.bytes_accessed();
return Status::OK();
}
@@ -436,22 +425,9 @@ int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
return it == hlo_to_transcendental_count_.end() ? 0 : it->second;
}
-StatusOr<int64> HloCostAnalysis::bytes_accessed(
- const HloInstruction& hlo) const {
- if (shape_size_ == nullptr) {
- return FailedPrecondition(
- "Cannot determine number of bytes accessed; no size function given");
- }
+int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
auto it = hlo_to_bytes_accessed_.find(&hlo);
return it == hlo_to_bytes_accessed_.end() ? 0 : it->second;
}
-StatusOr<int64> HloCostAnalysis::bytes_accessed() const {
- if (shape_size_ == nullptr) {
- return FailedPrecondition(
- "Cannot determine number of bytes accessed; no size function given");
- }
- return bytes_accessed_;
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 87752ef335..afeb588c0c 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -36,11 +36,8 @@ namespace xla {
// operations separately from transcendental operations.
class HloCostAnalysis : public DfsHloVisitor {
public:
- HloCostAnalysis() {}
-
- // Constructor which accepts a function for computing the size in bytes of the
- // top-level buffer of a shape. This constructor must be used if
- // bytes_accessed methods are to be called.
+ // shape_size is a function which returns the size in bytes of the top-level
+ // buffer of a shape.
using ShapeSizeFunction = std::function<int64(const Shape&)>;
explicit HloCostAnalysis(const ShapeSizeFunction& shape_size)
: shape_size_(shape_size) {}
@@ -136,15 +133,11 @@ class HloCostAnalysis : public DfsHloVisitor {
int64 flop_count(const HloInstruction& hlo) const;
int64 transcendental_count(const HloInstruction& hlo) const;
- // Returns the number of bytes read/written. Returns an Status error if no
- // ShapeSizeFunction was given at construction time.
- StatusOr<int64> bytes_accessed(const HloInstruction& hlo) const;
- StatusOr<int64> bytes_accessed() const;
+ // Returns the number of bytes read/written.
+ int64 bytes_accessed(const HloInstruction& hlo) const;
+ int64 bytes_accessed() const { return bytes_accessed_; }
private:
- // Returns the size in bytes of the top-level buffer of a shape.
- int64 ShapeSizeBytes(const Shape& shape) const;
-
// An FMA counts as two floating point operations in these analyses.
static constexpr int64 kFmaFlops = 2;
@@ -154,7 +147,7 @@ class HloCostAnalysis : public DfsHloVisitor {
// Function which computes the size of the top-level of a given shape (not
// including nested elements, if any). If null then bytes_accessed methods
// return an error.
- ShapeSizeFunction shape_size_ = nullptr;
+ const ShapeSizeFunction shape_size_;
// The total number of floating point operations, transcendental operations,
// and bytes accesses (read or written) in the computation.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index abdde619fb..9f1c91d41c 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -160,8 +160,8 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) {
EXPECT_EQ(analysis.transcendental_count(), 0);
// Bytes accessed is sum of inputs and output.
- TF_ASSIGN_OR_ASSERT_OK(int64 bytes_accessed, analysis.bytes_accessed());
- EXPECT_EQ(bytes_accessed, sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30));
+ EXPECT_EQ(analysis.bytes_accessed(),
+ sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30));
}
TEST_F(HloCostAnalysisTest, Map) {
@@ -178,8 +178,7 @@ TEST_F(HloCostAnalysisTest, Map) {
// add contributes to 10 flops and exp contributes to 10 transcendental ops.
EXPECT_EQ(analysis.flop_count(), 10);
EXPECT_EQ(analysis.transcendental_count(), 10);
- TF_ASSIGN_OR_ASSERT_OK(int64 bytes_accessed, analysis.bytes_accessed());
- EXPECT_EQ(bytes_accessed, 80);
+ EXPECT_EQ(analysis.bytes_accessed(), 80);
}
TEST_F(HloCostAnalysisTest, Convolution) {
@@ -205,8 +204,8 @@ TEST_F(HloCostAnalysisTest, Convolution) {
EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3);
// Bytes accessed is sum of inputs and output.
- TF_ASSIGN_OR_ASSERT_OK(int64 bytes_accessed, analysis.bytes_accessed());
- EXPECT_EQ(bytes_accessed, sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
+ EXPECT_EQ(analysis.bytes_accessed(),
+ sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
}
TEST_F(HloCostAnalysisTest, Reduce) {
@@ -391,8 +390,7 @@ TEST_F(HloCostAnalysisTest, TupleCost) {
EXPECT_EQ(analysis.flop_count(), 0);
EXPECT_EQ(analysis.transcendental_count(), 0);
- TF_ASSIGN_OR_ASSERT_OK(int64 bytes_accessed, analysis.bytes_accessed());
- EXPECT_EQ(bytes_accessed, kPointerSize * 2);
+ EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
index 82c85635f0..e2a81a052c 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
@@ -60,16 +60,32 @@ string HloExecutionProfile::ToString(
return cycles / clock_rate_ghz / 1000.0;
};
- auto append_item = [&](int64 cycles, int64 flops, const string& name) {
+ auto append_item = [&](int64 cycles, int64 flops, int64 bytes_accessed,
+ const string& name) {
double nsecs = cycles / clock_rate_ghz;
+ string bytes_per_sec;
+ string bytes_per_cycle;
+ if (bytes_accessed >= 0) {
+ bytes_per_sec = tensorflow::strings::HumanReadableNumBytes(
+ bytes_accessed / (nsecs / 1e9));
+ bytes_per_cycle =
+ tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles);
+ } else {
+ bytes_per_sec = "<unknown>";
+ bytes_per_cycle = "<unknown>";
+ }
+
tensorflow::strings::StrAppend(
&result,
tensorflow::strings::Printf(
- "%15lld cycles (%6.2f%%) :: %12.1f usec @ f_nom :: %18s :: %s",
+ "%15lld cycles (%6.2f%%) :: %12.1f usec @ f_nom :: %18s :: %12s/s "
+ ":: "
+ "%12s/cycle :: "
+ "%s",
cycles, cycles / static_cast<double>(total_cycles) * 100,
cycles_to_microseconds(cycles),
flops <= 0 ? "<none>" : HumanReadableNumFlops(flops, nsecs).c_str(),
- name.c_str()));
+ bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str()));
};
tensorflow::strings::StrAppend(
&result,
@@ -77,13 +93,15 @@ string HloExecutionProfile::ToString(
tensorflow::strings::HumanReadableElapsedTime(
total_cycles / clock_rate_ghz / 1e9)
.c_str()));
- append_item(total_cycles, -1, "[total]");
+ append_item(total_cycles, -1, -1, "[total]");
for (const auto& item : items) {
+ const HloInstruction* hlo = item.first;
tensorflow::strings::StrAppend(&result, "\n\t");
- auto flops =
- item.first == nullptr ? -1 : cost_analysis.flop_count(*item.first);
- string display = item.first == nullptr ? "<none>" : item.first->ToString();
- append_item(item.second, flops, display);
+ int64 flops = hlo == nullptr ? -1 : cost_analysis.flop_count(*hlo);
+ int64 bytes_accessed =
+ hlo == nullptr ? -1 : cost_analysis.bytes_accessed(*hlo);
+ string display = hlo == nullptr ? "<none>" : hlo->ToString();
+ append_item(item.second, flops, bytes_accessed, display);
}
MetricTableReport table;
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index c1c1946375..915dd4af24 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -365,8 +365,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// execution time per byte saved by rematerializing this instruction. In
// general, we want the most memory saving for the least latency penalty
// which is captured by this heuristic.
- TF_ASSIGN_OR_RETURN(int64 bytes_accessed,
- cost_analysis.bytes_accessed(*candidate));
+ int64 bytes_accessed = cost_analysis.bytes_accessed(*candidate);
int64 net_memory_reduced =
std::min(memory_reduced, memory_usage - memory_limit_bytes_);
int64 elements_accessed =
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 725808bc88..95bbb01e9e 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1203,7 +1203,9 @@ tensorflow::Status Service::GetComputationStats(
MakeHloDumper()(*module, "computation statistics subject");
// Run HLO analysis to get the computation statistics.
- HloCostAnalysis analysis;
+ HloCostAnalysis analysis([this](const Shape& shape) {
+ return execute_backend_->compiler()->ShapeSizeBytes(shape);
+ });
TF_RETURN_IF_ERROR(
module->entry_computation()->root_instruction()->Accept(&analysis));
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 9c4b0f44c8..ba609fe881 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -444,7 +444,9 @@ ReturnT Service::ExecuteOnStreamWrapper(
}
if (profile_ptr != nullptr) {
- HloCostAnalysis analysis;
+ HloCostAnalysis analysis([this](const Shape& shape) {
+ return execute_backend_->compiler()->ShapeSizeBytes(shape);
+ });
tensorflow::Status analysis_status =
executable->module().entry_computation()->root_instruction()->Accept(
&analysis);