diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-09-26 22:55:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-26 22:58:44 -0700 |
commit | 40dee372e3ee844c4746baa914c07b9c582a2ce7 (patch) | |
tree | bd39a01c0aad8a6cfc8e5d4205674b5a8892133d /tensorflow/core/grappler/costs/virtual_scheduler.cc | |
parent | 680c2f5d988fb1f3b725fb8f0a67d1926be8169b (diff) |
Define OpContext and use it for OpLevelCostEstimator.
This CL does not add any functionality (except GraphDef's function library pointer is passed to
OpContext), but we can later add additional fields to OpContext struct for extending
VirtualCluster, Scheduler, Placer, and others.
PiperOrigin-RevId: 170157235
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 16c434b0ad..4294c9e954 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -377,7 +377,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv( return std::make_pair(send, recv); } -NodeInfo VirtualScheduler::GetCurrNodeInfo() const { +OpContext VirtualScheduler::GetCurrNode() const { const NodeDef* node = ready_nodes_->GetCurrNode(); // Get the device from the placer. @@ -389,12 +389,12 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const { device.set_type(kChannelDevice); } - // Construct NodeInfo. - NodeInfo node_info; + // Construct OpContext. + OpContext op_context; const auto& node_state = node_map_.at(node); - node_info.name = node->name(); - node_info.device_name = node_state.device_name; - auto& op_info = node_info.op_info; + op_context.name = node->name(); + op_context.device_name = node_state.device_name; + auto& op_info = op_context.op_info; op_info.set_op(node->op()); *op_info.mutable_attr() = node->attr(); for (auto& input : node_state.input_properties) { @@ -404,7 +404,11 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const { *op_info.add_outputs() = output; } op_info.mutable_device()->Swap(&device); - return node_info; + + if (grappler_item_->graph.has_library()) { + op_context.function_library = &grappler_item_->graph.library(); + } + return op_context; } NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) { @@ -497,8 +501,8 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { const auto& op_name = node->op(); // Also keep track of op counts and times per op (with their shapes). - NodeInfo node_info = GetCurrNodeInfo(); - string node_description = GetOpDescription(node_info.op_info); + OpContext op_context = GetCurrNode(); + string node_description = GetOpDescription(op_context.op_info); op_counts_[node_description] += 1; op_costs_[node_description] = node_costs.execution_time.asMicroSeconds().count(); |