aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/virtual_scheduler.cc
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2017-12-01 14:48:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 14:52:24 -0800
commitd0ae1064ed0bb4bd1aed00afd4235f4dd5c853f0 (patch)
treeb077af5e338ac90aa8d10046e188943c049b24b0 /tensorflow/core/grappler/costs/virtual_scheduler.cc
parent10f77231b005c76b5a771243e18384b4b66be325 (diff)
Prefix inaccurate costs with "~" in VirtualScheduler verbose log.
Fix some inaccurate estimates exposed by this approach: - propagate the inaccuracy flag when merging device stats; - estimate Const as no-op; - estimate RandomUniform, Relu and Softmax as element-wise; - consider estimates accurate for known element-wise ops in op_level_cost_estimator. PiperOrigin-RevId: 177643976
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc32
1 files changed, 24 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 6640de668d..1554aeb3c0 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -43,6 +43,9 @@ Costs CombineCosts(const Costs& left, const Costs& right) {
Costs result = left;
result.execution_time += right.execution_time;
+ if (right.inaccurate) {
+ result.inaccurate = true;
+ }
if (right.max_memory != kMemoryUnknown) {
result.max_memory += right.max_memory;
}
@@ -538,7 +541,8 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
string node_description = GetOpDescription(op_context.op_info);
op_counts_[node_description] += 1;
op_costs_[node_description] =
- node_costs.execution_time.asMicroSeconds().count();
+ std::make_pair(node_costs.execution_time.asMicroSeconds().count(),
+ !node_costs.inaccurate);
auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
op_cost = CombineCosts(op_cost, node_costs);
@@ -647,8 +651,10 @@ Costs VirtualScheduler::Summary() const {
for (const auto& op_cost_pair : op_to_cost_) {
const auto& op = op_cost_pair.first;
const auto& cost = op_cost_pair.second.execution_time.count();
+ const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
if (cost) { // Skip printing out zero-cost ops.
- VLOG(1) << " + " << op << " : " << cost;
+ VLOG(1) << " + " << op << " : " << (is_op_cost_accurate ? "" : "~")
+ << cost;
}
}
@@ -699,10 +705,16 @@ Costs VirtualScheduler::Summary() const {
CalculateOutputSize(node_map_.at(node).output_properties, port);
}
Costs::NanoSeconds total_compute_time_ns;
+ bool is_total_cost_accurate = true;
for (const auto& op_cost_pair : state.op_to_cost) {
const auto& op = op_cost_pair.first;
const auto& cost = op_cost_pair.second.execution_time.count();
total_compute_time_ns += op_cost_pair.second.execution_time;
+ const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
+ if (!is_op_cost_accurate) {
+ is_total_cost_accurate = false;
+ }
+
int64 op_mem_usage = 0;
auto it = op_to_memory.find(op);
if (it != op_to_memory.end()) {
@@ -714,9 +726,9 @@ Costs VirtualScheduler::Summary() const {
: 0.0;
if (cost || mem_usage_percent > 1.0) {
// Print out only non-zero cost ops or ops with > 1% memory usage.
- VLOG(1) << " + " << op << " : " << cost << " ("
- << strings::HumanReadableNumBytes(op_mem_usage) << " ["
- << mem_usage_percent << "%] "
+ VLOG(1) << " + " << op << " : " << (is_op_cost_accurate ? "" : "~")
+ << cost << " (" << strings::HumanReadableNumBytes(op_mem_usage)
+ << " [" << mem_usage_percent << "%] "
<< (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
}
}
@@ -725,8 +737,9 @@ Costs VirtualScheduler::Summary() const {
if (wall_time_ns.count() > 0) {
utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count();
}
- VLOG(1) << "Device = " << name
- << ", total_compute_time_ns = " << total_compute_time_ns.count()
+ VLOG(1) << "Device = " << name << ", total_compute_time_ns = "
+ << (is_total_cost_accurate ? "" : "~")
+ << total_compute_time_ns.count()
<< ", utilization = " << utilization << "%";
if (critical_path_costs.execution_time <= state.GetCurrTime()) {
@@ -738,8 +751,11 @@ Costs VirtualScheduler::Summary() const {
// Also log the op description and their corresponding counts.
VLOG(2) << "Node description, counts, cost:";
for (const auto& item : op_counts_) {
+ int cost;
+ bool is_cost_accurate;
+ std::tie(cost, is_cost_accurate) = op_costs_.at(item.first);
VLOG(2) << "Node: " << item.first << ", Count: " << item.second
- << ", Individual Cost: " << op_costs_.at(item.first);
+ << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost;
}
}