diff options
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 30 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.h | 2 |
3 files changed, 18 insertions, 18 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 5298dc7565..e2a589cb47 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -29,6 +29,7 @@ constexpr char kSparseMatMul[] = "SparseMatMul"; constexpr char kIdentity[] = "Identity"; constexpr char kNoOp[] = "NoOp"; constexpr char kReshape[] = "Reshape"; +constexpr char kRecv[] = "_Recv"; OpLevelCostEstimator::OpLevelCostEstimator() { // Syntactic sugar to build and return a lambda that takes an OpInfo and @@ -49,7 +50,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, {kIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)}}; + {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)}, + {kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)}}; } Costs OpLevelCostEstimator::PredictCosts(const OpInfo& op_features) const { diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 80318fe8ad..2e2662f71c 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -43,7 +43,7 @@ Costs CombineCosts(const Costs& left, const Costs& right) { result.max_per_op_streaming = std::max(left.max_per_op_streaming, right.max_per_op_streaming); } - VLOG(2) << "costs execution_time=" << result.execution_time.count() + VLOG(3) << "costs execution_time=" << result.execution_time.count() << " max_memory=" << result.max_memory << " max_per_op_buffers=" << result.max_per_op_buffers << " max_per_op_streaming=" << result.max_per_op_streaming; @@ -206,8 +206,9 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const { const auto* to = node_state.outputs[0]; return ChannelDeviceName(from, to); } else { - const string& device = - node->device().empty() ? kDefaultDevice : node->device(); + const string& device = node->device().empty() + ? "/" + default_device_type_ + ":0" + : node->device(); DeviceNameUtils::ParsedName parsed; if (!DeviceNameUtils::ParseFullName(device, &parsed)) { LOG(WARNING) << "Device name parse failed: " << device; @@ -220,9 +221,7 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const { string VirtualScheduler::ChannelDeviceName(const NodeDef* from, const NodeDef* to) const { - // TODO(dyoon): once ChannelCostEstimator is ready, assign Channel device to - // _Send ops. - return kDefaultDevice; + return kChannelDevice + ": " + DeviceName(from) + " to " + DeviceName(to); } std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode( @@ -275,12 +274,8 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode( return std::make_pair(send, recv); } -const NodeDef* VirtualScheduler::GetCurrNode() const { - return ready_nodes_->GetCurrNode(); -} - NodeInfo VirtualScheduler::GetCurrNodeInfo() const { - const NodeDef* node = GetCurrNode(); + const NodeDef* node = ready_nodes_->GetCurrNode(); std::vector<OpInfo::TensorProperties> inputs = graph_properties_.GetInputProperties(node->name()); // Some ops created within VirtualScheduler may need further processing to @@ -292,13 +287,13 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const { DeviceProperties device; if (placer_) { device = placer_->get_device(*node); - } else { - device.set_type("UNKNOWN"); + } + if (device.type() == "UNKNOWN") { string device_type; int device_id; DeviceNameUtils::ParsedName parsed; if (!node->device().empty() && - DeviceNameUtils::ParseFullName(node->device(), &parsed)) { + DeviceNameUtils::ParseFullName(DeviceName(node), &parsed)) { device_type = parsed.type; device_id = parsed.id; } else { @@ -312,6 +307,11 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const { } } + // Special case for _Send op. + if (IsSendOp(node)) { + device.set_type(kChannelDevice); + } + NodeInfo node_info; node_info.name = node->name(); node_info.device_name = graph_properties_.GetDeviceName(node->name()); @@ -347,7 +347,7 @@ Costs& VirtualScheduler::FindOrCreateZero(const string& op_name, bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { // Update graph_costs_ and per-op costs. graph_costs_ = CombineCosts(graph_costs_, node_costs); - const auto* node = GetCurrNode(); + const auto* node = ready_nodes_->GetCurrNode(); const auto& op_name = node->op(); // Also keep track of op counts and times per op (with their shapes). diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 83878eea0a..23cb0bc53c 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -121,9 +121,7 @@ class VirtualScheduler { const string kAttrSrcDevice = "src_device_"; const string kAttrDstDevice = "dst_device_"; const string kChannelDevice = "Channel"; - const string kDefaultDevice = "/CPU:0"; - const NodeDef* GetCurrNode() const; void MaybeUpdateInputProperties( const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const; NodeState& GetNodeStateOrCreateIt(const NodeDef* node); |