aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/virtual_scheduler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc31
1 files changed, 21 insertions, 10 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 5b93fb128f..5c5bdad1cb 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -47,6 +47,7 @@ Costs CombineCosts(const Costs& left, const Costs& right) {
result.execution_time += right.execution_time;
result.compute_time += right.compute_time;
result.memory_time += right.memory_time;
+ result.intermediate_memory_time += right.intermediate_memory_time;
result.num_ops_total += right.num_ops_total;
if (right.inaccurate) result.inaccurate = true;
@@ -825,23 +826,29 @@ Costs VirtualScheduler::Summary() const {
VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count();
VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count();
+ VLOG(1) << "Expected intermediate memory time: "
+ << graph_costs_.intermediate_memory_time.count();
VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
VLOG(1) << "Expected max per-op streaming buffers: "
<< graph_costs_.max_per_op_streaming;
- VLOG(1) << "Per-op execution time / compute time / memory time:";
+ VLOG(1) << "Per-op execution time / compute time / memory time"
+ << " / intermediate memory time:";
for (const auto& op_cost_pair : op_to_cost_) {
const auto& op = op_cost_pair.first;
const auto& cost = op_cost_pair.second.execution_time.count();
const auto& compute_cost = op_cost_pair.second.compute_time.count();
const auto& memory_cost = op_cost_pair.second.memory_time.count();
+ const auto& intermediate_memory_cost =
+ op_cost_pair.second.intermediate_memory_time.count();
const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
if (cost) { // Skip printing out zero-cost ops.
VLOG(1) << strings::Printf(
- " + %30s : %c %10lld / %10lld / %10lld", op.c_str(),
+ " + %30s : %c %10lld / %10lld / %10lld / %10lld", op.c_str(),
(is_op_cost_accurate ? ' ' : '~'), static_cast<int64>(cost),
- static_cast<int64>(compute_cost), static_cast<int64>(memory_cost));
+ static_cast<int64>(compute_cost), static_cast<int64>(memory_cost),
+ static_cast<int64>(intermediate_memory_cost));
}
}
@@ -894,7 +901,8 @@ Costs VirtualScheduler::Summary() const {
<< " having unknown shapes";
VLOG(1) << "Per-op execution time / compute time / memory time "
- "(and memory usage at peak memory usage):";
+ << " / intermediate memory time"
+ << " (and memory usage at peak memory usage):";
// Profile non-persistent op memory usage.
for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
@@ -910,6 +918,8 @@ Costs VirtualScheduler::Summary() const {
const auto& cost = op_cost_pair.second.execution_time.count();
const auto& compute_cost = op_cost_pair.second.compute_time.count();
const auto& memory_cost = op_cost_pair.second.memory_time.count();
+ const auto& intermediate_memory_cost =
+ op_cost_pair.second.intermediate_memory_time.count();
total_compute_time_ns += op_cost_pair.second.execution_time;
const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
if (!is_op_cost_accurate) {
@@ -927,12 +937,13 @@ Costs VirtualScheduler::Summary() const {
: 0.0;
if (cost || mem_usage_percent > 1.0) {
// Print out only non-zero cost ops or ops with > 1% memory usage.
- VLOG(1) << strings::Printf(" + %30s : %c %10lld / %10lld / %10lld",
- op.c_str(),
- (is_op_cost_accurate ? ' ' : '~'),
- static_cast<int64>(cost),
- static_cast<int64>(compute_cost),
- static_cast<int64>(memory_cost))
+ VLOG(1) << strings::Printf(
+ " + %30s : %c %10lld / %10lld / %10lld / %10lld",
+ op.c_str(), (is_op_cost_accurate ? ' ' : '~'),
+ static_cast<int64>(cost),
+ static_cast<int64>(compute_cost),
+ static_cast<int64>(memory_cost),
+ static_cast<int64>(intermediate_memory_cost))
<< " (" << strings::HumanReadableNumBytes(op_mem_usage) << " ["
<< mem_usage_percent << "%] "
<< (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");