aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/costs/BUILD1
-rw-r--r--tensorflow/core/grappler/costs/utils.cc11
-rw-r--r--tensorflow/core/grappler/costs/utils.h3
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc15
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.h2
5 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 1ead2d5baa..792260ce5a 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -173,6 +173,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_properties",
+ ":utils",
":virtual_placer",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index 727aeb7ee6..7b7d79fc7e 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -219,5 +219,16 @@ OpInfo BuildOpInfo(
return op_info;
}
+string GetOpDescription(const OpInfo& op_info) {
+ string description = "[";
+ description += "Op=" + op_info.op() + ", ";
+ description += "input_shapes=[";
+ for (auto const& input : op_info.inputs()) {
+ description += PartialTensorShape::DebugString(input.shape());
+ }
+ description += "]";
+ return description;
+}
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index bdba4e4b15..cb23ac8355 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -45,6 +45,9 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node);
DeviceProperties GetDeviceInfo(const string& device_str);
+// Return a string describing a node given a nodeinfo.
+string GetOpDescription(const OpInfo& op_info);
+
// Builds the OpInfo proto for node, given all nodes in the graph, the node's
// device and its input properties which are typically built by shape inference
// or calling FindInputFeatures.
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 86cf498538..80318fe8ad 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/utils.h"
+#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/util/device_name_utils.h"
@@ -349,6 +350,13 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
const auto* node = GetCurrNode();
const auto& op_name = node->op();
+ // Also keep track of op counts and times per op (with their shapes).
+ NodeInfo node_info = GetCurrNodeInfo();
+ string node_description = GetOpDescription(node_info.op_info);
+ op_counts_[node_description] += 1;
+ op_costs_[node_description] =
+ node_costs.execution_time.asMicroSeconds().count();
+
auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
op_cost = CombineCosts(op_cost, node_costs);
@@ -445,6 +453,13 @@ Costs VirtualScheduler::Summary() const {
}
}
+ // Also log the op description and their corresponding counts.
+ VLOG(1) << "Node description, counts, cost:";
+ for (const auto& item : op_counts_) {
+ VLOG(1) << "Node: " << item.first << ", Count: " << item.second
+ << ", Individual Cost: " << op_costs_.at(item.first);
+ }
+
VLOG(1) << "Critical path execution time: "
<< critical_path_costs.execution_time.count();
return critical_path_costs;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index 0855071432..83878eea0a 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -137,6 +137,8 @@ class VirtualScheduler {
bool IsRecvOp(const NodeDef* node) const;
GraphProperties graph_properties_;
+ std::map<string, int> op_counts_; // Op counts with key with input shape.
+ std::map<string, int> op_costs_; // Individual op costs (with input shapes).
Costs graph_costs_; // Graph cost.
std::map<string, Costs> op_to_cost_; // Per-op cost.
std::unique_ptr<ReadyNodeManager> ready_nodes_;