diff options
-rw-r--r-- | tensorflow/core/grappler/optimizers/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer.cc | 490 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils.h | 8 | ||||
-rw-r--r-- | tensorflow/python/grappler/layout_optimizer_test.py | 52 |
6 files changed, 401 insertions, 162 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index a7515786a0..659451e991 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -275,6 +275,7 @@ cc_library( "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:frame", ], ) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index ada166703e..35b0b7c163 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -94,10 +94,6 @@ class DeviceSimple : public DeviceBase { std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_; }; -string AsControlDependency(const NodeDef& node) { - return strings::StrCat("^", node.name()); -} - } // namespace ConstantFolding::ConstantFolding() { diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index f469f9a9ac..a4b0a60e1f 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/frame.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -95,10 +96,84 @@ bool IsNodeNCHWToNHWC(const string& node_name) { return false; } -class NodeProcessor { +class GraphProcessor { public: - NodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : graph_(graph), node_(node), node_map_(node_map) {} + GraphProcessor(GraphDef* graph, NodeMap* node_map) + : graph_(graph), node_map_(node_map) {} + + protected: + NodeDef* AddNodePermConst(const string& name, const string& device, + const std::vector<int>& permutation) { + NodeDef* node = graph_->add_node(); + node_map_->AddNode(name, node); + node->set_name(name); + node->set_op("Const"); + node->set_device(device); + AttrValue attr_data_type; + attr_data_type.set_type(DT_INT32); + node->mutable_attr()->insert({"dtype", attr_data_type}); + AttrValue attr_tensor; + Tensor tensor(DT_INT32, TensorShape({4})); + for (int i = 0; static_cast<size_t>(i) < permutation.size(); i++) { + tensor.flat<int>()(i) = permutation[i]; + } + tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); + node->mutable_attr()->insert({"value", attr_tensor}); + return node; + } + + NodeDef* AddNodeConstScalar(const string& name, const string& device, + DataType dtype, int value) { + NodeDef* node = graph_->add_node(); + node_map_->AddNode(name, node); + node->set_name(name); + node->set_op("Const"); + node->set_device(device); + AttrValue attr_data_type; + attr_data_type.set_type(dtype); + node->mutable_attr()->insert({"dtype", attr_data_type}); + AttrValue attr_tensor; + Tensor tensor(dtype, TensorShape({})); + tensor.scalar<int>()() = value; + tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); + node->mutable_attr()->insert({"value", attr_tensor}); + return node; + } + + NodeDef* AddNodeReductionConst(const string& name, const string& device) { + NodeDef* node = graph_->add_node(); + node_map_->AddNode(name, node); + node->set_name(name); + node->set_op("Const"); + node->set_device(device); + AttrValue attr_data_type; + attr_data_type.set_type(DT_INT32); + node->mutable_attr()->insert({"dtype", attr_data_type}); + + AttrValue attr_tensor; + Tensor tensor(DT_INT32, TensorShape({3})); + std::vector<int> axis = {0, 2, 3}; + for (int i = 0; static_cast<size_t>(i) < axis.size(); i++) { + tensor.flat<int>()(i) = axis[i]; + } + tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); + node->mutable_attr()->insert({"value", attr_tensor}); + return node; + } + + GraphDef* graph_; + NodeMap* node_map_; + + private: +}; + +class NodeProcessor : public GraphProcessor { + public: + NodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : GraphProcessor(graph, node_map), + node_(node), + is_in_frame_(is_in_frame) {} virtual ~NodeProcessor() {} virtual Status ConvertNode() { if (ShouldProcess()) { @@ -229,14 +304,14 @@ class NodeProcessor { } NodeDef* AddNodeTranspose(const string& node_name, const string& input_name, - DataType data_type, + const string& const_name, DataType data_type, const TensorShapeProto& input_shape, bool NHWCToNCHW) { NodeDef* node = graph_->add_node(); node_map_->AddNode(node_name, node); node->set_name(node_name); *node->add_input() = input_name; - *node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC; + *node->add_input() = const_name; node->set_op("Transpose"); node->set_device(node_->device()); AttrValue attr_data_type; @@ -276,8 +351,10 @@ class NodeProcessor { auto input_node = node_map_->GetNode(node_->input(pos)); TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes")); + string const_name = GetOrAddNodePermNHWCToNCHW(pos); AddNodeTranspose( - node_name, node_->input(pos), node_->attr().at("T").type(), + node_name, node_->input(pos), const_name, + node_->attr().at("T").type(), input_node->attr().at("_output_shapes").list().shape(output_pos), true); node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name); @@ -289,6 +366,7 @@ class NodeProcessor { virtual Status AddLayoutTransposeToOutputs() { auto outputs = node_map_->GetOutputs(node_->name()); + string const_name = GetOrAddNodePermNCHWToNHWC(); for (const auto& output : outputs) { string base_name = strings::StrCat(node_->name(), "-", output->name()); string node_name = @@ -315,9 +393,9 @@ class NodeProcessor { } TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes")); - AddNodeTranspose(node_name, node_->name(), node_->attr().at("T").type(), - node_->attr().at("_output_shapes").list().shape(0), - false); + AddNodeTranspose( + node_name, node_->name(), const_name, node_->attr().at("T").type(), + node_->attr().at("_output_shapes").list().shape(0), false); *it = node_name; node_map_->UpdateOutput(node_->name(), output->name(), node_name); node_map_->AddOutput(node_name, output->name()); @@ -327,11 +405,56 @@ class NodeProcessor { virtual Status CustomizedProcessing() { return Status::OK(); } - GraphDef* graph_; + NodeDef* AddNodePermNHWCToNCHW(const string& suffix, + const string& depended_node, + const string& device) { + auto const_node = AddNodePermConst( + strings::StrCat(kPermNHWCToNCHW, "-", suffix), device, {0, 3, 1, 2}); + // This is to ensure the transpose node and the const node are in the + // same frame. + *const_node->add_input() = AsControlDependency(depended_node); + return const_node; + } + + NodeDef* AddNodePermNCHWToNHWC(const string& suffix, + const string& depended_node, + const string& device) { + auto const_node = AddNodePermConst( + strings::StrCat(kPermNCHWToNHWC, "-", suffix), device, {0, 2, 3, 1}); + // This is to ensure the transpose node and the const node are in the same + // frame. + *const_node->add_input() = AsControlDependency(depended_node); + return const_node; + } + NodeDef* node_; - NodeMap* node_map_; + bool is_in_frame_; private: + string GetOrAddNodePermNHWCToNCHW(int pos) { + string const_name; + if (is_in_frame_) { + auto const_node = AddNodePermNHWCToNCHW( + node_->input(pos), NodeName(node_->input(pos)), node_->device()); + const_name = const_node->name(); + } else { + const_name = kPermNHWCToNCHW; + } + return const_name; + } + + string GetOrAddNodePermNCHWToNHWC() { + string const_name; + if (is_in_frame_) { + auto const_node = + AddNodePermNCHWToNHWC(node_->name(), node_->name(), node_->device()); + const_name = const_node->name(); + } else { + const_name = kPermNCHWToNHWC; + } + return const_name; + } + void UpdateTuple(AttrValue_ListValue* list) { int64 h = list->i(1); int64 w = list->i(2); @@ -344,8 +467,9 @@ class NodeProcessor { class AvgPoolGradProcessor : public NodeProcessor { public: - AvgPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + AvgPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame) {} protected: std::vector<int> GetInputPos() const override { @@ -357,8 +481,9 @@ class AvgPoolGradProcessor : public NodeProcessor { class BiasAddGradProcessor : public NodeProcessor { public: - BiasAddGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + BiasAddGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame) {} protected: bool ShouldProcess() const override { @@ -377,8 +502,8 @@ class BiasAddGradProcessor : public NodeProcessor { class Conv2DProcessor : public NodeProcessor { public: Conv2DProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, - bool no_gemm) - : NodeProcessor(graph, node, node_map), no_gemm_(no_gemm) {} + bool no_gemm, bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame), no_gemm_(no_gemm) {} protected: bool ShouldProcess() const override { @@ -447,8 +572,9 @@ class Conv2DProcessor : public NodeProcessor { class Conv2DBackpropFilterProcessor : public Conv2DProcessor { public: Conv2DBackpropFilterProcessor(GraphDef* graph, NodeDef* node, - NodeMap* node_map, bool no_gemm) - : Conv2DProcessor(graph, node, node_map, no_gemm) {} + NodeMap* node_map, bool no_gemm, + bool is_in_frame) + : Conv2DProcessor(graph, node, node_map, no_gemm, is_in_frame) {} protected: bool IsGemmUsed() const override { @@ -472,8 +598,9 @@ class Conv2DBackpropFilterProcessor : public Conv2DProcessor { class Conv2DBackpropInputProcessor : public Conv2DProcessor { public: Conv2DBackpropInputProcessor(GraphDef* graph, NodeDef* node, - NodeMap* node_map, bool no_gemm) - : Conv2DProcessor(graph, node, node_map, no_gemm) {} + NodeMap* node_map, bool no_gemm, + bool is_in_frame) + : Conv2DProcessor(graph, node, node_map, no_gemm, is_in_frame) {} protected: bool IsGemmUsed() const override { @@ -492,8 +619,9 @@ class Conv2DBackpropInputProcessor : public Conv2DProcessor { class FusedBatchNormGradProcessor : public NodeProcessor { public: - FusedBatchNormGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + FusedBatchNormGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame) {} protected: std::vector<int> GetInputPos() const override { @@ -504,8 +632,9 @@ class FusedBatchNormGradProcessor : public NodeProcessor { class MaxPoolGradProcessor : public NodeProcessor { public: - MaxPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + MaxPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame) {} protected: std::vector<int> GetInputPos() const override { @@ -516,8 +645,9 @@ class MaxPoolGradProcessor : public NodeProcessor { class AgnosticNodeProcessor : public NodeProcessor { public: - AgnosticNodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + AgnosticNodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame) {} protected: bool ShouldProcess() const override { @@ -548,8 +678,9 @@ class AgnosticNodeProcessor : public NodeProcessor { class AddNProcessor : public AgnosticNodeProcessor { public: - AddNProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + AddNProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: std::vector<int> GetInputPos() const override { @@ -564,8 +695,9 @@ class AddNProcessor : public AgnosticNodeProcessor { class BinaryOpProcessor : public AgnosticNodeProcessor { public: - BinaryOpProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) { + BinaryOpProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) { is_4d_with_vector_ = Is4DOperateWithVector(); } @@ -672,8 +804,9 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { class ConcatProcessor : public AgnosticNodeProcessor { public: - ConcatProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) { + ConcatProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) { // For Concat, the concat axis is the first input; for ConcatV2, // the last input. axis_node_pos_ = @@ -698,8 +831,9 @@ class ConcatProcessor : public AgnosticNodeProcessor { } Status CustomizedProcessing() override { - node_map_->AddOutput(kConcatConst, node_->name()); - *node_->mutable_input(axis_node_pos_) = kConcatConst; + string concat_const_name = GetOrAddNodeConcatConst(); + node_map_->AddOutput(concat_const_name, node_->name()); + *node_->mutable_input(axis_node_pos_) = concat_const_name; return Status::OK(); } @@ -712,12 +846,38 @@ class ConcatProcessor : public AgnosticNodeProcessor { } int axis_node_pos_; + + private: + NodeDef* AddNodeConcatConst(const string& suffix, const string& depended_node, + const string& device) { + auto const_node = AddNodeConstScalar( + strings::StrCat(kConcatConst, "-", suffix), device, DT_INT32, 1); + // This is to ensure the concat node and the const node are + // in the same frame. + *const_node->add_input() = AsControlDependency(depended_node); + return const_node; + } + + string GetOrAddNodeConcatConst() { + string const_name; + if (is_in_frame_) { + int value_node_pos = (axis_node_pos_ == 0) ? 1 : 0; + auto const_node = AddNodeConcatConst( + node_->name(), NodeName(node_->input(value_node_pos)), + node_->device()); + const_name = const_node->name(); + } else { + const_name = kConcatConst; + } + return const_name; + } }; class ReluGradProcessor : public AgnosticNodeProcessor { public: - ReluGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + ReluGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: std::vector<int> GetInputPos() const override { @@ -728,8 +888,9 @@ class ReluGradProcessor : public AgnosticNodeProcessor { class SliceProcessor : public AgnosticNodeProcessor { public: - SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: Status CustomizedProcessing() override { @@ -749,14 +910,62 @@ class SliceProcessor : public AgnosticNodeProcessor { } private: + NodeDef* AddNodeGatherAxisConst(const string& suffix, + const string& depended_node, + const string& device) { + auto const_node = AddNodeConstScalar( + strings::StrCat(kGatherAxisConst, "-", suffix), device, DT_INT32, 0); + // This is to ensure the Slice node and the const node are + // in the same frame. + *const_node->add_input() = AsControlDependency(depended_node); + return const_node; + } + + string GetOrAddNodeGatherAxisConst() { + string const_name; + if (is_in_frame_) { + auto const_node = AddNodeGatherAxisConst( + node_->name(), NodeName(node_->input(0)), node_->device()); + const_name = const_node->name(); + } else { + const_name = kGatherAxisConst; + } + return const_name; + } + + string GetOrAddNodePermNHWCToNCHW() { + string const_name; + if (is_in_frame_) { + auto const_node = AddNodePermNHWCToNCHW( + node_->name(), NodeName(node_->input(0)), node_->device()); + const_name = const_node->name(); + } else { + const_name = kPermNHWCToNCHW; + } + return const_name; + } + + string GetOrAddNodePermNCHWToNHWC() { + string const_name; + if (is_in_frame_) { + auto const_node = AddNodePermNCHWToNHWC( + node_->name(), NodeName(node_->input(0)), node_->device()); + const_name = const_node->name(); + } else { + const_name = kPermNCHWToNHWC; + } + return const_name; + } + void AddNodePermVec(const string& node_name, const string& input_name, DataType data_type, bool NHWCToNCHW) { NodeDef* node = graph_->add_node(); node_map_->AddNode(node_name, node); node->set_name(node_name); *node->add_input() = input_name; - *node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC; - *node->add_input() = kGatherAxisConst; + *node->add_input() = NHWCToNCHW ? GetOrAddNodePermNHWCToNCHW() + : GetOrAddNodePermNCHWToNHWC(); + *node->add_input() = GetOrAddNodeGatherAxisConst(); node->set_op("GatherV2"); AttrValue attr_type_indices; @@ -782,8 +991,9 @@ class SliceProcessor : public AgnosticNodeProcessor { // before this optimization. class SliceProcessorConst : public AgnosticNodeProcessor { public: - SliceProcessorConst(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + SliceProcessorConst(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: Status CustomizedProcessing() override { @@ -799,8 +1009,9 @@ class SliceProcessorConst : public AgnosticNodeProcessor { // example use case is in the gradient computation of Concat for InceptionV3. class SliceProcessorConcatOffset : public AgnosticNodeProcessor { public: - SliceProcessorConcatOffset(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + SliceProcessorConcatOffset(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: Status CustomizedProcessing() override { @@ -849,8 +1060,9 @@ class SliceProcessorConcatOffset : public AgnosticNodeProcessor { class SqueezeProcessor : public AgnosticNodeProcessor { public: - SqueezeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + SqueezeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: bool ShouldProcess() const override { @@ -898,8 +1110,9 @@ class SqueezeProcessor : public AgnosticNodeProcessor { class SumProcessor : public AgnosticNodeProcessor { public: - SumProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + SumProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: bool ShouldProcess() const override { @@ -913,7 +1126,7 @@ class SumProcessor : public AgnosticNodeProcessor { Status CustomizedProcessing() override { node_map_->AddOutput(kReductionConst, node_->name()); - *node_->mutable_input(1) = kReductionConst; + *node_->mutable_input(1) = GetOrAddNodeReductionConst(); return Status::OK(); } @@ -938,6 +1151,29 @@ class SumProcessor : public AgnosticNodeProcessor { } return false; } + + NodeDef* AddNodeReductionConst(const string& suffix, + const string& depended_node, + const string& device) { + auto const_node = GraphProcessor::AddNodeReductionConst( + strings::StrCat(kReductionConst, "-", suffix), device); + // This is to ensure the Sum node and the const node are in the + // same frame. + *const_node->add_input() = AsControlDependency(depended_node); + return const_node; + } + + string GetOrAddNodeReductionConst() { + string const_name; + if (is_in_frame_) { + auto const_node = AddNodeReductionConst( + node_->name(), NodeName(node_->input(0)), node_->device()); + const_name = const_node->name(); + } else { + const_name = kReductionConst; + } + return const_name; + } }; struct TuningConfig { @@ -951,13 +1187,12 @@ struct TuningConfig { bool no_gemm; }; -class DataLayoutOptimizer { +class DataLayoutOptimizer : GraphProcessor { public: explicit DataLayoutOptimizer(const string& default_device, GraphDef* graph, - TuningConfig config) - : default_device_(default_device), - graph_(graph), - node_map_(graph_), + NodeMap* node_map, TuningConfig config) + : GraphProcessor(graph, node_map), + default_device_(default_device), config_(config) {} Status Optimize() { @@ -970,105 +1205,65 @@ class DataLayoutOptimizer { } private: - NodeDef* AddNodePermConst(const string& name, - const std::vector<int>& permutation) { - NodeDef* node = graph_->add_node(); - node_map_.AddNode(name, node); - node->set_name(name); - node->set_op("Const"); - node->set_device(default_device_); - AttrValue attr_data_type; - attr_data_type.set_type(DT_INT32); - node->mutable_attr()->insert({"dtype", attr_data_type}); - AttrValue attr_tensor; - Tensor tensor(DT_INT32, TensorShape({4})); - for (int i = 0; static_cast<size_t>(i) < permutation.size(); i++) { - tensor.flat<int>()(i) = permutation[i]; - } - tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); - node->mutable_attr()->insert({"value", attr_tensor}); - return node; + NodeDef* AddNodePermNHWCToNCHW() { + return AddNodePermConst(kPermNHWCToNCHW, default_device_, {0, 3, 1, 2}); } - NodeDef* AddConstScalar(const char* name, DataType dtype, int value) { - NodeDef* node = graph_->add_node(); - node_map_.AddNode(name, node); - node->set_name(name); - node->set_op("Const"); - node->set_device(default_device_); - AttrValue attr_data_type; - attr_data_type.set_type(dtype); - node->mutable_attr()->insert({"dtype", attr_data_type}); - AttrValue attr_tensor; - Tensor tensor(dtype, TensorShape({})); - tensor.scalar<int>()() = value; - tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); - node->mutable_attr()->insert({"value", attr_tensor}); - return node; + NodeDef* AddNodePermNCHWToNHWC() { + return AddNodePermConst(kPermNCHWToNHWC, default_device_, {0, 2, 3, 1}); } NodeDef* AddNodeConcatConst() { - return AddConstScalar(kConcatConst, DT_INT32, 1); + return AddNodeConstScalar(kConcatConst, default_device_, DT_INT32, 1); } - NodeDef* AddGatherAxisConst() { - return AddConstScalar(kGatherAxisConst, DT_INT32, 0); + NodeDef* AddNodeGatherAxisConst() { + return AddNodeConstScalar(kGatherAxisConst, default_device_, DT_INT32, 0); } NodeDef* AddNodeReductionConst() { - NodeDef* node = graph_->add_node(); - node_map_.AddNode(kReductionConst, node); - node->set_name(kReductionConst); - node->set_op("Const"); - node->set_device(default_device_); - AttrValue attr_data_type; - attr_data_type.set_type(DT_INT32); - node->mutable_attr()->insert({"dtype", attr_data_type}); - - AttrValue attr_tensor; - Tensor tensor(DT_INT32, TensorShape({3})); - std::vector<int> axis = {0, 2, 3}; - for (int i = 0; static_cast<size_t>(i) < axis.size(); i++) { - tensor.flat<int>()(i) = axis[i]; - } - tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); - node->mutable_attr()->insert({"value", attr_tensor}); - return node; + return GraphProcessor::AddNodeReductionConst(kReductionConst, + default_device_); } // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic. Status Expand() { int node_size_original = graph_->node_size(); + std::unordered_map<const NodeDef*, std::vector<int>> frames; + IdentifyFrames(*graph_, &frames); + // This is the first pass where we expand the nodes which support NCHW. std::set<string> ops_format_supported = GetOpsFormatSupported(); - for (int i = 0; i < graph_->node_size(); i++) { + for (int i = 0; i < node_size_original; i++) { if (ops_format_supported.find(graph_->node(i).op()) != ops_format_supported.end()) { auto node = graph_->mutable_node(i); + bool is_in_frame = !frames[node].empty(); std::unique_ptr<NodeProcessor> node_processor; if (node->op().compare("AvgPoolGrad") == 0) { node_processor.reset( - new AvgPoolGradProcessor(graph_, node, &node_map_)); + new AvgPoolGradProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("BiasAddGrad") == 0) { node_processor.reset( - new BiasAddGradProcessor(graph_, node, &node_map_)); + new BiasAddGradProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("Conv2D") == 0) { - node_processor.reset( - new Conv2DProcessor(graph_, node, &node_map_, config_.no_gemm)); + node_processor.reset(new Conv2DProcessor( + graph_, node, node_map_, config_.no_gemm, is_in_frame)); } else if (node->op().compare("Conv2DBackpropFilter") == 0) { node_processor.reset(new Conv2DBackpropFilterProcessor( - graph_, node, &node_map_, config_.no_gemm)); + graph_, node, node_map_, config_.no_gemm, is_in_frame)); } else if (node->op().compare("Conv2DBackpropInput") == 0) { node_processor.reset(new Conv2DBackpropInputProcessor( - graph_, node, &node_map_, config_.no_gemm)); + graph_, node, node_map_, config_.no_gemm, is_in_frame)); } else if (node->op().compare("FusedBatchNormGrad") == 0) { - node_processor.reset( - new FusedBatchNormGradProcessor(graph_, node, &node_map_)); + node_processor.reset(new FusedBatchNormGradProcessor( + graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("MaxPoolGrad") == 0) { node_processor.reset( - new MaxPoolGradProcessor(graph_, node, &node_map_)); + new MaxPoolGradProcessor(graph_, node, node_map_, is_in_frame)); } else { - node_processor.reset(new NodeProcessor(graph_, node, &node_map_)); + node_processor.reset( + new NodeProcessor(graph_, node, node_map_, is_in_frame)); } TF_RETURN_IF_ERROR(node_processor->ConvertNode()); } @@ -1078,54 +1273,57 @@ class DataLayoutOptimizer { // only needs to be performed if at least one node in the previous pass is // expanded. if (graph_->node_size() > node_size_original) { - NodeDef* n = AddNodePermConst(kPermNHWCToNCHW, {0, 3, 1, 2}); - n = AddNodePermConst(kPermNCHWToNHWC, {0, 2, 3, 1}); + NodeDef* n = AddNodePermNHWCToNCHW(); + n = AddNodePermNCHWToNHWC(); n = AddNodeConcatConst(); - n = AddGatherAxisConst(); + n = AddNodeGatherAxisConst(); n = AddNodeReductionConst(); std::set<string> ops_format_agnostic = GetOpsFormatAgnostic(); for (int i = 0; i < graph_->node_size(); i++) { if (ops_format_agnostic.find(graph_->node(i).op()) != ops_format_agnostic.end()) { auto node = graph_->mutable_node(i); + bool is_in_frame = !frames[node].empty(); std::unique_ptr<NodeProcessor> node_processor; if (node->op().compare("AddN") == 0) { - node_processor.reset(new AddNProcessor(graph_, node, &node_map_)); + node_processor.reset( + new AddNProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("Add") == 0 || node->op().compare("Mul") == 0 || node->op().compare("RealDiv") == 0 || node->op().compare("SquaredDifference") == 0 || node->op().compare("Sub") == 0) { node_processor.reset( - new BinaryOpProcessor(graph_, node, &node_map_)); + new BinaryOpProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("Concat") == 0 || node->op().compare("ConcatV2") == 0) { - node_processor.reset(new ConcatProcessor(graph_, node, &node_map_)); + node_processor.reset( + new ConcatProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("ReluGrad") == 0) { node_processor.reset( - new ReluGradProcessor(graph_, node, &node_map_)); + new ReluGradProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("Slice") == 0) { - auto input1 = node_map_.GetNode(NodeName(node->input(1))); - auto input2 = node_map_.GetNode(NodeName(node->input(2))); + auto input1 = node_map_->GetNode(NodeName(node->input(1))); + auto input2 = node_map_->GetNode(NodeName(node->input(2))); if (input1->op() == "ConcatOffset") { - node_processor.reset( - new SliceProcessorConcatOffset(graph_, node, &node_map_)); + node_processor.reset(new SliceProcessorConcatOffset( + graph_, node, node_map_, is_in_frame)); } else if (input1->op() == "Const" && input2->op() == "Const") { - node_processor.reset( - new SliceProcessorConst(graph_, node, &node_map_)); + node_processor.reset(new SliceProcessorConst( + graph_, node, node_map_, is_in_frame)); } else { node_processor.reset( - new SliceProcessor(graph_, node, &node_map_)); + new SliceProcessor(graph_, node, node_map_, is_in_frame)); } - } else if (node->op().compare("Squeeze") == 0) { node_processor.reset( - new SqueezeProcessor(graph_, node, &node_map_)); + new SqueezeProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("Sum") == 0) { - node_processor.reset(new SumProcessor(graph_, node, &node_map_)); - } else { node_processor.reset( - new AgnosticNodeProcessor(graph_, node, &node_map_)); + new SumProcessor(graph_, node, node_map_, is_in_frame)); + } else { + node_processor.reset(new AgnosticNodeProcessor( + graph_, node, node_map_, is_in_frame)); } TF_RETURN_IF_ERROR(node_processor->ConvertNode()); } @@ -1145,12 +1343,12 @@ class DataLayoutOptimizer { if (IsNodeNCHWToNHWC(node->input(0))) { const string& trans_first = node->input(0); const string& trans_second = node->name(); - auto outputs = node_map_.GetOutputs(trans_second); + auto outputs = node_map_->GetOutputs(trans_second); CHECK(outputs.size() == 1) << "There is always only a single output for a Transpose node, " << "due to the way it is added by NodeProcessor."; NodeDef* output = *outputs.begin(); - string input = node_map_.GetNode(trans_first)->input(0); + string input = node_map_->GetNode(trans_first)->input(0); for (int i = 0; i < output->input_size(); i++) { if (output->input(i).compare(trans_second) == 0) { *output->mutable_input(i) = input; @@ -1173,8 +1371,6 @@ class DataLayoutOptimizer { } string default_device_; - GraphDef* graph_; - NodeMap node_map_; TuningConfig config_; }; @@ -1231,8 +1427,9 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, default_device = cluster->GetDevices().begin()->first; } } + std::unique_ptr<NodeMap> node_map(new NodeMap(output)); std::unique_ptr<DataLayoutOptimizer> layout_optimizer( - new DataLayoutOptimizer(default_device, output, config)); + new DataLayoutOptimizer(default_device, output, node_map.get(), config)); status = layout_optimizer->Optimize(); // This is based on an empirical observation that if the introduced Transpose // nodes is more than 30, not using GEMM implementation would result in better @@ -1240,8 +1437,9 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (status.ok() && GetNumTranspose(*output) > 30) { *output = new_item.graph; config.no_gemm = true; - layout_optimizer.reset( - new DataLayoutOptimizer(default_device, output, config)); + node_map.reset(new NodeMap(output)); + layout_optimizer.reset(new DataLayoutOptimizer(default_device, output, + node_map.get(), config)); status = layout_optimizer->Optimize(); } diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 948df18879..9e15744fab 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -177,5 +177,13 @@ bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms, return notified; } +string AsControlDependency(const NodeDef& node) { + return strings::StrCat("^", node.name()); +} + +string AsControlDependency(const string& node) { + return strings::StrCat("^", node); +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 4a8cb573d8..a9eccd685b 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -85,6 +85,14 @@ string AddPrefixToNodeName(const string& name, const string& prefix); bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms, thread::ThreadPool* thread_pool); +// Returns the node name prefixed with conventional symbol '^' +// for control dependency, given a NodeDef. +string AsControlDependency(const NodeDef& node); +// +// Returns the node name prefixed with conventional symbol '^' +// for control dependency, given a node name +string AsControlDependency(const string& node); + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 5dbaf76edb..bda9502cd1 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -22,8 +22,10 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -51,9 +53,7 @@ def max_pool_2x2(x): # Taken from tensorflow/examples/tutorials/mnist/mnist_deep.py -def two_layer_model(): - random_seed.set_random_seed(0) - x = random_ops.truncated_normal([1, 784], seed=0) +def two_layer_model(x): x_image = array_ops.reshape(x, [-1, 28, 28, 1]) w_conv1 = weight([5, 5, 1, 32]) b_conv1 = bias([32]) @@ -66,24 +66,39 @@ def two_layer_model(): return h_pool2 +def loop(): + random_seed.set_random_seed(0) + x1 = random_ops.truncated_normal([1, 784], seed=0) + x2 = random_ops.truncated_normal([1, 784], seed=0) + x3 = random_ops.truncated_normal([1, 784], seed=0) + x4 = random_ops.truncated_normal([1, 784], seed=0) + elems = (x1, x2, x3, x4) + outputs = functional_ops.map_fn(two_layer_model, elems, dtype=dtypes.float32) + return outputs + + +def get_config(): + rewrite_options = rewriter_config_pb2.RewriterConfig( + optimize_tensor_layout=True) + graph_options = config_pb2.GraphOptions( + rewrite_options=rewrite_options, build_cost_model=1) + config = config_pb2.ConfigProto(graph_options=graph_options) + return config + + class LayoutOptimizerTest(test.TestCase): """Tests the Grappler layout optimizer.""" def testTwoConvLayers(self): if test.is_gpu_available(cuda_only=True): - output = two_layer_model() + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + output = two_layer_model(x) with session.Session() as sess: output_val_ref = sess.run(output) - rewrite_options = rewriter_config_pb2.RewriterConfig( - optimize_tensor_layout=True) - graph_options = config_pb2.GraphOptions( - rewrite_options=rewrite_options, - build_cost_model=1) - config = config_pb2.ConfigProto(graph_options=graph_options) - - with session.Session(config=config) as sess: + with session.Session(config=get_config()) as sess: metadata = config_pb2.RunMetadata() output_val = sess.run(output, run_metadata=metadata) @@ -105,6 +120,19 @@ class LayoutOptimizerTest(test.TestCase): self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testLoop(self): + if test.is_gpu_available(cuda_only=True): + output = loop() + + with session.Session() as sess: + output_val_ref = sess.run(output) + + with session.Session(config=get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + if __name__ == '__main__': test.main() |