aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc4
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc30
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.h2
3 files changed, 18 insertions, 18 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 5298dc7565..e2a589cb47 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -29,6 +29,7 @@ constexpr char kSparseMatMul[] = "SparseMatMul";
constexpr char kIdentity[] = "Identity";
constexpr char kNoOp[] = "NoOp";
constexpr char kReshape[] = "Reshape";
+constexpr char kRecv[] = "_Recv";
OpLevelCostEstimator::OpLevelCostEstimator() {
// Syntactic sugar to build and return a lambda that takes an OpInfo and
@@ -49,7 +50,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)}};
+ {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)}};
}
Costs OpLevelCostEstimator::PredictCosts(const OpInfo& op_features) const {
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 80318fe8ad..2e2662f71c 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -43,7 +43,7 @@ Costs CombineCosts(const Costs& left, const Costs& right) {
result.max_per_op_streaming =
std::max(left.max_per_op_streaming, right.max_per_op_streaming);
}
- VLOG(2) << "costs execution_time=" << result.execution_time.count()
+ VLOG(3) << "costs execution_time=" << result.execution_time.count()
<< " max_memory=" << result.max_memory
<< " max_per_op_buffers=" << result.max_per_op_buffers
<< " max_per_op_streaming=" << result.max_per_op_streaming;
@@ -206,8 +206,9 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const {
const auto* to = node_state.outputs[0];
return ChannelDeviceName(from, to);
} else {
- const string& device =
- node->device().empty() ? kDefaultDevice : node->device();
+ const string& device = node->device().empty()
+ ? "/" + default_device_type_ + ":0"
+ : node->device();
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
LOG(WARNING) << "Device name parse failed: " << device;
@@ -220,9 +221,7 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const {
string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
const NodeDef* to) const {
- // TODO(dyoon): once ChannelCostEstimator is ready, assign Channel device to
- // _Send ops.
- return kDefaultDevice;
+ return kChannelDevice + ": " + DeviceName(from) + " to " + DeviceName(to);
}
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
@@ -275,12 +274,8 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
return std::make_pair(send, recv);
}
-const NodeDef* VirtualScheduler::GetCurrNode() const {
- return ready_nodes_->GetCurrNode();
-}
-
NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
- const NodeDef* node = GetCurrNode();
+ const NodeDef* node = ready_nodes_->GetCurrNode();
std::vector<OpInfo::TensorProperties> inputs =
graph_properties_.GetInputProperties(node->name());
// Some ops created within VirtualScheduler may need further processing to
@@ -292,13 +287,13 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
DeviceProperties device;
if (placer_) {
device = placer_->get_device(*node);
- } else {
- device.set_type("UNKNOWN");
+ }
+ if (device.type() == "UNKNOWN") {
string device_type;
int device_id;
DeviceNameUtils::ParsedName parsed;
if (!node->device().empty() &&
- DeviceNameUtils::ParseFullName(node->device(), &parsed)) {
+ DeviceNameUtils::ParseFullName(DeviceName(node), &parsed)) {
device_type = parsed.type;
device_id = parsed.id;
} else {
@@ -312,6 +307,11 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
}
}
+ // Special case for _Send op.
+ if (IsSendOp(node)) {
+ device.set_type(kChannelDevice);
+ }
+
NodeInfo node_info;
node_info.name = node->name();
node_info.device_name = graph_properties_.GetDeviceName(node->name());
@@ -347,7 +347,7 @@ Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
// Update graph_costs_ and per-op costs.
graph_costs_ = CombineCosts(graph_costs_, node_costs);
- const auto* node = GetCurrNode();
+ const auto* node = ready_nodes_->GetCurrNode();
const auto& op_name = node->op();
// Also keep track of op counts and times per op (with their shapes).
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index 83878eea0a..23cb0bc53c 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -121,9 +121,7 @@ class VirtualScheduler {
const string kAttrSrcDevice = "src_device_";
const string kAttrDstDevice = "dst_device_";
const string kChannelDevice = "Channel";
- const string kDefaultDevice = "/CPU:0";
- const NodeDef* GetCurrNode() const;
void MaybeUpdateInputProperties(
const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const;
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);