aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-07-26 15:13:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-26 15:16:42 -0700
commitf1537588d4983dfdabfdd82442264a9d1a702d2f (patch)
tree4fa57bb29ed5442ae89933f1ee59bfc91ab3adbb
parent4eb749185981692e6ed327e241f3f8ca91a5a08f (diff)
Set device for nodes added by LayoutOptimizer.
PiperOrigin-RevId: 163263257
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc55
1 files changed, 32 insertions, 23 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 55da0f710d..5a4e1b76e2 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -238,6 +238,7 @@ class NodeProcessor {
*node->add_input() = input_name;
*node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC;
node->set_op("Transpose");
+ node->set_device(node_->device());
AttrValue attr_data_type;
attr_data_type.set_type(data_type);
node->mutable_attr()->insert({"T", attr_data_type});
@@ -273,11 +274,10 @@ class NodeProcessor {
int output_pos = NodePosition(node_->input(pos));
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
- NodeDef* transpose = AddNodeTranspose(
+ AddNodeTranspose(
node_name, node_->input(pos), node_->attr().at("T").type(),
input_node->attr().at("_output_shapes").list().shape(output_pos),
true);
- transpose->set_device(node_->device());
node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name);
node_map_->AddOutput(node_name, node_->name());
*node_->mutable_input(pos) = node_name;
@@ -313,10 +313,9 @@ class NodeProcessor {
}
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
- NodeDef* transpose = AddNodeTranspose(
- node_name, node_->name(), node_->attr().at("T").type(),
- node_->attr().at("_output_shapes").list().shape(0), false);
- transpose->set_device(node_->device());
+ AddNodeTranspose(node_name, node_->name(), node_->attr().at("T").type(),
+ node_->attr().at("_output_shapes").list().shape(0),
+ false);
*it = node_name;
node_map_->UpdateOutput(node_->name(), output->name(), node_name);
node_map_->AddOutput(node_name, output->name());
@@ -604,6 +603,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
node_map_->AddNode(name, node);
node->set_name(name);
node->set_op("Const");
+ node->set_device(node_->device());
AttrValue attr_data_type;
attr_data_type.set_type(DT_INT32);
node->mutable_attr()->insert({"dtype", attr_data_type});
@@ -628,6 +628,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
*node->add_input() = input_name;
*node->add_input() = shape_const_node_name;
node->set_op("Reshape");
+ node->set_device(node_->device());
AttrValue attr_type_indices;
attr_type_indices.set_type(DT_INT32);
@@ -650,13 +651,10 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
int vector_size =
input_node->attr().at("_output_shapes").list().shape(0).dim(0).size();
- NodeDef* shp = AddNodeShapeConst(shape_const_node_name, vector_size);
- shp->set_device("/job:localhost/replica:0/task:0/cpu:0");
+ AddNodeShapeConst(shape_const_node_name, vector_size);
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
- NodeDef* reshape =
- AddNodeReshape(reshape_node_name, node_->input(1),
- shape_const_node_name, node_->attr().at("T").type());
- reshape->set_device(node_->device());
+ AddNodeReshape(reshape_node_name, node_->input(1), shape_const_node_name,
+ node_->attr().at("T").type());
node_map_->AddOutput(shape_const_node_name, reshape_node_name);
node_map_->UpdateOutput(node_->input(1), node_->name(),
reshape_node_name);
@@ -953,8 +951,12 @@ struct TuningConfig {
class DataLayoutOptimizer {
public:
- explicit DataLayoutOptimizer(GraphDef* graph, TuningConfig config)
- : graph_(graph), node_map_(graph_), config_(config) {}
+ explicit DataLayoutOptimizer(const string& default_device, GraphDef* graph,
+ TuningConfig config)
+ : default_device_(default_device),
+ graph_(graph),
+ node_map_(graph_),
+ config_(config) {}
Status Optimize() {
LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size();
@@ -972,6 +974,7 @@ class DataLayoutOptimizer {
node_map_.AddNode(name, node);
node->set_name(name);
node->set_op("Const");
+ node->set_device(default_device_);
AttrValue attr_data_type;
attr_data_type.set_type(DT_INT32);
node->mutable_attr()->insert({"dtype", attr_data_type});
@@ -990,6 +993,7 @@ class DataLayoutOptimizer {
node_map_.AddNode(name, node);
node->set_name(name);
node->set_op("Const");
+ node->set_device(default_device_);
AttrValue attr_data_type;
attr_data_type.set_type(dtype);
node->mutable_attr()->insert({"dtype", attr_data_type});
@@ -1014,6 +1018,7 @@ class DataLayoutOptimizer {
node_map_.AddNode(kReductionConst, node);
node->set_name(kReductionConst);
node->set_op("Const");
+ node->set_device(default_device_);
AttrValue attr_data_type;
attr_data_type.set_type(DT_INT32);
node->mutable_attr()->insert({"dtype", attr_data_type});
@@ -1072,15 +1077,10 @@ class DataLayoutOptimizer {
// expanded.
if (graph_->node_size() > node_size_original) {
NodeDef* n = AddNodePermConst(kPermNHWCToNCHW, {0, 3, 1, 2});
- n->set_device("/job:localhost/replica:0/task:0/cpu:0");
n = AddNodePermConst(kPermNCHWToNHWC, {0, 2, 3, 1});
- n->set_device("/job:localhost/replica:0/task:0/cpu:0");
n = AddNodeConcatConst();
- n->set_device("/job:localhost/replica:0/task:0/cpu:0");
n = AddGatherAxisConst();
- n->set_device("/job:localhost/replica:0/task:0/cpu:0");
n = AddNodeReductionConst();
- n->set_device("/job:localhost/replica:0/task:0/cpu:0");
std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
for (int i = 0; i < graph_->node_size(); i++) {
if (ops_format_agnostic.find(graph_->node(i).op()) !=
@@ -1169,6 +1169,7 @@ class DataLayoutOptimizer {
return Status::OK();
}
+ string default_device_;
GraphDef* graph_;
NodeMap node_map_;
TuningConfig config_;
@@ -1221,16 +1222,24 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
*output = new_item.graph;
TuningConfig config;
config.no_gemm = false;
- DataLayoutOptimizer layout_optimizer(output, config);
- status = layout_optimizer.Optimize();
+ string default_device = "/job:localhost/replica:0/task:0/cpu:0";
+ if (cluster) {
+ if (!cluster->GetDevices().empty()) {
+ default_device = cluster->GetDevices().begin()->first;
+ }
+ }
+ std::unique_ptr<DataLayoutOptimizer> layout_optimizer(
+ new DataLayoutOptimizer(default_device, output, config));
+ status = layout_optimizer->Optimize();
// This is based on an empirical observation that if the introduced Transpose
// nodes is more than 30, not using GEMM implementation would result in better
// performance.
if (status.ok() && GetNumTranspose(*output) > 30) {
*output = new_item.graph;
config.no_gemm = true;
- DataLayoutOptimizer layout_optimizer(output, config);
- status = layout_optimizer.Optimize();
+ layout_optimizer.reset(
+ new DataLayoutOptimizer(default_device, output, config));
+ status = layout_optimizer->Optimize();
}
if (!status.ok()) {