diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 31 |
1 files changed, 21 insertions, 10 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 5b93fb128f..5c5bdad1cb 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -47,6 +47,7 @@ Costs CombineCosts(const Costs& left, const Costs& right) { result.execution_time += right.execution_time; result.compute_time += right.compute_time; result.memory_time += right.memory_time; + result.intermediate_memory_time += right.intermediate_memory_time; result.num_ops_total += right.num_ops_total; if (right.inaccurate) result.inaccurate = true; @@ -825,23 +826,29 @@ Costs VirtualScheduler::Summary() const { VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count(); VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count(); VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count(); + VLOG(1) << "Expected intermediate memory time: " + << graph_costs_.intermediate_memory_time.count(); VLOG(1) << "Expected max memory: " << graph_costs_.max_memory; VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers; VLOG(1) << "Expected max per-op streaming buffers: " << graph_costs_.max_per_op_streaming; - VLOG(1) << "Per-op execution time / compute time / memory time:"; + VLOG(1) << "Per-op execution time / compute time / memory time" + << " / intermediate memory time:"; for (const auto& op_cost_pair : op_to_cost_) { const auto& op = op_cost_pair.first; const auto& cost = op_cost_pair.second.execution_time.count(); const auto& compute_cost = op_cost_pair.second.compute_time.count(); const auto& memory_cost = op_cost_pair.second.memory_time.count(); + const auto& intermediate_memory_cost = + op_cost_pair.second.intermediate_memory_time.count(); const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate; if (cost) { // Skip printing out zero-cost ops. VLOG(1) << strings::Printf( - " + %30s : %c %10lld / %10lld / %10lld", op.c_str(), + " + %30s : %c %10lld / %10lld / %10lld / %10lld", op.c_str(), (is_op_cost_accurate ? ' ' : '~'), static_cast<int64>(cost), - static_cast<int64>(compute_cost), static_cast<int64>(memory_cost)); + static_cast<int64>(compute_cost), static_cast<int64>(memory_cost), + static_cast<int64>(intermediate_memory_cost)); } } @@ -894,7 +901,8 @@ Costs VirtualScheduler::Summary() const { << " having unknown shapes"; VLOG(1) << "Per-op execution time / compute time / memory time " - "(and memory usage at peak memory usage):"; + << " / intermediate memory time" + << " (and memory usage at peak memory usage):"; // Profile non-persistent op memory usage. for (const auto& node_port : state.mem_usage_snapshot_at_peak) { @@ -910,6 +918,8 @@ Costs VirtualScheduler::Summary() const { const auto& cost = op_cost_pair.second.execution_time.count(); const auto& compute_cost = op_cost_pair.second.compute_time.count(); const auto& memory_cost = op_cost_pair.second.memory_time.count(); + const auto& intermediate_memory_cost = + op_cost_pair.second.intermediate_memory_time.count(); total_compute_time_ns += op_cost_pair.second.execution_time; const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate; if (!is_op_cost_accurate) { @@ -927,12 +937,13 @@ Costs VirtualScheduler::Summary() const { : 0.0; if (cost || mem_usage_percent > 1.0) { // Print out only non-zero cost ops or ops with > 1% memory usage. - VLOG(1) << strings::Printf(" + %30s : %c %10lld / %10lld / %10lld", - op.c_str(), - (is_op_cost_accurate ? ' ' : '~'), - static_cast<int64>(cost), - static_cast<int64>(compute_cost), - static_cast<int64>(memory_cost)) + VLOG(1) << strings::Printf( + " + %30s : %c %10lld / %10lld / %10lld / %10lld", + op.c_str(), (is_op_cost_accurate ? ' ' : '~'), + static_cast<int64>(cost), + static_cast<int64>(compute_cost), + static_cast<int64>(memory_cost), + static_cast<int64>(intermediate_memory_cost)) << " (" << strings::HumanReadableNumBytes(op_mem_usage) << " [" << mem_usage_percent << "%] " << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")"); |