aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_layout_pass.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-30 10:05:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-30 12:27:47 -0800
commit7149a2e2e2f549035f23e21224ee41afe8df3876 (patch)
tree4fab32a87362e9708d07f388154a10ccb0c7800b /tensorflow/core/graph/mkl_layout_pass.cc
parent88eb6c61ef7659c2b5bb1ec6586c7d3cca5e4e9c (diff)
Cleanup: Ran clang-format on files in tensorflow/core/.../*.{cc,h}.
PiperOrigin-RevId: 183848459
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc348
1 files changed, 170 insertions, 178 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 55bc401b9d..68c3136019 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -37,8 +37,8 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
-#include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
+#include "tensorflow/core/graph/mkl_layout_pass.h"
namespace tensorflow {
@@ -281,7 +281,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
csinfo_.mkl_conv2d_with_bias_backprop_bias =
- "_MklConv2DWithBiasBackpropBias";
+ "_MklConv2DWithBiasBackpropBias";
csinfo_.relu = "Relu";
csinfo_.relu_grad = "ReluGrad";
csinfo_.reshape = "Reshape";
@@ -297,10 +297,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// End - element-wise ops. See note above.
// NOTE: names are alphabetically sorted.
- rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN,
- AddNRewrite, nullptr});
- rinfo_.push_back({csinfo_.add,
- mkl_op_registry::GetMklOpName(csinfo_.add),
+ rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
+ CopyAttrsAddN, AddNRewrite, nullptr});
+ rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.avg_pool,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
@@ -337,14 +336,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.fused_batch_norm_grad,
- mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
- CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
+ rinfo_.push_back(
+ {csinfo_.fused_batch_norm_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
+ CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.identity,
mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsIdentity, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.lrn,
- mkl_op_registry::GetMklOpName(csinfo_.lrn),
+ rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
CopyAttrsLRN, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.lrn_grad,
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
@@ -358,11 +357,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum),
CopyAttrsDataType, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.mul,
- mkl_op_registry::GetMklOpName(csinfo_.mul),
+ rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
CopyAttrsDataType, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.relu,
- mkl_op_registry::GetMklOpName(csinfo_.relu),
+ rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.relu_grad,
mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
@@ -373,8 +370,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.squared_difference,
mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
CopyAttrsDataType, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.sub,
- mkl_op_registry::GetMklOpName(csinfo_.sub),
+ rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
CopyAttrsDataType, AlwaysRewrite, nullptr});
// Add info about which ops to add workspace edge to and the slots.
@@ -388,9 +384,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
IsBiasAddGradInMatMulContext};
- biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad,
- csinfo_.mkl_conv2d_with_bias,
- IsBiasAddGradInConv2DWithBiasContext};
+ biasaddgrad_conv2dwithbias_context_ = {
+ csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias,
+ IsBiasAddGradInConv2DWithBiasContext};
cinfo_.push_back(&biasaddgrad_matmul_context_);
cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
@@ -410,9 +406,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
/// Structure to specify the context information used in a node rewrite rule
typedef struct {
- string node; // Name of the node to be rewritten
- string fwd; // Name of the node in the forward pass that this node
- // corresponds to
+ string node; // Name of the node to be rewritten
+ string fwd; // Name of the node in the forward pass that this node
+ // corresponds to
std::function<bool(const Node*, const Node**, void* c)> context_match_fn;
} ContextInfo;
@@ -615,14 +611,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
std::vector<int32> ksize, strides;
CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
- CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(),
- true);
+ CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
// Condition that specifies non-batch-wise and non-depth-wise pooling.
- if (GetTensorDim(ksize, data_format, 'N') == 1 &&
+ if (GetTensorDim(ksize, data_format, 'N') == 1 &&
GetTensorDim(strides, data_format, 'N') == 1 &&
- GetTensorDim(ksize, data_format, 'C') == 1 &&
+ GetTensorDim(ksize, data_format, 'C') == 1 &&
GetTensorDim(strides, data_format, 'C') == 1) {
return true;
}
@@ -785,8 +780,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
for (const Edge* fe : first_inp_of_filter->out_edges()) {
if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
fe->dst_input() == 0) {
- VLOG(1) << "MklLayoutRewritePass: found "
- << fe->dst()->DebugString()
+ VLOG(1) << "MklLayoutRewritePass: found " << fe->dst()->DebugString()
<< " as the forward node for matching context, backward"
<< " node is: " << n->DebugString();
*fwd_node = fe->dst();
@@ -803,13 +797,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
//
// @return - true (if BiasAddGrad is associated with MatMul);
// false otherwise.
- static bool IsBiasAddGradInMatMulContext(const Node* n,
- const Node** fwd_node,
+ static bool IsBiasAddGradInMatMulContext(const Node* n, const Node** fwd_node,
void* ci) {
return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci));
}
-
// Rewrite rule that uses context-information for matching,
// used in scenario 2.
//
@@ -880,10 +872,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// @output output_nodes - the list of new nodes creating Mkl tensors
//
// @return None
- void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g,
- Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
- int* input_idx, int list_length,
- std::vector<NodeBuilder::NodeOut>* output_nodes);
+ void GetNodesProducingMklTensorList(
+ std::unique_ptr<Graph>* g, Node* orig_node,
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
+ int* input_idx, int list_length,
+ std::vector<NodeBuilder::NodeOut>* output_nodes);
// Get a node that will feed an Mkl tensor to the new
// node that we are constructing. The output node could be (1) 'n'
@@ -900,7 +893,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// will feed the tensor
// @return None
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
- Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot);
+ Node* n, int n_output_slot, Node** mkl_node,
+ int* mkl_node_output_slot);
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
@@ -970,9 +964,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
MklLayoutRewritePass::ContextInfo
- MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
+ MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
MklLayoutRewritePass::ContextInfo
- MklLayoutRewritePass::biasaddgrad_matmul_context_;
+ MklLayoutRewritePass::biasaddgrad_matmul_context_;
std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
// We register Mkl rewrite pass for phase 1 in post partitioning group.
@@ -1041,13 +1035,13 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
TensorShape dummy_shape({8});
dummy_shape.AsProto(proto.mutable_tensor_shape());
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orig_node->def().device()) // We place this node on
- // the same device as the
- // device of the original
- // node.
- .Finalize(&**g, out));
+ .Attr("value", proto)
+ .Attr("dtype", dt)
+ .Device(orig_node->def().device()) // We place this node on
+ // the same device as the
+ // device of the original
+ // node.
+ .Finalize(&**g, out));
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -1060,8 +1054,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
// the same frame.
if (orig_node->num_inputs() > 0) {
Node* orig_input0 = nullptr;
- TF_CHECK_OK(orig_node->input_node(0,
- const_cast<const Node**>(&orig_input0)));
+ TF_CHECK_OK(
+ orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
}
@@ -1069,11 +1063,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
}
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
- std::unique_ptr<Graph>* g,
- Node* orig_node,
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
- int* input_idx, int list_length,
- std::vector<NodeBuilder::NodeOut>* output_nodes) {
+ std::unique_ptr<Graph>* g, Node* orig_node,
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
+ int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
CHECK_LT(*input_idx, inputs.size());
CHECK_GT(list_length, 0);
CHECK_NOTNULL(output_nodes);
@@ -1090,8 +1082,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
int mkl_node_output_slot = 0;
GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
&mkl_node_output_slot);
- output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
- mkl_node_output_slot));
+ output_nodes->push_back(
+ NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
(*input_idx)++;
list_length--;
}
@@ -1101,9 +1093,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
// node that we are constructing. An input node could be (1) 'n'
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
// if 'n' is not an Mkl layer.
-void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
- Node* orig_node, Node* n,
- int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
+void MklLayoutRewritePass::GetNodeProducingMklTensor(
+ std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
+ Node** mkl_node, int* mkl_node_output_slot) {
CHECK_NOTNULL(n);
CHECK_NOTNULL(mkl_node);
CHECK_NOTNULL(mkl_node_output_slot);
@@ -1234,8 +1226,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node);
- GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
- N, &new_node_inputs);
+ GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
+ &new_node_inputs);
nb->Input(new_node_inputs);
nn_slot_idx++;
} else {
@@ -1336,13 +1328,13 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
TensorShape dummy_shape({1});
dummy_shape.AsProto(proto.mutable_tensor_shape());
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orig_node->def().device()) // We place this node on
- // same the device as the
- // device of the original
- // node.
- .Finalize(&**g, out));
+ .Attr("value", proto)
+ .Attr("dtype", dt)
+ .Device(orig_node->def().device()) // We place this node on
+ // same the device as the
+ // device of the original
+ // node.
+ .Finalize(&**g, out));
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -1355,8 +1347,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
// the same frame.
if (orig_node->num_inputs() > 0) {
Node* orig_input0 = nullptr;
- TF_CHECK_OK(orig_node->input_node(0,
- const_cast<const Node**>(&orig_input0)));
+ TF_CHECK_OK(
+ orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
}
@@ -1374,7 +1366,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
for (auto ws : wsinfo_) {
if (orig_node->type_string() == ws.fwd_op &&
- mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
+ mkl_op_registry::IsMklOp(
+ mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
// If this op is a fwd op, then we need to check if there is an
// edge from this node's fwd_slot to bwdop's bwd_slot. If there is
// an edge, then we just add an attribute on this node for setting
@@ -1400,8 +1393,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
nb->Attr("workspace_enabled", false);
}
} else if (orig_node->type_string() == ws.bwd_op &&
- mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()),
- T)) {
+ mkl_op_registry::IsMklOp(
+ mkl_op_registry::GetMklOpName(orig_node->type_string()),
+ T)) {
// If this op is a bwd op, then we need to add workspace edge and
// it's Mkl tensor edge between its corresponding fwd op and this
// op. Corresponding fwd op is specified in 'fwd_op' field of
@@ -1416,7 +1410,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
if (e->src_output() == ws.fwd_slot &&
// We would have rewritten the forward op, so we need to use
// GetMklOpName call to get its Mkl name.
- e->src()->type_string() == mkl_op_registry::GetMklOpName(ws.fwd_op) &&
+ e->src()->type_string() ==
+ mkl_op_registry::GetMklOpName(ws.fwd_op) &&
e->dst_input() == ws.bwd_slot) {
nb->Attr("workspace_enabled", true);
CHECK_NOTNULL(ws_tensors);
@@ -1593,7 +1588,7 @@ void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
}
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
- NodeBuilder* nb) {
+ NodeBuilder* nb) {
DataType T;
DataType Tshape;
@@ -1869,8 +1864,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
if (e->IsControlEdge()) {
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
} else {
- CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(),
- e->dst_input()));
+ CHECK_NOTNULL(
+ (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()));
}
}
@@ -1941,9 +1936,9 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
// and leave BiasAddGrad as it is. But we check for this condition
// when we check for node rewrite rule. So we should not even come
// here for MatMul. So we will fail now.
- return Status(
- error::Code::INVALID_ARGUMENT,
- "No rewrite is required for BiasAddGrad for MatMul context.");
+ return Status(
+ error::Code::INVALID_ARGUMENT,
+ "No rewrite is required for BiasAddGrad for MatMul context.");
}
}
@@ -2012,9 +2007,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
if (e->IsControlEdge()) {
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
} else {
- CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(),
- e->src()->num_outputs()),
- e->dst(), e->dst_input()));
+ CHECK_NOTNULL((*g)->AddEdge(
+ new_node,
+ GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
+ e->dst(), e->dst_input()));
}
}
@@ -2070,7 +2066,8 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
// BiasAddGrad is not an Mkl layer, so we make an exception for it.
if (n->type_string() != csinfo_.bias_add_grad) {
- if (!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) {
+ if (!mkl_op_registry::IsMklOp(
+ mkl_op_registry::GetMklOpName(n->type_string()), T)) {
return nullptr;
}
}
@@ -2186,8 +2183,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
return MklLayoutRewritePass().RunPass(g);
}
-Status MklLayoutRewritePass::Run(
- const GraphOptimizationPassOptions& options) {
+Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr && options.partition_graphs == nullptr) {
return Status::OK();
}
@@ -2215,7 +2211,7 @@ Status MklLayoutRewritePass::Run(
return Status::OK();
}
-#else // INTEL_MKL_DNN
+#else // INTEL_MKL_DNN
// This pass implements rewriting of graph to support following scenarios:
// (A) Merging nodes in the graph
@@ -2421,7 +2417,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
csinfo_.conv2d_grad_filter_with_bias =
- "__MklDummyConv2DBackpropFilterWithBias";
+ "__MklDummyConv2DBackpropFilterWithBias";
csinfo_.fused_batch_norm = "FusedBatchNorm";
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
csinfo_.identity = "Identity";
@@ -2435,11 +2431,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
csinfo_.mkl_conv2d_grad_filter_with_bias =
- "_MklConv2DBackpropFilterWithBias";
+ "_MklConv2DBackpropFilterWithBias";
csinfo_.relu = "Relu";
csinfo_.relu_grad = "ReluGrad";
- csinfo_.tanh = "Tanh";
- csinfo_.tanh_grad = "TanhGrad";
+ csinfo_.tanh = "Tanh";
+ csinfo_.tanh_grad = "TanhGrad";
csinfo_.reshape = "Reshape";
csinfo_.softmax = "Softmax";
csinfo_.split = "Split";
@@ -2474,29 +2470,28 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.conv2d,
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
CopyAttrsConv2D, AlwaysRewrite});
- rinfo_.push_back({csinfo_.conv2d_with_bias,
- csinfo_.mkl_conv2d_with_bias,
+ rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
CopyAttrsConv2D, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
CopyAttrsConv2D, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
- csinfo_.mkl_conv2d_grad_filter_with_bias,
- CopyAttrsConv2D, AlwaysRewrite});
+ csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D,
+ AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
CopyAttrsConv2D, AlwaysRewrite});
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite});
- rinfo_.push_back({csinfo_.fused_batch_norm_grad,
- mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
- CopyAttrsFusedBatchNorm, AlwaysRewrite});
+ rinfo_.push_back(
+ {csinfo_.fused_batch_norm_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
+ CopyAttrsFusedBatchNorm, AlwaysRewrite});
rinfo_.push_back({csinfo_.identity,
mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsDataType, AlwaysRewrite});
- rinfo_.push_back({csinfo_.lrn,
- mkl_op_registry::GetMklOpName(csinfo_.lrn),
+ rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
CopyAttrsLRN, AlwaysRewrite});
rinfo_.push_back({csinfo_.lrn_grad,
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
@@ -2515,8 +2510,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
mkl_op_registry::GetMklOpName(csinfo_.mul),
CopyAttrsDataType, AlwaysRewrite});
*/
- rinfo_.push_back({csinfo_.relu,
- mkl_op_registry::GetMklOpName(csinfo_.relu),
+ rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
CopyAttrsDataType, AlwaysRewrite});
rinfo_.push_back({csinfo_.relu_grad,
mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
@@ -2550,8 +2544,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Add a rule for merging nodes
minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
- csinfo_.conv2d_with_bias,
- GetConv2DOrBiasAdd});
+ csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
csinfo_.conv2d_grad_filter_with_bias,
@@ -2846,9 +2839,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Default rewrite rule to be used in scenario 1 for rewrite.
// @return - true (since we want to always rewrite)
- static bool AlwaysRewrite(const Node* n) {
- return true;
- }
+ static bool AlwaysRewrite(const Node* n) { return true; }
// Check if we are performing pooling on depth or batch. If it is, then we
// do not rewrite MaxPool node to Mkl version.
@@ -2862,14 +2853,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
std::vector<int32> ksize, strides;
CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
- CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(),
- true);
+ CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
// Condition that specifies non-batch-wise and non-depth-wise pooling.
- if (GetTensorDim(ksize, data_format, 'N') == 1 &&
+ if (GetTensorDim(ksize, data_format, 'N') == 1 &&
GetTensorDim(strides, data_format, 'N') == 1 &&
- GetTensorDim(ksize, data_format, 'C') == 1 &&
+ GetTensorDim(ksize, data_format, 'C') == 1 &&
GetTensorDim(strides, data_format, 'C') == 1) {
return true;
}
@@ -2941,10 +2931,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// @output output_nodes - the list of new nodes creating Mkl tensors
//
// @return None
- void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g,
- Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
- int* input_idx, int list_length,
- std::vector<NodeBuilder::NodeOut>* output_nodes);
+ void GetNodesProducingMklTensorList(
+ std::unique_ptr<Graph>* g, Node* orig_node,
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
+ int* input_idx, int list_length,
+ std::vector<NodeBuilder::NodeOut>* output_nodes);
// Get a node that will feed an Mkl tensor to the new
// node that we are constructing. The output node could be (1) 'n'
@@ -2961,7 +2952,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// will feed the tensor
// @return None
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
- Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot);
+ Node* n, int n_output_slot, Node** mkl_node,
+ int* mkl_node_output_slot);
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
@@ -3096,13 +3088,13 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
TensorShape dummy_shape({8});
dummy_shape.AsProto(proto.mutable_tensor_shape());
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orig_node->def().device()) // We place this node on
- // the same device as the
- // device of the original
- // node.
- .Finalize(&**g, out));
+ .Attr("value", proto)
+ .Attr("dtype", dt)
+ .Device(orig_node->def().device()) // We place this node on
+ // the same device as the
+ // device of the original
+ // node.
+ .Finalize(&**g, out));
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -3115,8 +3107,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
// the same frame.
if (orig_node->num_inputs() > 0) {
Node* orig_input0 = nullptr;
- TF_CHECK_OK(orig_node->input_node(0,
- const_cast<const Node**>(&orig_input0)));
+ TF_CHECK_OK(
+ orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
// Allow duplicate while adding control edge as it would fail (return
// NULL) if we try to add duplicate edge.
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
@@ -3126,11 +3118,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
}
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
- std::unique_ptr<Graph>* g,
- Node* orig_node,
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
- int* input_idx, int list_length,
- std::vector<NodeBuilder::NodeOut>* output_nodes) {
+ std::unique_ptr<Graph>* g, Node* orig_node,
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
+ int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
CHECK_LT(*input_idx, inputs.size());
CHECK_GT(list_length, 0);
CHECK_NOTNULL(output_nodes);
@@ -3147,8 +3137,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
int mkl_node_output_slot = 0;
GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
&mkl_node_output_slot);
- output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
- mkl_node_output_slot));
+ output_nodes->push_back(
+ NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
(*input_idx)++;
list_length--;
}
@@ -3158,9 +3148,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
// node that we are constructing. An input node could be (1) 'n'
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
// if 'n' is not an Mkl layer.
-void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
- Node* orig_node, Node* n,
- int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
+void MklLayoutRewritePass::GetNodeProducingMklTensor(
+ std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
+ Node** mkl_node, int* mkl_node_output_slot) {
CHECK_NOTNULL(n);
CHECK_NOTNULL(mkl_node);
CHECK_NOTNULL(mkl_node_output_slot);
@@ -3292,8 +3282,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node);
- GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
- N, &new_node_inputs);
+ GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
+ &new_node_inputs);
nb->Input(new_node_inputs);
nn_slot_idx++;
} else {
@@ -3394,13 +3384,13 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
TensorShape dummy_shape({1});
dummy_shape.AsProto(proto.mutable_tensor_shape());
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orig_node->def().device()) // We place this node on
- // same the device as the
- // device of the original
- // node.
- .Finalize(&**g, out));
+ .Attr("value", proto)
+ .Attr("dtype", dt)
+ .Device(orig_node->def().device()) // We place this node on
+ // same the device as the
+ // device of the original
+ // node.
+ .Finalize(&**g, out));
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -3413,8 +3403,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
// the same frame.
if (orig_node->num_inputs() > 0) {
Node* orig_input0 = nullptr;
- TF_CHECK_OK(orig_node->input_node(0,
- const_cast<const Node**>(&orig_input0)));
+ TF_CHECK_OK(
+ orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
// Allow duplicate while adding control edge as it would fail (return
// NULL) if we try to add duplicate edge.
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
@@ -3434,8 +3424,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
for (auto ws : wsinfo_) {
if (orig_node->type_string() == ws.fwd_op &&
- mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(
- orig_node->type_string()), T)) {
+ mkl_op_registry::IsMklOp(
+ mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
// If this op is a fwd op, then we need to check if there is an
// edge from this node's fwd_slot to bwdop's bwd_slot. If there is
// an edge, then we just add an attribute on this node for setting
@@ -3461,8 +3451,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
nb->Attr("workspace_enabled", false);
}
} else if (orig_node->type_string() == ws.bwd_op &&
- mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(
- orig_node->type_string()), T)) {
+ mkl_op_registry::IsMklOp(
+ mkl_op_registry::GetMklOpName(orig_node->type_string()),
+ T)) {
// If this op is a bwd op, then we need to add workspace edge and
// it's Mkl tensor edge between its corresponding fwd op and this
// op. Corresponding fwd op is specified in 'fwd_op' field of
@@ -3477,8 +3468,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
if (e->src_output() == ws.fwd_slot &&
// We would have rewritten the forward op, so we need to use
// GetMklOpName call to get its Mkl name.
- e->src()->type_string() == mkl_op_registry::GetMklOpName(
- ws.fwd_op) &&
+ e->src()->type_string() ==
+ mkl_op_registry::GetMklOpName(ws.fwd_op) &&
e->dst_input() == ws.bwd_slot) {
nb->Attr("workspace_enabled", true);
CHECK_NOTNULL(ws_tensors);
@@ -3645,7 +3636,7 @@ void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
}
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
- NodeBuilder* nb) {
+ NodeBuilder* nb) {
DataType T;
DataType Tshape;
@@ -3776,8 +3767,9 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
Node* m, Node* n) {
CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
n->type_string() == csinfo_.conv2d)) ||
- ((n->type_string() == csinfo_.bias_add &&
- m->type_string() == csinfo_.conv2d)), true);
+ ((n->type_string() == csinfo_.bias_add &&
+ m->type_string() == csinfo_.conv2d)),
+ true);
// If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
// BiasAdd is successor node, and Conv2D predecessor node.
@@ -3796,8 +3788,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
- TF_CHECK_OK(
- GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
+ TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
// We check to ensure that data formats of both succ and pred are same.
// We expect them to be same, so we can enforce this as assert.
// But assert can be too strict, so we enforce this as a check.
@@ -3900,8 +3891,8 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
// BiasAdd has only 1 output (at slot 0) and merged node also has only 1
// output (at slot 0).
const int kConv2DWithBiasOutputSlot = 0;
- CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot,
- e->dst(), e->dst_input()));
+ CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(),
+ e->dst_input()));
}
}
@@ -3924,8 +3915,9 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
std::unique_ptr<Graph>* g, Node* m, Node* n) {
CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
n->type_string() == csinfo_.conv2d_grad_filter)) ||
- ((n->type_string() == csinfo_.bias_add_grad &&
- m->type_string() == csinfo_.conv2d_grad_filter)), true);
+ ((n->type_string() == csinfo_.bias_add_grad &&
+ m->type_string() == csinfo_.conv2d_grad_filter)),
+ true);
// If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
@@ -4132,9 +4124,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
// NULL) if we try to add duplicate edge.
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
} else {
- CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(),
- e->src()->num_outputs()),
- e->dst(), e->dst_input()));
+ CHECK_NOTNULL((*g)->AddEdge(
+ new_node,
+ GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
+ e->dst(), e->dst_input()));
}
}
@@ -4166,9 +4159,9 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
// names.
if (n->type_string() != csinfo_.conv2d_with_bias &&
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
- !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(
- n->type_string()), T)) {
- return nullptr;
+ !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
+ T)) {
+ return nullptr;
}
// For elementwise node, we reuse the Eigen implementation and pass the MKL
@@ -4184,29 +4177,30 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
// eigen code to reduce cross-library dependency.
VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string();
if (mkl_op_registry::IsMklElementWiseOp(
- mkl_op_registry::GetMklOpName(n->type_string()), T) ||
+ mkl_op_registry::GetMklOpName(n->type_string()), T) ||
n->type_string().find("Identity") != string::npos) {
VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string();
bool incoming_mkl_edge = false;
int num_parent = 0;
for (auto parent : n->in_edges()) {
if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) {
- VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is MKL op: "
- << parent->src()->type_string();
+ VLOG(1) << "ELEMENTWISE: parent " << num_parent++
+ << " is MKL op: " << parent->src()->type_string();
incoming_mkl_edge = true;
break;
} else {
- VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is NON-MKL op: "
- << parent->src()->type_string();
+ VLOG(1) << "ELEMENTWISE: parent " << num_parent++
+ << " is NON-MKL op: " << parent->src()->type_string();
}
}
if (incoming_mkl_edge == false) {
- VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which has no MKL "
+ VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which "
+ "has no MKL "
"parents.";
return nullptr;
} else {
- VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() <<
- " which has MKL parents";
+ VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string()
+ << " which has MKL parents";
}
}
@@ -4214,8 +4208,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
// for this op, then we rewrite it to Mkl op.
// Find matching RewriteInfo and then check that rewrite rule applies.
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
- if (n->type_string().compare(ri->name) == 0 &&
- ri->rewrite_rule(n)) {
+ if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
return &*ri;
}
}
@@ -4297,8 +4290,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
return MklLayoutRewritePass().RunPass(g);
}
-Status MklLayoutRewritePass::Run(
- const GraphOptimizationPassOptions& options) {
+Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr && options.partition_graphs == nullptr) {
return Status::OK();
}