diff options
author | Max Galkin <maxgalkin@google.com> | 2017-12-01 14:48:56 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-01 14:52:24 -0800 |
commit | d0ae1064ed0bb4bd1aed00afd4235f4dd5c853f0 (patch) | |
tree | b077af5e338ac90aa8d10046e188943c049b24b0 /tensorflow/core/grappler/costs/virtual_scheduler.cc | |
parent | 10f77231b005c76b5a771243e18384b4b66be325 (diff) |
Prefix inaccurate costs with "~" in VirtualScheduler verbose log.
Fix some inaccurate estimates exposed by this approach:
- propagate the inaccuracy flag when merging device stats;
- estimate Const as no-op;
- estimate RandomUniform, Relu and Softmax as element-wise;
- consider estimates accurate for known element-wise ops in op_level_cost_estimator.
PiperOrigin-RevId: 177643976
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 32 |
1 files changed, 24 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 6640de668d..1554aeb3c0 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -43,6 +43,9 @@ Costs CombineCosts(const Costs& left, const Costs& right) { Costs result = left; result.execution_time += right.execution_time; + if (right.inaccurate) { + result.inaccurate = true; + } if (right.max_memory != kMemoryUnknown) { result.max_memory += right.max_memory; } @@ -538,7 +541,8 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { string node_description = GetOpDescription(op_context.op_info); op_counts_[node_description] += 1; op_costs_[node_description] = - node_costs.execution_time.asMicroSeconds().count(); + std::make_pair(node_costs.execution_time.asMicroSeconds().count(), + !node_costs.inaccurate); auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_); op_cost = CombineCosts(op_cost, node_costs); @@ -647,8 +651,10 @@ Costs VirtualScheduler::Summary() const { for (const auto& op_cost_pair : op_to_cost_) { const auto& op = op_cost_pair.first; const auto& cost = op_cost_pair.second.execution_time.count(); + const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate; if (cost) { // Skip printing out zero-cost ops. - VLOG(1) << " + " << op << " : " << cost; + VLOG(1) << " + " << op << " : " << (is_op_cost_accurate ? "" : "~") + << cost; } } @@ -699,10 +705,16 @@ Costs VirtualScheduler::Summary() const { CalculateOutputSize(node_map_.at(node).output_properties, port); } Costs::NanoSeconds total_compute_time_ns; + bool is_total_cost_accurate = true; for (const auto& op_cost_pair : state.op_to_cost) { const auto& op = op_cost_pair.first; const auto& cost = op_cost_pair.second.execution_time.count(); total_compute_time_ns += op_cost_pair.second.execution_time; + const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate; + if (!is_op_cost_accurate) { + is_total_cost_accurate = false; + } + int64 op_mem_usage = 0; auto it = op_to_memory.find(op); if (it != op_to_memory.end()) { @@ -714,9 +726,9 @@ Costs VirtualScheduler::Summary() const { : 0.0; if (cost || mem_usage_percent > 1.0) { // Print out only non-zero cost ops or ops with > 1% memory usage. - VLOG(1) << " + " << op << " : " << cost << " (" - << strings::HumanReadableNumBytes(op_mem_usage) << " [" - << mem_usage_percent << "%] " + VLOG(1) << " + " << op << " : " << (is_op_cost_accurate ? "" : "~") + << cost << " (" << strings::HumanReadableNumBytes(op_mem_usage) + << " [" << mem_usage_percent << "%] " << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")"); } } @@ -725,8 +737,9 @@ Costs VirtualScheduler::Summary() const { if (wall_time_ns.count() > 0) { utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count(); } - VLOG(1) << "Device = " << name - << ", total_compute_time_ns = " << total_compute_time_ns.count() + VLOG(1) << "Device = " << name << ", total_compute_time_ns = " + << (is_total_cost_accurate ? "" : "~") + << total_compute_time_ns.count() << ", utilization = " << utilization << "%"; if (critical_path_costs.execution_time <= state.GetCurrTime()) { @@ -738,8 +751,11 @@ Costs VirtualScheduler::Summary() const { // Also log the op description and their corresponding counts. VLOG(2) << "Node description, counts, cost:"; for (const auto& item : op_counts_) { + int cost; + bool is_cost_accurate; + std::tie(cost, is_cost_accurate) = op_costs_.at(item.first); VLOG(2) << "Node: " << item.first << ", Count: " << item.second - << ", Individual Cost: " << op_costs_.at(item.first); + << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost; } } |