aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc490
-rw-r--r--tensorflow/core/grappler/utils.cc8
-rw-r--r--tensorflow/core/grappler/utils.h8
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py52
6 files changed, 401 insertions, 162 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index a7515786a0..659451e991 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -275,6 +275,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:frame",
],
)
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index ada166703e..35b0b7c163 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -94,10 +94,6 @@ class DeviceSimple : public DeviceBase {
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
};
-string AsControlDependency(const NodeDef& node) {
- return strings::StrCat("^", node.name());
-}
-
} // namespace
ConstantFolding::ConstantFolding() {
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index f469f9a9ac..a4b0a60e1f 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -95,10 +96,84 @@ bool IsNodeNCHWToNHWC(const string& node_name) {
return false;
}
-class NodeProcessor {
+class GraphProcessor {
public:
- NodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : graph_(graph), node_(node), node_map_(node_map) {}
+ GraphProcessor(GraphDef* graph, NodeMap* node_map)
+ : graph_(graph), node_map_(node_map) {}
+
+ protected:
+ NodeDef* AddNodePermConst(const string& name, const string& device,
+ const std::vector<int>& permutation) {
+ NodeDef* node = graph_->add_node();
+ node_map_->AddNode(name, node);
+ node->set_name(name);
+ node->set_op("Const");
+ node->set_device(device);
+ AttrValue attr_data_type;
+ attr_data_type.set_type(DT_INT32);
+ node->mutable_attr()->insert({"dtype", attr_data_type});
+ AttrValue attr_tensor;
+ Tensor tensor(DT_INT32, TensorShape({4}));
+ for (int i = 0; static_cast<size_t>(i) < permutation.size(); i++) {
+ tensor.flat<int>()(i) = permutation[i];
+ }
+ tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
+ node->mutable_attr()->insert({"value", attr_tensor});
+ return node;
+ }
+
+ NodeDef* AddNodeConstScalar(const string& name, const string& device,
+ DataType dtype, int value) {
+ NodeDef* node = graph_->add_node();
+ node_map_->AddNode(name, node);
+ node->set_name(name);
+ node->set_op("Const");
+ node->set_device(device);
+ AttrValue attr_data_type;
+ attr_data_type.set_type(dtype);
+ node->mutable_attr()->insert({"dtype", attr_data_type});
+ AttrValue attr_tensor;
+ Tensor tensor(dtype, TensorShape({}));
+ tensor.scalar<int>()() = value;
+ tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
+ node->mutable_attr()->insert({"value", attr_tensor});
+ return node;
+ }
+
+ NodeDef* AddNodeReductionConst(const string& name, const string& device) {
+ NodeDef* node = graph_->add_node();
+ node_map_->AddNode(name, node);
+ node->set_name(name);
+ node->set_op("Const");
+ node->set_device(device);
+ AttrValue attr_data_type;
+ attr_data_type.set_type(DT_INT32);
+ node->mutable_attr()->insert({"dtype", attr_data_type});
+
+ AttrValue attr_tensor;
+ Tensor tensor(DT_INT32, TensorShape({3}));
+ std::vector<int> axis = {0, 2, 3};
+ for (int i = 0; static_cast<size_t>(i) < axis.size(); i++) {
+ tensor.flat<int>()(i) = axis[i];
+ }
+ tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
+ node->mutable_attr()->insert({"value", attr_tensor});
+ return node;
+ }
+
+ GraphDef* graph_;
+ NodeMap* node_map_;
+
+ private:
+};
+
+class NodeProcessor : public GraphProcessor {
+ public:
+ NodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : GraphProcessor(graph, node_map),
+ node_(node),
+ is_in_frame_(is_in_frame) {}
virtual ~NodeProcessor() {}
virtual Status ConvertNode() {
if (ShouldProcess()) {
@@ -229,14 +304,14 @@ class NodeProcessor {
}
NodeDef* AddNodeTranspose(const string& node_name, const string& input_name,
- DataType data_type,
+ const string& const_name, DataType data_type,
const TensorShapeProto& input_shape,
bool NHWCToNCHW) {
NodeDef* node = graph_->add_node();
node_map_->AddNode(node_name, node);
node->set_name(node_name);
*node->add_input() = input_name;
- *node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC;
+ *node->add_input() = const_name;
node->set_op("Transpose");
node->set_device(node_->device());
AttrValue attr_data_type;
@@ -276,8 +351,10 @@ class NodeProcessor {
auto input_node = node_map_->GetNode(node_->input(pos));
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
+ string const_name = GetOrAddNodePermNHWCToNCHW(pos);
AddNodeTranspose(
- node_name, node_->input(pos), node_->attr().at("T").type(),
+ node_name, node_->input(pos), const_name,
+ node_->attr().at("T").type(),
input_node->attr().at("_output_shapes").list().shape(output_pos),
true);
node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name);
@@ -289,6 +366,7 @@ class NodeProcessor {
virtual Status AddLayoutTransposeToOutputs() {
auto outputs = node_map_->GetOutputs(node_->name());
+ string const_name = GetOrAddNodePermNCHWToNHWC();
for (const auto& output : outputs) {
string base_name = strings::StrCat(node_->name(), "-", output->name());
string node_name =
@@ -315,9 +393,9 @@ class NodeProcessor {
}
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
- AddNodeTranspose(node_name, node_->name(), node_->attr().at("T").type(),
- node_->attr().at("_output_shapes").list().shape(0),
- false);
+ AddNodeTranspose(
+ node_name, node_->name(), const_name, node_->attr().at("T").type(),
+ node_->attr().at("_output_shapes").list().shape(0), false);
*it = node_name;
node_map_->UpdateOutput(node_->name(), output->name(), node_name);
node_map_->AddOutput(node_name, output->name());
@@ -327,11 +405,56 @@ class NodeProcessor {
virtual Status CustomizedProcessing() { return Status::OK(); }
- GraphDef* graph_;
+ NodeDef* AddNodePermNHWCToNCHW(const string& suffix,
+ const string& depended_node,
+ const string& device) {
+ auto const_node = AddNodePermConst(
+ strings::StrCat(kPermNHWCToNCHW, "-", suffix), device, {0, 3, 1, 2});
+ // This is to ensure the transpose node and the const node are in the
+ // same frame.
+ *const_node->add_input() = AsControlDependency(depended_node);
+ return const_node;
+ }
+
+ NodeDef* AddNodePermNCHWToNHWC(const string& suffix,
+ const string& depended_node,
+ const string& device) {
+ auto const_node = AddNodePermConst(
+ strings::StrCat(kPermNCHWToNHWC, "-", suffix), device, {0, 2, 3, 1});
+ // This is to ensure the transpose node and the const node are in the same
+ // frame.
+ *const_node->add_input() = AsControlDependency(depended_node);
+ return const_node;
+ }
+
NodeDef* node_;
- NodeMap* node_map_;
+ bool is_in_frame_;
private:
+ string GetOrAddNodePermNHWCToNCHW(int pos) {
+ string const_name;
+ if (is_in_frame_) {
+ auto const_node = AddNodePermNHWCToNCHW(
+ node_->input(pos), NodeName(node_->input(pos)), node_->device());
+ const_name = const_node->name();
+ } else {
+ const_name = kPermNHWCToNCHW;
+ }
+ return const_name;
+ }
+
+ string GetOrAddNodePermNCHWToNHWC() {
+ string const_name;
+ if (is_in_frame_) {
+ auto const_node =
+ AddNodePermNCHWToNHWC(node_->name(), node_->name(), node_->device());
+ const_name = const_node->name();
+ } else {
+ const_name = kPermNCHWToNHWC;
+ }
+ return const_name;
+ }
+
void UpdateTuple(AttrValue_ListValue* list) {
int64 h = list->i(1);
int64 w = list->i(2);
@@ -344,8 +467,9 @@ class NodeProcessor {
class AvgPoolGradProcessor : public NodeProcessor {
public:
- AvgPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : NodeProcessor(graph, node, node_map) {}
+ AvgPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : NodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
std::vector<int> GetInputPos() const override {
@@ -357,8 +481,9 @@ class AvgPoolGradProcessor : public NodeProcessor {
class BiasAddGradProcessor : public NodeProcessor {
public:
- BiasAddGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : NodeProcessor(graph, node, node_map) {}
+ BiasAddGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : NodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
bool ShouldProcess() const override {
@@ -377,8 +502,8 @@ class BiasAddGradProcessor : public NodeProcessor {
class Conv2DProcessor : public NodeProcessor {
public:
Conv2DProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
- bool no_gemm)
- : NodeProcessor(graph, node, node_map), no_gemm_(no_gemm) {}
+ bool no_gemm, bool is_in_frame)
+ : NodeProcessor(graph, node, node_map, is_in_frame), no_gemm_(no_gemm) {}
protected:
bool ShouldProcess() const override {
@@ -447,8 +572,9 @@ class Conv2DProcessor : public NodeProcessor {
class Conv2DBackpropFilterProcessor : public Conv2DProcessor {
public:
Conv2DBackpropFilterProcessor(GraphDef* graph, NodeDef* node,
- NodeMap* node_map, bool no_gemm)
- : Conv2DProcessor(graph, node, node_map, no_gemm) {}
+ NodeMap* node_map, bool no_gemm,
+ bool is_in_frame)
+ : Conv2DProcessor(graph, node, node_map, no_gemm, is_in_frame) {}
protected:
bool IsGemmUsed() const override {
@@ -472,8 +598,9 @@ class Conv2DBackpropFilterProcessor : public Conv2DProcessor {
class Conv2DBackpropInputProcessor : public Conv2DProcessor {
public:
Conv2DBackpropInputProcessor(GraphDef* graph, NodeDef* node,
- NodeMap* node_map, bool no_gemm)
- : Conv2DProcessor(graph, node, node_map, no_gemm) {}
+ NodeMap* node_map, bool no_gemm,
+ bool is_in_frame)
+ : Conv2DProcessor(graph, node, node_map, no_gemm, is_in_frame) {}
protected:
bool IsGemmUsed() const override {
@@ -492,8 +619,9 @@ class Conv2DBackpropInputProcessor : public Conv2DProcessor {
class FusedBatchNormGradProcessor : public NodeProcessor {
public:
- FusedBatchNormGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : NodeProcessor(graph, node, node_map) {}
+ FusedBatchNormGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : NodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
std::vector<int> GetInputPos() const override {
@@ -504,8 +632,9 @@ class FusedBatchNormGradProcessor : public NodeProcessor {
class MaxPoolGradProcessor : public NodeProcessor {
public:
- MaxPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : NodeProcessor(graph, node, node_map) {}
+ MaxPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : NodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
std::vector<int> GetInputPos() const override {
@@ -516,8 +645,9 @@ class MaxPoolGradProcessor : public NodeProcessor {
class AgnosticNodeProcessor : public NodeProcessor {
public:
- AgnosticNodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : NodeProcessor(graph, node, node_map) {}
+ AgnosticNodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : NodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
bool ShouldProcess() const override {
@@ -548,8 +678,9 @@ class AgnosticNodeProcessor : public NodeProcessor {
class AddNProcessor : public AgnosticNodeProcessor {
public:
- AddNProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : AgnosticNodeProcessor(graph, node, node_map) {}
+ AddNProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
std::vector<int> GetInputPos() const override {
@@ -564,8 +695,9 @@ class AddNProcessor : public AgnosticNodeProcessor {
class BinaryOpProcessor : public AgnosticNodeProcessor {
public:
- BinaryOpProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : AgnosticNodeProcessor(graph, node, node_map) {
+ BinaryOpProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {
is_4d_with_vector_ = Is4DOperateWithVector();
}
@@ -672,8 +804,9 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
class ConcatProcessor : public AgnosticNodeProcessor {
public:
- ConcatProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : AgnosticNodeProcessor(graph, node, node_map) {
+ ConcatProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {
// For Concat, the concat axis is the first input; for ConcatV2,
// the last input.
axis_node_pos_ =
@@ -698,8 +831,9 @@ class ConcatProcessor : public AgnosticNodeProcessor {
}
Status CustomizedProcessing() override {
- node_map_->AddOutput(kConcatConst, node_->name());
- *node_->mutable_input(axis_node_pos_) = kConcatConst;
+ string concat_const_name = GetOrAddNodeConcatConst();
+ node_map_->AddOutput(concat_const_name, node_->name());
+ *node_->mutable_input(axis_node_pos_) = concat_const_name;
return Status::OK();
}
@@ -712,12 +846,38 @@ class ConcatProcessor : public AgnosticNodeProcessor {
}
int axis_node_pos_;
+
+ private:
+ NodeDef* AddNodeConcatConst(const string& suffix, const string& depended_node,
+ const string& device) {
+ auto const_node = AddNodeConstScalar(
+ strings::StrCat(kConcatConst, "-", suffix), device, DT_INT32, 1);
+ // This is to ensure the concat node and the const node are
+ // in the same frame.
+ *const_node->add_input() = AsControlDependency(depended_node);
+ return const_node;
+ }
+
+ string GetOrAddNodeConcatConst() {
+ string const_name;
+ if (is_in_frame_) {
+ int value_node_pos = (axis_node_pos_ == 0) ? 1 : 0;
+ auto const_node = AddNodeConcatConst(
+ node_->name(), NodeName(node_->input(value_node_pos)),
+ node_->device());
+ const_name = const_node->name();
+ } else {
+ const_name = kConcatConst;
+ }
+ return const_name;
+ }
};
class ReluGradProcessor : public AgnosticNodeProcessor {
public:
- ReluGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : AgnosticNodeProcessor(graph, node, node_map) {}
+ ReluGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
std::vector<int> GetInputPos() const override {
@@ -728,8 +888,9 @@ class ReluGradProcessor : public AgnosticNodeProcessor {
class SliceProcessor : public AgnosticNodeProcessor {
public:
- SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : AgnosticNodeProcessor(graph, node, node_map) {}
+ SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
Status CustomizedProcessing() override {
@@ -749,14 +910,62 @@ class SliceProcessor : public AgnosticNodeProcessor {
}
private:
+ NodeDef* AddNodeGatherAxisConst(const string& suffix,
+ const string& depended_node,
+ const string& device) {
+ auto const_node = AddNodeConstScalar(
+ strings::StrCat(kGatherAxisConst, "-", suffix), device, DT_INT32, 0);
+ // This is to ensure the Slice node and the const node are
+ // in the same frame.
+ *const_node->add_input() = AsControlDependency(depended_node);
+ return const_node;
+ }
+
+ string GetOrAddNodeGatherAxisConst() {
+ string const_name;
+ if (is_in_frame_) {
+ auto const_node = AddNodeGatherAxisConst(
+ node_->name(), NodeName(node_->input(0)), node_->device());
+ const_name = const_node->name();
+ } else {
+ const_name = kGatherAxisConst;
+ }
+ return const_name;
+ }
+
+ string GetOrAddNodePermNHWCToNCHW() {
+ string const_name;
+ if (is_in_frame_) {
+ auto const_node = AddNodePermNHWCToNCHW(
+ node_->name(), NodeName(node_->input(0)), node_->device());
+ const_name = const_node->name();
+ } else {
+ const_name = kPermNHWCToNCHW;
+ }
+ return const_name;
+ }
+
+ string GetOrAddNodePermNCHWToNHWC() {
+ string const_name;
+ if (is_in_frame_) {
+ auto const_node = AddNodePermNCHWToNHWC(
+ node_->name(), NodeName(node_->input(0)), node_->device());
+ const_name = const_node->name();
+ } else {
+ const_name = kPermNCHWToNHWC;
+ }
+ return const_name;
+ }
+
void AddNodePermVec(const string& node_name, const string& input_name,
DataType data_type, bool NHWCToNCHW) {
NodeDef* node = graph_->add_node();
node_map_->AddNode(node_name, node);
node->set_name(node_name);
*node->add_input() = input_name;
- *node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC;
- *node->add_input() = kGatherAxisConst;
+ *node->add_input() = NHWCToNCHW ? GetOrAddNodePermNHWCToNCHW()
+ : GetOrAddNodePermNCHWToNHWC();
+ *node->add_input() = GetOrAddNodeGatherAxisConst();
node->set_op("GatherV2");
AttrValue attr_type_indices;
@@ -782,8 +991,9 @@ class SliceProcessor : public AgnosticNodeProcessor {
// before this optimization.
class SliceProcessorConst : public AgnosticNodeProcessor {
public:
- SliceProcessorConst(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : AgnosticNodeProcessor(graph, node, node_map) {}
+ SliceProcessorConst(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
Status CustomizedProcessing() override {
@@ -799,8 +1009,9 @@ class SliceProcessorConst : public AgnosticNodeProcessor {
// example use case is in the gradient computation of Concat for InceptionV3.
class SliceProcessorConcatOffset : public AgnosticNodeProcessor {
public:
- SliceProcessorConcatOffset(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : AgnosticNodeProcessor(graph, node, node_map) {}
+ SliceProcessorConcatOffset(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
Status CustomizedProcessing() override {
@@ -849,8 +1060,9 @@ class SliceProcessorConcatOffset : public AgnosticNodeProcessor {
class SqueezeProcessor : public AgnosticNodeProcessor {
public:
- SqueezeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : AgnosticNodeProcessor(graph, node, node_map) {}
+ SqueezeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
bool ShouldProcess() const override {
@@ -898,8 +1110,9 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
class SumProcessor : public AgnosticNodeProcessor {
public:
- SumProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
- : AgnosticNodeProcessor(graph, node, node_map) {}
+ SumProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ bool is_in_frame)
+ : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {}
protected:
bool ShouldProcess() const override {
@@ -913,7 +1126,7 @@ class SumProcessor : public AgnosticNodeProcessor {
Status CustomizedProcessing() override {
node_map_->AddOutput(kReductionConst, node_->name());
- *node_->mutable_input(1) = kReductionConst;
+ *node_->mutable_input(1) = GetOrAddNodeReductionConst();
return Status::OK();
}
@@ -938,6 +1151,29 @@ class SumProcessor : public AgnosticNodeProcessor {
}
return false;
}
+
+ NodeDef* AddNodeReductionConst(const string& suffix,
+ const string& depended_node,
+ const string& device) {
+ auto const_node = GraphProcessor::AddNodeReductionConst(
+ strings::StrCat(kReductionConst, "-", suffix), device);
+ // This is to ensure the Sum node and the const node are in the
+ // same frame.
+ *const_node->add_input() = AsControlDependency(depended_node);
+ return const_node;
+ }
+
+ string GetOrAddNodeReductionConst() {
+ string const_name;
+ if (is_in_frame_) {
+ auto const_node = AddNodeReductionConst(
+ node_->name(), NodeName(node_->input(0)), node_->device());
+ const_name = const_node->name();
+ } else {
+ const_name = kReductionConst;
+ }
+ return const_name;
+ }
};
struct TuningConfig {
@@ -951,13 +1187,12 @@ struct TuningConfig {
bool no_gemm;
};
-class DataLayoutOptimizer {
+class DataLayoutOptimizer : GraphProcessor {
public:
explicit DataLayoutOptimizer(const string& default_device, GraphDef* graph,
- TuningConfig config)
- : default_device_(default_device),
- graph_(graph),
- node_map_(graph_),
+ NodeMap* node_map, TuningConfig config)
+ : GraphProcessor(graph, node_map),
+ default_device_(default_device),
config_(config) {}
Status Optimize() {
@@ -970,105 +1205,65 @@ class DataLayoutOptimizer {
}
private:
- NodeDef* AddNodePermConst(const string& name,
- const std::vector<int>& permutation) {
- NodeDef* node = graph_->add_node();
- node_map_.AddNode(name, node);
- node->set_name(name);
- node->set_op("Const");
- node->set_device(default_device_);
- AttrValue attr_data_type;
- attr_data_type.set_type(DT_INT32);
- node->mutable_attr()->insert({"dtype", attr_data_type});
- AttrValue attr_tensor;
- Tensor tensor(DT_INT32, TensorShape({4}));
- for (int i = 0; static_cast<size_t>(i) < permutation.size(); i++) {
- tensor.flat<int>()(i) = permutation[i];
- }
- tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
- node->mutable_attr()->insert({"value", attr_tensor});
- return node;
+ NodeDef* AddNodePermNHWCToNCHW() {
+ return AddNodePermConst(kPermNHWCToNCHW, default_device_, {0, 3, 1, 2});
}
- NodeDef* AddConstScalar(const char* name, DataType dtype, int value) {
- NodeDef* node = graph_->add_node();
- node_map_.AddNode(name, node);
- node->set_name(name);
- node->set_op("Const");
- node->set_device(default_device_);
- AttrValue attr_data_type;
- attr_data_type.set_type(dtype);
- node->mutable_attr()->insert({"dtype", attr_data_type});
- AttrValue attr_tensor;
- Tensor tensor(dtype, TensorShape({}));
- tensor.scalar<int>()() = value;
- tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
- node->mutable_attr()->insert({"value", attr_tensor});
- return node;
+ NodeDef* AddNodePermNCHWToNHWC() {
+ return AddNodePermConst(kPermNCHWToNHWC, default_device_, {0, 2, 3, 1});
}
NodeDef* AddNodeConcatConst() {
- return AddConstScalar(kConcatConst, DT_INT32, 1);
+ return AddNodeConstScalar(kConcatConst, default_device_, DT_INT32, 1);
}
- NodeDef* AddGatherAxisConst() {
- return AddConstScalar(kGatherAxisConst, DT_INT32, 0);
+ NodeDef* AddNodeGatherAxisConst() {
+ return AddNodeConstScalar(kGatherAxisConst, default_device_, DT_INT32, 0);
}
NodeDef* AddNodeReductionConst() {
- NodeDef* node = graph_->add_node();
- node_map_.AddNode(kReductionConst, node);
- node->set_name(kReductionConst);
- node->set_op("Const");
- node->set_device(default_device_);
- AttrValue attr_data_type;
- attr_data_type.set_type(DT_INT32);
- node->mutable_attr()->insert({"dtype", attr_data_type});
-
- AttrValue attr_tensor;
- Tensor tensor(DT_INT32, TensorShape({3}));
- std::vector<int> axis = {0, 2, 3};
- for (int i = 0; static_cast<size_t>(i) < axis.size(); i++) {
- tensor.flat<int>()(i) = axis[i];
- }
- tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
- node->mutable_attr()->insert({"value", attr_tensor});
- return node;
+ return GraphProcessor::AddNodeReductionConst(kReductionConst,
+ default_device_);
}
// Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
Status Expand() {
int node_size_original = graph_->node_size();
+ std::unordered_map<const NodeDef*, std::vector<int>> frames;
+ IdentifyFrames(*graph_, &frames);
+
// This is the first pass where we expand the nodes which support NCHW.
std::set<string> ops_format_supported = GetOpsFormatSupported();
- for (int i = 0; i < graph_->node_size(); i++) {
+ for (int i = 0; i < node_size_original; i++) {
if (ops_format_supported.find(graph_->node(i).op()) !=
ops_format_supported.end()) {
auto node = graph_->mutable_node(i);
+ bool is_in_frame = !frames[node].empty();
std::unique_ptr<NodeProcessor> node_processor;
if (node->op().compare("AvgPoolGrad") == 0) {
node_processor.reset(
- new AvgPoolGradProcessor(graph_, node, &node_map_));
+ new AvgPoolGradProcessor(graph_, node, node_map_, is_in_frame));
} else if (node->op().compare("BiasAddGrad") == 0) {
node_processor.reset(
- new BiasAddGradProcessor(graph_, node, &node_map_));
+ new BiasAddGradProcessor(graph_, node, node_map_, is_in_frame));
} else if (node->op().compare("Conv2D") == 0) {
- node_processor.reset(
- new Conv2DProcessor(graph_, node, &node_map_, config_.no_gemm));
+ node_processor.reset(new Conv2DProcessor(
+ graph_, node, node_map_, config_.no_gemm, is_in_frame));
} else if (node->op().compare("Conv2DBackpropFilter") == 0) {
node_processor.reset(new Conv2DBackpropFilterProcessor(
- graph_, node, &node_map_, config_.no_gemm));
+ graph_, node, node_map_, config_.no_gemm, is_in_frame));
} else if (node->op().compare("Conv2DBackpropInput") == 0) {
node_processor.reset(new Conv2DBackpropInputProcessor(
- graph_, node, &node_map_, config_.no_gemm));
+ graph_, node, node_map_, config_.no_gemm, is_in_frame));
} else if (node->op().compare("FusedBatchNormGrad") == 0) {
- node_processor.reset(
- new FusedBatchNormGradProcessor(graph_, node, &node_map_));
+ node_processor.reset(new FusedBatchNormGradProcessor(
+ graph_, node, node_map_, is_in_frame));
} else if (node->op().compare("MaxPoolGrad") == 0) {
node_processor.reset(
- new MaxPoolGradProcessor(graph_, node, &node_map_));
+ new MaxPoolGradProcessor(graph_, node, node_map_, is_in_frame));
} else {
- node_processor.reset(new NodeProcessor(graph_, node, &node_map_));
+ node_processor.reset(
+ new NodeProcessor(graph_, node, node_map_, is_in_frame));
}
TF_RETURN_IF_ERROR(node_processor->ConvertNode());
}
@@ -1078,54 +1273,57 @@ class DataLayoutOptimizer {
// only needs to be performed if at least one node in the previous pass is
// expanded.
if (graph_->node_size() > node_size_original) {
- NodeDef* n = AddNodePermConst(kPermNHWCToNCHW, {0, 3, 1, 2});
- n = AddNodePermConst(kPermNCHWToNHWC, {0, 2, 3, 1});
+ NodeDef* n = AddNodePermNHWCToNCHW();
+ n = AddNodePermNCHWToNHWC();
n = AddNodeConcatConst();
- n = AddGatherAxisConst();
+ n = AddNodeGatherAxisConst();
n = AddNodeReductionConst();
std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
for (int i = 0; i < graph_->node_size(); i++) {
if (ops_format_agnostic.find(graph_->node(i).op()) !=
ops_format_agnostic.end()) {
auto node = graph_->mutable_node(i);
+ bool is_in_frame = !frames[node].empty();
std::unique_ptr<NodeProcessor> node_processor;
if (node->op().compare("AddN") == 0) {
- node_processor.reset(new AddNProcessor(graph_, node, &node_map_));
+ node_processor.reset(
+ new AddNProcessor(graph_, node, node_map_, is_in_frame));
} else if (node->op().compare("Add") == 0 ||
node->op().compare("Mul") == 0 ||
node->op().compare("RealDiv") == 0 ||
node->op().compare("SquaredDifference") == 0 ||
node->op().compare("Sub") == 0) {
node_processor.reset(
- new BinaryOpProcessor(graph_, node, &node_map_));
+ new BinaryOpProcessor(graph_, node, node_map_, is_in_frame));
} else if (node->op().compare("Concat") == 0 ||
node->op().compare("ConcatV2") == 0) {
- node_processor.reset(new ConcatProcessor(graph_, node, &node_map_));
+ node_processor.reset(
+ new ConcatProcessor(graph_, node, node_map_, is_in_frame));
} else if (node->op().compare("ReluGrad") == 0) {
node_processor.reset(
- new ReluGradProcessor(graph_, node, &node_map_));
+ new ReluGradProcessor(graph_, node, node_map_, is_in_frame));
} else if (node->op().compare("Slice") == 0) {
- auto input1 = node_map_.GetNode(NodeName(node->input(1)));
- auto input2 = node_map_.GetNode(NodeName(node->input(2)));
+ auto input1 = node_map_->GetNode(NodeName(node->input(1)));
+ auto input2 = node_map_->GetNode(NodeName(node->input(2)));
if (input1->op() == "ConcatOffset") {
- node_processor.reset(
- new SliceProcessorConcatOffset(graph_, node, &node_map_));
+ node_processor.reset(new SliceProcessorConcatOffset(
+ graph_, node, node_map_, is_in_frame));
} else if (input1->op() == "Const" && input2->op() == "Const") {
- node_processor.reset(
- new SliceProcessorConst(graph_, node, &node_map_));
+ node_processor.reset(new SliceProcessorConst(
+ graph_, node, node_map_, is_in_frame));
} else {
node_processor.reset(
- new SliceProcessor(graph_, node, &node_map_));
+ new SliceProcessor(graph_, node, node_map_, is_in_frame));
}
-
} else if (node->op().compare("Squeeze") == 0) {
node_processor.reset(
- new SqueezeProcessor(graph_, node, &node_map_));
+ new SqueezeProcessor(graph_, node, node_map_, is_in_frame));
} else if (node->op().compare("Sum") == 0) {
- node_processor.reset(new SumProcessor(graph_, node, &node_map_));
- } else {
node_processor.reset(
- new AgnosticNodeProcessor(graph_, node, &node_map_));
+ new SumProcessor(graph_, node, node_map_, is_in_frame));
+ } else {
+ node_processor.reset(new AgnosticNodeProcessor(
+ graph_, node, node_map_, is_in_frame));
}
TF_RETURN_IF_ERROR(node_processor->ConvertNode());
}
@@ -1145,12 +1343,12 @@ class DataLayoutOptimizer {
if (IsNodeNCHWToNHWC(node->input(0))) {
const string& trans_first = node->input(0);
const string& trans_second = node->name();
- auto outputs = node_map_.GetOutputs(trans_second);
+ auto outputs = node_map_->GetOutputs(trans_second);
CHECK(outputs.size() == 1)
<< "There is always only a single output for a Transpose node, "
<< "due to the way it is added by NodeProcessor.";
NodeDef* output = *outputs.begin();
- string input = node_map_.GetNode(trans_first)->input(0);
+ string input = node_map_->GetNode(trans_first)->input(0);
for (int i = 0; i < output->input_size(); i++) {
if (output->input(i).compare(trans_second) == 0) {
*output->mutable_input(i) = input;
@@ -1173,8 +1371,6 @@ class DataLayoutOptimizer {
}
string default_device_;
- GraphDef* graph_;
- NodeMap node_map_;
TuningConfig config_;
};
@@ -1231,8 +1427,9 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
default_device = cluster->GetDevices().begin()->first;
}
}
+ std::unique_ptr<NodeMap> node_map(new NodeMap(output));
std::unique_ptr<DataLayoutOptimizer> layout_optimizer(
- new DataLayoutOptimizer(default_device, output, config));
+ new DataLayoutOptimizer(default_device, output, node_map.get(), config));
status = layout_optimizer->Optimize();
// This is based on an empirical observation that if the introduced Transpose
// nodes is more than 30, not using GEMM implementation would result in better
@@ -1240,8 +1437,9 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (status.ok() && GetNumTranspose(*output) > 30) {
*output = new_item.graph;
config.no_gemm = true;
- layout_optimizer.reset(
- new DataLayoutOptimizer(default_device, output, config));
+ node_map.reset(new NodeMap(output));
+ layout_optimizer.reset(new DataLayoutOptimizer(default_device, output,
+ node_map.get(), config));
status = layout_optimizer->Optimize();
}
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 948df18879..9e15744fab 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -177,5 +177,13 @@ bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms,
return notified;
}
+string AsControlDependency(const NodeDef& node) {
+ return strings::StrCat("^", node.name());
+}
+
+string AsControlDependency(const string& node) {
+ return strings::StrCat("^", node);
+}
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 4a8cb573d8..a9eccd685b 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -85,6 +85,14 @@ string AddPrefixToNodeName(const string& name, const string& prefix);
bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms,
thread::ThreadPool* thread_pool);
+// Returns the node name prefixed with conventional symbol '^'
+// for control dependency, given a NodeDef.
+string AsControlDependency(const NodeDef& node);
+//
+// Returns the node name prefixed with conventional symbol '^'
+// for control dependency, given a node name
+string AsControlDependency(const string& node);
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 5dbaf76edb..bda9502cd1 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -22,8 +22,10 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -51,9 +53,7 @@ def max_pool_2x2(x):
# Taken from tensorflow/examples/tutorials/mnist/mnist_deep.py
-def two_layer_model():
- random_seed.set_random_seed(0)
- x = random_ops.truncated_normal([1, 784], seed=0)
+def two_layer_model(x):
x_image = array_ops.reshape(x, [-1, 28, 28, 1])
w_conv1 = weight([5, 5, 1, 32])
b_conv1 = bias([32])
@@ -66,24 +66,39 @@ def two_layer_model():
return h_pool2
+def loop():
+ random_seed.set_random_seed(0)
+ x1 = random_ops.truncated_normal([1, 784], seed=0)
+ x2 = random_ops.truncated_normal([1, 784], seed=0)
+ x3 = random_ops.truncated_normal([1, 784], seed=0)
+ x4 = random_ops.truncated_normal([1, 784], seed=0)
+ elems = (x1, x2, x3, x4)
+ outputs = functional_ops.map_fn(two_layer_model, elems, dtype=dtypes.float32)
+ return outputs
+
+
+def get_config():
+ rewrite_options = rewriter_config_pb2.RewriterConfig(
+ optimize_tensor_layout=True)
+ graph_options = config_pb2.GraphOptions(
+ rewrite_options=rewrite_options, build_cost_model=1)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+ return config
+
+
class LayoutOptimizerTest(test.TestCase):
"""Tests the Grappler layout optimizer."""
def testTwoConvLayers(self):
if test.is_gpu_available(cuda_only=True):
- output = two_layer_model()
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ output = two_layer_model(x)
with session.Session() as sess:
output_val_ref = sess.run(output)
- rewrite_options = rewriter_config_pb2.RewriterConfig(
- optimize_tensor_layout=True)
- graph_options = config_pb2.GraphOptions(
- rewrite_options=rewrite_options,
- build_cost_model=1)
- config = config_pb2.ConfigProto(graph_options=graph_options)
-
- with session.Session(config=config) as sess:
+ with session.Session(config=get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
@@ -105,6 +120,19 @@ class LayoutOptimizerTest(test.TestCase):
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+ def testLoop(self):
+ if test.is_gpu_available(cuda_only=True):
+ output = loop()
+
+ with session.Session() as sess:
+ output_val_ref = sess.run(output)
+
+ with session.Session(config=get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(output, run_metadata=metadata)
+
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
if __name__ == '__main__':
test.main()