diff options
author | 2017-07-26 15:13:16 -0700 | |
---|---|---|
committer | 2017-07-26 15:16:42 -0700 | |
commit | f1537588d4983dfdabfdd82442264a9d1a702d2f (patch) | |
tree | 4fa57bb29ed5442ae89933f1ee59bfc91ab3adbb | |
parent | 4eb749185981692e6ed327e241f3f8ca91a5a08f (diff) |
Set device for nodes added by LayoutOptimizer.
PiperOrigin-RevId: 163263257
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer.cc | 55 |
1 files changed, 32 insertions, 23 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 55da0f710d..5a4e1b76e2 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -238,6 +238,7 @@ class NodeProcessor { *node->add_input() = input_name; *node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC; node->set_op("Transpose"); + node->set_device(node_->device()); AttrValue attr_data_type; attr_data_type.set_type(data_type); node->mutable_attr()->insert({"T", attr_data_type}); @@ -273,11 +274,10 @@ class NodeProcessor { int output_pos = NodePosition(node_->input(pos)); TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes")); - NodeDef* transpose = AddNodeTranspose( + AddNodeTranspose( node_name, node_->input(pos), node_->attr().at("T").type(), input_node->attr().at("_output_shapes").list().shape(output_pos), true); - transpose->set_device(node_->device()); node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name); node_map_->AddOutput(node_name, node_->name()); *node_->mutable_input(pos) = node_name; @@ -313,10 +313,9 @@ class NodeProcessor { } TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes")); - NodeDef* transpose = AddNodeTranspose( - node_name, node_->name(), node_->attr().at("T").type(), - node_->attr().at("_output_shapes").list().shape(0), false); - transpose->set_device(node_->device()); + AddNodeTranspose(node_name, node_->name(), node_->attr().at("T").type(), + node_->attr().at("_output_shapes").list().shape(0), + false); *it = node_name; node_map_->UpdateOutput(node_->name(), output->name(), node_name); node_map_->AddOutput(node_name, output->name()); @@ -604,6 +603,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { node_map_->AddNode(name, node); node->set_name(name); node->set_op("Const"); + node->set_device(node_->device()); AttrValue attr_data_type; attr_data_type.set_type(DT_INT32); node->mutable_attr()->insert({"dtype", attr_data_type}); @@ -628,6 +628,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { *node->add_input() = input_name; *node->add_input() = shape_const_node_name; node->set_op("Reshape"); + node->set_device(node_->device()); AttrValue attr_type_indices; attr_type_indices.set_type(DT_INT32); @@ -650,13 +651,10 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes")); int vector_size = input_node->attr().at("_output_shapes").list().shape(0).dim(0).size(); - NodeDef* shp = AddNodeShapeConst(shape_const_node_name, vector_size); - shp->set_device("/job:localhost/replica:0/task:0/cpu:0"); + AddNodeShapeConst(shape_const_node_name, vector_size); TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); - NodeDef* reshape = - AddNodeReshape(reshape_node_name, node_->input(1), - shape_const_node_name, node_->attr().at("T").type()); - reshape->set_device(node_->device()); + AddNodeReshape(reshape_node_name, node_->input(1), shape_const_node_name, + node_->attr().at("T").type()); node_map_->AddOutput(shape_const_node_name, reshape_node_name); node_map_->UpdateOutput(node_->input(1), node_->name(), reshape_node_name); @@ -953,8 +951,12 @@ struct TuningConfig { class DataLayoutOptimizer { public: - explicit DataLayoutOptimizer(GraphDef* graph, TuningConfig config) - : graph_(graph), node_map_(graph_), config_(config) {} + explicit DataLayoutOptimizer(const string& default_device, GraphDef* graph, + TuningConfig config) + : default_device_(default_device), + graph_(graph), + node_map_(graph_), + config_(config) {} Status Optimize() { LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size(); @@ -972,6 +974,7 @@ class DataLayoutOptimizer { node_map_.AddNode(name, node); node->set_name(name); node->set_op("Const"); + node->set_device(default_device_); AttrValue attr_data_type; attr_data_type.set_type(DT_INT32); node->mutable_attr()->insert({"dtype", attr_data_type}); @@ -990,6 +993,7 @@ class DataLayoutOptimizer { node_map_.AddNode(name, node); node->set_name(name); node->set_op("Const"); + node->set_device(default_device_); AttrValue attr_data_type; attr_data_type.set_type(dtype); node->mutable_attr()->insert({"dtype", attr_data_type}); @@ -1014,6 +1018,7 @@ class DataLayoutOptimizer { node_map_.AddNode(kReductionConst, node); node->set_name(kReductionConst); node->set_op("Const"); + node->set_device(default_device_); AttrValue attr_data_type; attr_data_type.set_type(DT_INT32); node->mutable_attr()->insert({"dtype", attr_data_type}); @@ -1072,15 +1077,10 @@ class DataLayoutOptimizer { // expanded. if (graph_->node_size() > node_size_original) { NodeDef* n = AddNodePermConst(kPermNHWCToNCHW, {0, 3, 1, 2}); - n->set_device("/job:localhost/replica:0/task:0/cpu:0"); n = AddNodePermConst(kPermNCHWToNHWC, {0, 2, 3, 1}); - n->set_device("/job:localhost/replica:0/task:0/cpu:0"); n = AddNodeConcatConst(); - n->set_device("/job:localhost/replica:0/task:0/cpu:0"); n = AddGatherAxisConst(); - n->set_device("/job:localhost/replica:0/task:0/cpu:0"); n = AddNodeReductionConst(); - n->set_device("/job:localhost/replica:0/task:0/cpu:0"); std::set<string> ops_format_agnostic = GetOpsFormatAgnostic(); for (int i = 0; i < graph_->node_size(); i++) { if (ops_format_agnostic.find(graph_->node(i).op()) != @@ -1169,6 +1169,7 @@ class DataLayoutOptimizer { return Status::OK(); } + string default_device_; GraphDef* graph_; NodeMap node_map_; TuningConfig config_; @@ -1221,16 +1222,24 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, *output = new_item.graph; TuningConfig config; config.no_gemm = false; - DataLayoutOptimizer layout_optimizer(output, config); - status = layout_optimizer.Optimize(); + string default_device = "/job:localhost/replica:0/task:0/cpu:0"; + if (cluster) { + if (!cluster->GetDevices().empty()) { + default_device = cluster->GetDevices().begin()->first; + } + } + std::unique_ptr<DataLayoutOptimizer> layout_optimizer( + new DataLayoutOptimizer(default_device, output, config)); + status = layout_optimizer->Optimize(); // This is based on an empirical observation that if the introduced Transpose // nodes is more than 30, not using GEMM implementation would result in better // performance. if (status.ok() && GetNumTranspose(*output) > 30) { *output = new_item.graph; config.no_gemm = true; - DataLayoutOptimizer layout_optimizer(output, config); - status = layout_optimizer.Optimize(); + layout_optimizer.reset( + new DataLayoutOptimizer(default_device, output, config)); + status = layout_optimizer->Optimize(); } if (!status.ok()) { |