aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/virtual_scheduler.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-26 22:55:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 22:58:44 -0700
commit40dee372e3ee844c4746baa914c07b9c582a2ce7 (patch)
treebd39a01c0aad8a6cfc8e5d4205674b5a8892133d /tensorflow/core/grappler/costs/virtual_scheduler.cc
parent680c2f5d988fb1f3b725fb8f0a67d1926be8169b (diff)
Define OpContext and use it for OpLevelCostEstimator.
This CL does not add any functionality (except GraphDef's function library pointer is passed to OpContext), but we can later add additional fields to OpContext struct for extending VirtualCluster, Scheduler, Placer, and others. PiperOrigin-RevId: 170157235
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc22
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 16c434b0ad..4294c9e954 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -377,7 +377,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
return std::make_pair(send, recv);
}
-NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
+OpContext VirtualScheduler::GetCurrNode() const {
const NodeDef* node = ready_nodes_->GetCurrNode();
// Get the device from the placer.
@@ -389,12 +389,12 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
device.set_type(kChannelDevice);
}
- // Construct NodeInfo.
- NodeInfo node_info;
+ // Construct OpContext.
+ OpContext op_context;
const auto& node_state = node_map_.at(node);
- node_info.name = node->name();
- node_info.device_name = node_state.device_name;
- auto& op_info = node_info.op_info;
+ op_context.name = node->name();
+ op_context.device_name = node_state.device_name;
+ auto& op_info = op_context.op_info;
op_info.set_op(node->op());
*op_info.mutable_attr() = node->attr();
for (auto& input : node_state.input_properties) {
@@ -404,7 +404,11 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
*op_info.add_outputs() = output;
}
op_info.mutable_device()->Swap(&device);
- return node_info;
+
+ if (grappler_item_->graph.has_library()) {
+ op_context.function_library = &grappler_item_->graph.library();
+ }
+ return op_context;
}
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
@@ -497,8 +501,8 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
const auto& op_name = node->op();
// Also keep track of op counts and times per op (with their shapes).
- NodeInfo node_info = GetCurrNodeInfo();
- string node_description = GetOpDescription(node_info.op_info);
+ OpContext op_context = GetCurrNode();
+ string node_description = GetOpDescription(op_context.op_info);
op_counts_[node_description] += 1;
op_costs_[node_description] =
node_costs.execution_time.asMicroSeconds().count();