aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vivek Rane <vivek.v.rane@intel.com>2017-03-23 13:13:49 -0700
committerGravatar Martin Wicke <martin.wicke@gmail.com>2017-03-23 13:13:49 -0700
commitfe97705b706c9dcd36586b6158e30758346c6afd (patch)
tree6dd25ad6e4f5c7288c02bfdad2aa8725a4762d64
parent5b4a597b088344ff55c917a505eacefe605737aa (diff)
MKL support for max/avg pooling and relu (#8296)
* Adding MKL support for Max/Avg Pooling and ReLU * Missed the mkl layer registry files * Fixed sanity check errors with buildifier * Adding MKL support for Max/Avg Pooling and ReLU * Missed the mkl layer registry files * Fixed sanity check errors with buildifier * Adding Intel Conv2D kernel implementation alongwith required Graph passes This commit contains 4 main components: 1) Intel-optimized kernel implementation for Conv2D op Implementation in kernels/mkl_conv_ops.* 2) Graph passes required to enable Conv2D optimized implementation Implementation in graph/mkl_*. We also need a new op, MklToTf op. Its implementation is in kernels/mkl_tfconv_op.cc. 3) Utility functions used in kernel implementation Implementation is in common_runtime/mkl_layer_registry* and util/mkl_util.h 4) BUILD changes for Conv2D, graph passes and utility functions * Refactor MKL convolution forward pass computation into smaller functions. Changed configure to point to newer MKLML library * Moved Mkl helper datastructures and routines to private class members * MKL op registration changed to use existing op registry (nhasabni) * Fixed buildifier error * Adding MKL support for Max/Avg Pooling and ReLU * Missed the mkl layer registry files * Fixed sanity check errors with buildifier * Removed the mkl layer registry (should not have been added) and made fixes according to the code review comments * Adding Intel Conv2D kernel implementation alongwith required Graph passes This commit contains 4 main components: 1) Intel-optimized kernel implementation for Conv2D op Implementation in kernels/mkl_conv_ops.* 2) Graph passes required to enable Conv2D optimized implementation Implementation in graph/mkl_*. We also need a new op, MklToTf op. Its implementation is in kernels/mkl_tfconv_op.cc. 3) Utility functions used in kernel implementation Implementation is in common_runtime/mkl_layer_registry* and util/mkl_util.h 4) BUILD changes for Conv2D, graph passes and utility functions * Refactor MKL convolution forward pass computation into smaller functions. Changed configure to point to newer MKLML library * Moved Mkl helper datastructures and routines to private class members * MKL op registration changed to use existing op registry (nhasabni) * Fixed buildifier error * Adding MKL support for Max/Avg Pooling and ReLU * Missed the mkl layer registry files * Fixed sanity check errors with buildifier * Removed the mkl layer registry (should not have been added) and made fixes according to the code review comments * Fixed rebase messups * Added documentation for mkl pooling op parameters * removed layer registry reference from mkl relu op
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc354
-rw-r--r--tensorflow/core/kernels/BUILD67
-rw-r--r--tensorflow/core/kernels/mkl_avgpooling_op.cc486
-rw-r--r--tensorflow/core/kernels/mkl_maxpooling_op.cc591
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.cc166
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.h93
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc387
-rw-r--r--tensorflow/core/ops/nn_ops.cc175
8 files changed, 2242 insertions, 77 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 9e3af279ea..94e54ba6d2 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -147,11 +147,75 @@ namespace tensorflow {
// it is, then we rewrite that node after constructing new inputs to
// the node. If it is not Mkl layer, then we do not rewrite the node.
//
+// Handling workspace propagation for certain ops:
+//
+// Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
+// passing of workspace from their corresponding forward ops. But
+// TensorFlow does not have a notion of workspace and as a result
+// does not allow producing additional outputs from these forward ops.
+// For these ops, we need to add an additional edge between forward
+// ops and their corresponding backward ops, and this edge carries
+// workspace tensor value and another edge carries Mkl tensor for
+// workspace tensor.
+//
+// Example:
+//
+// Typical graph for MaxPool and its gradient looks like:
+//
+// A = MaxPool(T)
+// B = MaxPoolGrad(X, A, Y)
+//
+// We will transform this graph to propagate workspace as:
+//
+// A, A_m, W, W_m = MklMaxPool(T, T_m)
+// B, B_m = MklMaxPoolGrad(X, X_m, A, A_m, Y, Y_m, W, W_m)
+//
+// Here W is the workspace tensor. Transformed tensors with name
+// suffix _m are Mkl tensors and this transformation has been done
+// using the algorithm discussed earlier. The transformation for
+// workspace only adds extra outputs (W, W_m) for forward op and
+// connects them to corresponding backward ops.
+//
+// Terms:
+//
+// Forward op name = name of the op in the forward pass
+// where workspace originates (MaxPool in this example)
+// Backward op name = name of the op in the backward pass that receives
+// workspace from forward op (MaxPoolGrad in the example)
+// Slot = Number of the output or input slot that will be
+// used by the workspace (2 for MklMaxPool as W is 3rd
+// output of MaxPool (0 is 1st); 6 for MklMaxPoolGrad)
+//
+// Question:
+//
+// How do we associate backward op to forward op? There can be more
+// than one op with exact same name.
+//
+// In this example we associate MaxPoolGrad with MaxPool. But there
+// could be more than one MaxPool ops. To solve this problem, we look
+// for _direct_ edge between forward op and backward op (tensor A is
+// flowing along this edge in the example.)
+//
+// How do we transform forward and backward op when there is no direct
+// edge between them? In such case, we generate dummy tensors as
+// workspace tensors. For the example, transformation of MaxPool will
+// be exactly same --- it is just that MaxPool won't generate any
+// workspace tensor. For MaxPoolGrad, transformation will also be same,
+// but instead of connecting W and W_m with outputs of MaxPool, we will
+// produce dummy tensors for them, and we will set workspace_enabled
+// attribute to false.
+//
class MklLayoutRewritePass : public GraphOptimizationPass {
public:
MklLayoutRewritePass() {
csinfo_.conv2d = "Conv2D";
+ csinfo_.relu = "Relu";
+ csinfo_.relugrad = "ReluGrad";
csinfo_.conv2dgradfilter = "Conv2DBackpropFilter";
+ csinfo_.maxpool = "MaxPool";
+ csinfo_.maxpoolgrad = "MaxPoolGrad";
+ csinfo_.avgpool = "AvgPool";
+ csinfo_.avgpoolgrad = "AvgPoolGrad";
csinfo_.conv2dgradinput = "Conv2DBackpropInput";
ninfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d),
@@ -162,6 +226,21 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
ninfo_.push_back({csinfo_.conv2dgradinput,
GetMklOpName(csinfo_.conv2dgradinput),
3, CopyAttrsConv2D});
+ ninfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu),
+ 1, CopyAttrsRelu});
+ ninfo_.push_back({csinfo_.relugrad, GetMklOpName(csinfo_.relugrad),
+ 2, CopyAttrsRelu });
+ ninfo_.push_back({csinfo_.maxpool, GetMklOpName(csinfo_.maxpool),
+ 1, CopyAttrsPooling});
+ ninfo_.push_back({csinfo_.maxpoolgrad, GetMklOpName(csinfo_.maxpoolgrad),
+ 3, CopyAttrsPooling});
+ ninfo_.push_back({csinfo_.avgpool, GetMklOpName(csinfo_.avgpool),
+ 1, CopyAttrsPooling});
+ ninfo_.push_back({csinfo_.avgpoolgrad, GetMklOpName(csinfo_.avgpoolgrad),
+ 3, CopyAttrsPooling});
+
+ // Add info about which ops to add workspace edge to and the slots.
+ wsinfo_.push_back({csinfo_.maxpool, csinfo_.maxpoolgrad, 0, 1, 2, 6});
}
// Standard interface to run pass
@@ -188,16 +267,39 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// to copy attributes from old node to new node.
} NodesInfo;
+ /// Structure to specify forward op, backward op, and the slot numbers
+ /// in forward and backward op where we will add workspace edge.
+ typedef struct {
+ string fwdop; // Name of the forward op in the graph
+ string bwdop; // Name of the backward op in the graph
+ int fwdslot; // Output slot in the forward op node where actual
+ // output tensor resides
+ int bwdslot; // Input slot in the backward op node where actual
+ // input tensor resides
+ int wsfwdslot; // Output slot in the forward op node where workspace
+ // edge is added
+ int wsbwdslot; // Input slot in the backward op node where workspace
+ // edge is added
+ } WorkSpaceInfo;
+
/// Structure to store all constant strings
struct {
string conv2d;
string conv2dgradfilter;
+ string maxpool;
+ string maxpoolgrad;
+ string avgpool;
+ string avgpoolgrad;
string conv2dgradinput;
+ string conv2dgradbias;
} csinfo_;
/// Maintain info about nodes to rewrite
std::vector<NodesInfo> ninfo_;
+ /// Maintain info about nodes to add workspace edge
+ std::vector<WorkSpaceInfo> wsinfo_;
+
/// Hash table to maintain nodes visited in the graph.
std::unordered_set<const Node*> visited_nodes_;
@@ -239,6 +341,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
NodeBuilder* nb, Node* orign);
+ // Add workspace edge on the input or output side of Node 'orign' by using
+ // NodeBuilder 'nb' for the new node provided. If 'orign' does not dictate
+ // adding workspace edge then do not add it.
+ void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orign,
+ NodeBuilder* nb);
+
// Rewrite Node 'n' in graph 'g' with rewrite information specified in 'ni'
// Returns Status::OK() if node rewrite is successful, otherwise returns
// appropriate error status
@@ -248,12 +356,17 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// We need operator-specific function to copy attributes because the framework
// does not provide any generic function for it.
static void CopyAttrsConv2D(Node* orign, NodeBuilder* nb);
+ static void CopyAttrsConv2DBias(Node* orign, NodeBuilder* nb);
+ static void CopyAttrsPooling(Node* orign, NodeBuilder* nb);
+ static void CopyAttrsRelu(Node* orign, NodeBuilder* nb);
// Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
// using node for original node 'orign' and return it in '*out'.
// TODO(nhasabni) We should move this to mkl_util.h
void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
Node* orign);
+ void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
+ Node* orign);
};
@@ -286,49 +399,6 @@ static void FillInputs(const Node* n,
//////////////////////////////////////////////////////////////////////////
-// Macros to build new node with different number of inputs.
-// We need this way because we need to specify all the inputs when
-// building a node. Comment at core/graph/node_builder.h, line 85-86.
-
-#define SETUP_INPUTS1(nb, op1) do { \
- nb->Input(op1.node, op1.index); \
-}while(0)
-
-#define SETUP_INPUTS2(nb, op1, op2) do { \
- nb->Input(op1.node, op1.index); \
- nb->Input(op2.node, op2.index); \
-}while(0)
-
-#define SETUP_INPUTS3(nb, op1, op2, op3) do { \
- nb->Input(op1.node, op1.index); \
- nb->Input(op2.node, op2.index); \
- nb->Input(op3.node, op3.index); \
-}while(0)
-
-#define SETUP_INPUTS4(nb, op1, op2, op3, op4) do { \
- nb->Input(op1.node, op1.index); \
- nb->Input(op2.node, op2.index); \
- nb->Input(op3.node, op3.index); \
- nb->Input(op4.node, op4.index); \
-}while(0)
-
-#define SETUP_INPUTS5(nb, op1, op2, op3, op4, op5) do {\
- nb->Input(op1.node, op1.index); \
- nb->Input(op2.node, op2.index); \
- nb->Input(op3.node, op3.index); \
- nb->Input(op4.node, op4.index); \
- nb->Input(op5.node, op5.index); \
-}while(0)
-
-#define SETUP_INPUTS6(nb, op1, op2, op3, op4, op5, op6) do {\
- nb->Input(op1.node, op1.index); \
- nb->Input(op2.node, op2.index); \
- nb->Input(op3.node, op3.index); \
- nb->Input(op4.node, op4.index); \
- nb->Input(op5.node, op5.index); \
- nb->Input(op6.node, op6.index); \
-}while(0)
-
// TODO(nhasabni) We should move this to mkl_util.h.
void MklLayoutRewritePass::GetDummyMklTensorNode(
std::unique_ptr<Graph>* g, Node** out, Node* orign) {
@@ -351,6 +421,29 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(
.Finalize(&**g, out));
}
+// TODO(nhasabni) We should move this to mkl_util.h.
+void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
+ std::unique_ptr<Graph>* g, Node** out, Node* orign) {
+ // We use a tensor of shape {1} and value 0 to represent
+ // dummy float tensor. We need this as a dummy workspace tensor.
+ // Workspace tensor has type float.
+ const DataType dt = DataTypeToEnum<float>::v();
+ TensorProto proto;
+ proto.set_dtype(dt);
+ float zero[1] = {0};
+ proto.set_tensor_content(const_cast<const void*>(
+ static_cast<void*>(&zero)), 4);
+ TensorShape dummy_shape({1});
+ dummy_shape.AsProto(proto.mutable_tensor_shape());
+ TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
+ .Attr("value", proto)
+ .Attr("dtype", dt)
+ .Device(orign->def().device()) // We place this node on same
+ // device as device of original
+ // node.
+ .Finalize(&**g, out));
+}
+
Status MklLayoutRewritePass::SetUpInputs(std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
NodeBuilder* nb, Node* orign) {
@@ -394,41 +487,102 @@ Status MklLayoutRewritePass::SetUpInputs(std::unique_ptr<Graph>* g,
// N for Mkl tensors corresponding to each Tensorflow tensors.
CHECK_EQ(new_inputs.size(), inputs.size() * 2);
- // 2. Let's build the node with new inputs.
- switch (new_inputs.size()) {
- case 0: // We don't need to do anything for no input as we have
- // already built node.
- break;
- case 1: SETUP_INPUTS1(nb, new_inputs[0]); break;
- case 2: SETUP_INPUTS2(nb, new_inputs[0],
- new_inputs[1]); break;
- case 3: SETUP_INPUTS3(nb, new_inputs[0],
- new_inputs[1],
- new_inputs[2]); break;
- case 4: SETUP_INPUTS4(nb, new_inputs[0],
- new_inputs[1],
- new_inputs[2],
- new_inputs[3]); break;
- case 5: SETUP_INPUTS5(nb, new_inputs[0],
- new_inputs[1],
- new_inputs[2],
- new_inputs[3],
- new_inputs[4]); break;
- case 6: SETUP_INPUTS6(nb, new_inputs[0],
- new_inputs[1],
- new_inputs[2],
- new_inputs[3],
- new_inputs[4],
- new_inputs[5]); break;
- default: {
- return Status(error::Code::UNIMPLEMENTED,
- "Could not create node with given number of inputs");
- }
+ // 2. Let's add the new inputs.
+ for (auto ni : new_inputs) {
+ nb->Input(ni.node, ni.index);
}
return Status::OK();
}
+void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
+ Node* orign, NodeBuilder* nb) {
+ bool workspace_edge_added = false;
+ for (auto ws : wsinfo_) {
+ if (orign->type_string() == ws.fwdop &&
+ mkl_layer_registry::IsMklLayer(GetMklOpName(orign->type_string()))) {
+ // If this op is a fwd op, then we need to check if there is an
+ // edge from this node's fwdslot to bwdop's bwdslot. If there is
+ // an edge, then we just add an attribute on this node for setting
+ // workspace_passed to true. We don't add actual workspace edge
+ // in this node. Actual workspace edge gets added in the backward
+ // op for this node.
+ for (const Edge* e : orign->out_edges()) {
+ if (e->src_output() == ws.fwdslot &&
+ e->dst()->type_string() == ws.bwdop &&
+ e->dst_input() == ws.bwdslot) {
+ nb->Attr("workspace_enabled", true);
+ VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
+ << orign->type_string();
+ workspace_edge_added = true;
+ // We found the edge that we were looking for, so break.
+ break;
+ }
+ }
+
+ if (!workspace_edge_added) {
+ // If we are here, then we did not find backward operator for this
+ // node.
+ nb->Attr("workspace_enabled", false);
+ }
+ } else if (orign->type_string() == ws.bwdop &&
+ mkl_layer_registry::IsMklLayer(GetMklOpName(orign->type_string()))) {
+ // If this op is a bwd op, then we need to add workspace edge and
+ // it's Mkl tensor edge between its corresponding fwd op and this
+ // op. Corresponding fwd op is specified in 'fwdop' field of
+ // workspace info. fwdslot and bwdslot in workspace info specify
+ // an edge between which slots connect forward and backward op.
+ // Once all these criteria match, we add a workspace edge between
+ // wsfwdslot and wsbwdslot. It's corresponding Mkl tensor is added
+ // in wsfwdslot+1 and wsbwdslot+1.
+ for (const Edge* e : orign->in_edges()) {
+ if (e->src_output() == ws.fwdslot &&
+ // We would have rewritten the forward op, so we need to use
+ // GetMklOpName call to get its Mkl name.
+ e->src()->type_string() == GetMklOpName(ws.fwdop) &&
+ e->dst_input() == ws.bwdslot) {
+ nb->Attr("workspace_enabled", true);
+ // Add workspace edge between fwd op and bwd op.
+ nb->Input(e->src(), ws.wsfwdslot);
+ // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
+ nb->Input(e->src(), ws.wsfwdslot+1);
+ // In terms of input ordering, we add these calls to add Input
+ // here because workspace edge (and its Mkl tensor) is the last
+ // edge in the fwdop and bwdop. So all inputs before workspace
+ // tensor have been added by SetUpInputs function.
+ VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
+ << orign->type_string();
+ workspace_edge_added = true;
+ // We found the edge that we were looking for, so break.
+ break;
+ }
+ }
+
+ // If we are here means we did not find fwd op that feeds to this
+ // bwd op. So in this case, we need to generate dummy tensors for
+ // workspace input and Mkl tensor for workspace, and set
+ // workspace_enabled to false.
+ if (!workspace_edge_added) {
+ nb->Attr("workspace_enabled", false);
+ Node* dmt_ws = nullptr; // Dummy tensor for workspace
+ Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace
+ GetDummyWorkspaceTensorNode(g, &dmt_ws, orign);
+ GetDummyMklTensorNode(g, &dmt_mkl_ws, orign);
+ CHECK_NOTNULL(dmt_ws);
+ CHECK_NOTNULL(dmt_mkl_ws);
+ nb->Input(dmt_ws, 0); // We add dummy tensor as workspace tensor.
+ nb->Input(dmt_mkl_ws, 0); // We add dummy tensor as Mkl
+ // tensor for workspace tensor.
+ VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
+ << orign->type_string();
+ }
+ } else {
+ // If this node does not match any workspace info, then we do not
+ // do anything special for workspace progagation for it.
+ }
+ }
+}
+
void MklLayoutRewritePass::CopyAttrsConv2D(Node* orign, NodeBuilder* nb) {
DataType T;
string data_format;
@@ -451,6 +605,53 @@ void MklLayoutRewritePass::CopyAttrsConv2D(Node* orign, NodeBuilder* nb) {
nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
}
+void MklLayoutRewritePass::CopyAttrsConv2DBias(Node* orign, NodeBuilder* nb) {
+ DataType T;
+ string data_format;
+ std::vector<int32> strides;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides));
+ TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format));
+
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("strides", strides);
+ nb->Attr("data_format", data_format);
+}
+
+void MklLayoutRewritePass::CopyAttrsPooling(Node* orign, NodeBuilder* nb) {
+ DataType T;
+ string data_format;
+ string padding;
+ std::vector<int32> ksize, strides;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orign->def(), "ksize", &ksize));
+ TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides));
+ TF_CHECK_OK(GetNodeAttr(orign->def(), "padding", &padding));
+ TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format));
+
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("ksize", ksize);
+ nb->Attr("strides", strides);
+ nb->Attr("padding", padding);
+ nb->Attr("data_format", data_format);
+}
+
+void MklLayoutRewritePass::CopyAttrsRelu(Node* orign, NodeBuilder* nb) {
+ DataType T;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
+
+ // Add attributes to new node.
+ nb->Attr("T", T);
+}
+
Status MklLayoutRewritePass::RewriteNode(
std::unique_ptr<Graph>* g, Node* orign, const NodesInfo& ni) {
VLOG(1) << "MklLayoutRewritePass: Original node:" << orign->DebugString();
@@ -471,13 +672,18 @@ Status MklLayoutRewritePass::RewriteNode(
if (s != Status::OK()) {
return s;
}
+
// Copy attributes from original node to new node.
ni.copyattrs(orign, &nb);
// Set the Mkl layer label for this op.
nb.Attr("_kernel", mkl_layer_registry::kMklLayerLabel);
- Node* newn = nullptr;
+
+ // Add workspace edge to this node if needed.
+ // We add workspace edge only for MaxPool, LRN and BatchNorm.
+ AddWorkSpaceEdgeIfNeeded(g, orign, &nb);
// Finalize graph and get new node.
+ Node* newn = nullptr;
TF_CHECK_OK(nb.Finalize(&**g, &newn));
CHECK_NOTNULL(newn);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 886713bad1..9363cc46a8 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2694,14 +2694,20 @@ tf_kernel_library(
"maxpooling_op.cc",
"pooling_ops_3d.cc",
"pooling_ops_common.cc",
- ],
+ ] + if_mkl([
+ "mkl_avgpooling_op.cc",
+ "mkl_maxpooling_op.cc",
+ "mkl_pooling_ops_common.cc",
+ ]),
hdrs = [
"avgpooling_op.h",
"cudnn_pooling_gpu.h",
"fractional_pool_common.h",
"maxpooling_op.h",
"pooling_ops_common.h",
- ],
+ ] + if_mkl([
+ "mkl_pooling_ops_common.h",
+ ]),
gpu_srcs = [
"avgpooling_op.h",
"avgpooling_op_gpu.cu.cc",
@@ -2722,7 +2728,9 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/eigen3",
- ],
+ ] + if_mkl([
+ "//third_party/mkl:intel_binary_blob",
+ ]),
)
tf_kernel_library(
@@ -4493,6 +4501,59 @@ if_mkl(
),
)
+if_mkl(
+ tf_kernel_library(
+ name = "mkl_maxpooling_op",
+ hdrs = ["mkl_pooling_ops_common.h"],
+ prefix = "mkl_maxpooling",
+ deps = [
+ ":bounds_check",
+ ":ops_util",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:nn_ops_op_lib",
+ "//third_party/mkl:intel_binary_blob",
+ ],
+ ),
+)
+
+if_mkl(
+ tf_kernel_library(
+ name = "mkl_avgpooling_op",
+ hdrs = ["mkl_pooling_ops_common.h"],
+ prefix = "mkl_avgpooling",
+ deps = [
+ ":bounds_check",
+ ":ops_util",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:nn_ops_op_lib",
+ "//third_party/mkl:intel_binary_blob",
+ ],
+ ),
+)
+
+if_mkl(
+ tf_kernel_library(
+ name = "mkl_relu_op",
+ prefix = "mkl_relu",
+ deps = [
+ ":bounds_check",
+ ":ops_util",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:nn_ops_op_lib",
+ "//third_party/mkl:intel_binary_blob",
+ ],
+ ),
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc
new file mode 100644
index 0000000000..161b6e98ff
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc
@@ -0,0 +1,486 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
+
+#ifdef INTEL_MKL
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/util/mkl_util.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/register_types.h"
+
+#include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Device, typename T>
+class MklAvgPoolingOp : public UnaryOp<T> {
+ public:
+ explicit MklAvgPoolingOp(OpKernelConstruction* context)
+ : UnaryOp<T>(context) {
+ pooling_fwd_ = nullptr;
+ lt_user_input_fwd_ = nullptr;
+ lt_input_prim_ = nullptr;
+ convert_input_ = nullptr;
+
+ input_buf_ = nullptr;
+ workspace_ = nullptr;
+
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented("Pooling is not yet supported on the "
+ "batch dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = MklGetInput(context, 0);
+
+ GetMklShape(context, 0, &mkl_input_shape_);
+ bool input_in_mkl_format = mkl_input_shape_.IsMklTensor();
+
+ if (!input_in_mkl_format)
+ mkl_params_.in_dim = tensor_in.dims();
+ else
+ mkl_params_.in_dim = mkl_input_shape_.GetDimension();
+
+ MklPoolParameters params;
+ if (!input_in_mkl_format) {
+ params.Init(context, ksize_, stride_, padding_, data_format_,
+ tensor_in.shape());
+ } else {
+ params.Init(context, ksize_, stride_, padding_, data_format_,
+ &mkl_input_shape_);
+ }
+
+ // Extract the parameters for the op from the pooling specs
+ ExtractMklOpParams(context, data_format_, params, &mkl_params_);
+
+ MklCreateLayoutsAndPrimitives(context);
+
+ AllocTmpBuffer(context, &workspace_tensor_, lt_workspace_, &workspace_);
+
+ if (convert_input_ != nullptr) {
+ if (input_in_mkl_format == false) {
+ CHECK_EQ(dnnConversionExecute_F32(convert_input_,
+ static_cast<void*>(
+ const_cast<T*>(
+ tensor_in.flat<T>().data())),
+ input_buf_),
+ E_SUCCESS);
+ CHECK_EQ(dnnDelete_F32(convert_input_), E_SUCCESS);
+ } else {
+ mkl_input_shape_.GetConvertedFlatData(lt_input_prim_,
+ static_cast<void*>(
+ const_cast<T*>(
+ tensor_in.flat<T>().data())),
+ input_buf_);
+ }
+ pooling_res_[dnnResourceSrc] = input_buf_;
+ } else {
+ pooling_res_[dnnResourceSrc] =
+ static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data()));
+ }
+
+ // Declare output tensor and allocate memory
+ Tensor* output = nullptr;
+ TensorShape tensor_out_shape;
+ MklShape mkl_out_shape;
+ mkl_out_shape.SetMklTensor(true);
+ mkl_out_shape.SetMklLayout(pooling_fwd_, dnnResourceDst);
+ mkl_out_shape.SetTfLayout(mkl_params_.in_dim,
+ mkl_params_.out_sizes,
+ mkl_params_.out_strides);
+
+ tensor_out_shape.AddDim(
+ dnnLayoutGetMemorySize_F32(
+ static_cast<dnnLayout_t>(mkl_out_shape.GetMklLayout())) / sizeof(T));
+
+ AllocateOutputSetMklshape(context,
+ 0,
+ &output,
+ tensor_out_shape,
+ mkl_out_shape);
+ pooling_res_[dnnResourceDst] =
+ static_cast<void*>(output->flat<T>().data());
+
+ pooling_res_[dnnResourceWorkspace] = workspace_;
+
+ CHECK_EQ(dnnExecute_F32(pooling_fwd_, pooling_res_), E_SUCCESS);
+
+ MklCleanup();
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ MklShape mkl_input_shape_;
+
+ dnnPrimitive_t pooling_fwd_;
+ dnnPrimitive_t convert_input_;
+ dnnLayout_t lt_user_input_fwd_;
+ dnnLayout_t lt_input_prim_;
+ dnnLayout_t lt_workspace_;
+
+ void* workspace_;
+ void* input_buf_;
+ void* pooling_res_[dnnResourceNumber];
+
+ // Tensors needed to create temporary buffers
+ Tensor input_buf_tensor_;
+ Tensor workspace_tensor_;
+
+ MklPoolingOpParams mkl_params_;
+
+ void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
+ bool input_in_mkl_format = mkl_input_shape_.IsMklTensor();
+
+ if (!input_in_mkl_format) {
+ CHECK_EQ(dnnLayoutCreate_F32(&lt_user_input_fwd_,
+ mkl_params_.in_dim,
+ mkl_params_.in_sizes,
+ mkl_params_.in_strides),
+ E_SUCCESS);
+ } else {
+ lt_user_input_fwd_ = (dnnLayout_t) mkl_input_shape_.GetCurLayout();
+ }
+
+ dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg;
+ dnnPrimitiveAttributes_t primAttr = nullptr;
+
+ // Create DNN primitives
+ CHECK_EQ(dnnPoolingCreateForward_F32(&pooling_fwd_,
+ primAttr,
+ algorithm,
+ lt_user_input_fwd_,
+ mkl_params_.kernel_size,
+ mkl_params_.kernel_stride,
+ mkl_params_.in_offset,
+ dnnBorderZerosAsymm),
+ E_SUCCESS);
+
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_input_prim_,
+ pooling_fwd_,
+ dnnResourceSrc),
+ E_SUCCESS);
+ if (!dnnLayoutCompare_F32(lt_user_input_fwd_, lt_input_prim_)) {
+ CHECK_EQ(dnnConversionCreate_F32(&convert_input_,
+ lt_user_input_fwd_,
+ lt_input_prim_),
+ E_SUCCESS);
+
+ AllocTmpBuffer(context,
+ &input_buf_tensor_,
+ lt_input_prim_,
+ &input_buf_);
+ }
+
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace_,
+ pooling_fwd_,
+ dnnResourceWorkspace),
+ E_SUCCESS);
+ }
+
+ void MklCleanup() {
+ bool input_in_mkl_format = mkl_input_shape_.IsMklTensor();
+ if (!input_in_mkl_format) {
+ CHECK_EQ(dnnLayoutDelete_F32(lt_user_input_fwd_), E_SUCCESS);
+ lt_user_input_fwd_ = nullptr;
+ }
+
+ CHECK_EQ(dnnDelete_F32(pooling_fwd_), E_SUCCESS);
+ pooling_fwd_ = nullptr;
+
+ CHECK_EQ(dnnLayoutDelete_F32(lt_input_prim_), E_SUCCESS);
+ lt_input_prim_ = nullptr;
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+template <class Device, class T>
+class MklAvgPoolingGradOp : public OpKernel {
+ public:
+ explicit MklAvgPoolingGradOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format;
+
+ pooling_bwd_ = nullptr;
+ convert_outbackprop_ = nullptr;
+ lt_user_input_bwd_ = nullptr;
+ lt_outbackprop_user_ = nullptr;
+ lt_outbackprop_prim_ = nullptr;
+ lt_workspace_prim_ = nullptr;
+
+ outbackprop_buf_ = nullptr;
+ workspace_ = nullptr;
+
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented("Pooling is not yet supported on the "
+ "batch dimension."));
+ mkl_params_.in_dim = 4;
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor &out_backprop = MklGetInput(context, 1);
+ GetMklShape(context, 2, &mkl_out_backprop_shape);
+ outbackprop_in_mkl_format_ = mkl_out_backprop_shape.IsMklTensor();
+
+ MklCreateLayoutsAndPrimitives(context);
+
+ // Check if outbackprop layout requires conversion.
+ if (!dnnLayoutCompare_F32(lt_outbackprop_user_, lt_outbackprop_prim_)) {
+ CHECK_EQ(dnnConversionCreate_F32(&convert_outbackprop_,
+ lt_outbackprop_user_,
+ lt_outbackprop_prim_),
+ E_SUCCESS);
+
+ AllocTmpBuffer(context,
+ &outbackprop_buf_tensor,
+ lt_outbackprop_prim_,
+ &outbackprop_buf_);
+
+ if (!outbackprop_in_mkl_format_) {
+ CHECK_EQ(dnnConversionExecute_F32(convert_outbackprop_,
+ static_cast<void*>(const_cast<T*>(
+ out_backprop.flat<T>().data())),
+ outbackprop_buf_),
+ E_SUCCESS);
+ CHECK_EQ(dnnDelete_F32(convert_outbackprop_), E_SUCCESS);
+ } else {
+ mkl_out_backprop_shape.
+ GetConvertedFlatData(lt_outbackprop_prim_,
+ static_cast<void*>(const_cast<T*>(
+ out_backprop.flat<T>().data())),
+ outbackprop_buf_);
+ }
+ pooling_res_[dnnResourceDiffDst] = outbackprop_buf_;
+ } else {
+ pooling_res_[dnnResourceDiffDst] =
+ static_cast<void*>(const_cast<T*>(out_backprop.flat<T>().data()));
+ }
+
+ // Handle workspace requirements.
+ AllocTmpBuffer(context,
+ &workspace_buf_tensor,
+ lt_workspace_prim_,
+ &workspace_);
+ pooling_res_[dnnResourceWorkspace] = workspace_;
+
+ // Handle MKL output tensor setup.
+ Tensor* output = nullptr;
+ TensorShape tensor_out_shape;
+ MklShape mkl_out_shape;
+ mkl_out_shape.SetMklTensor(true);
+ mkl_out_shape.SetMklLayout(pooling_bwd_, dnnResourceDiffSrc);
+ mkl_out_shape.SetTfLayout(mkl_params_.in_dim,
+ mkl_params_.in_sizes,
+ mkl_params_.in_strides);
+
+ tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(
+ static_cast<dnnLayout_t>(
+ mkl_out_shape.GetMklLayout())) / sizeof(T));
+
+ AllocateOutputSetMklshape(context,
+ 0,
+ &output,
+ tensor_out_shape,
+ mkl_out_shape);
+
+ // Set output tensor.
+ pooling_res_[dnnResourceDiffSrc] =
+ static_cast<void*>(output->flat<T>().data());
+
+ // Execute primitive.
+ CHECK_EQ(dnnExecute_F32(pooling_bwd_, pooling_res_), E_SUCCESS);
+
+ MklCleanup();
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ bool outbackprop_in_mkl_format_;
+
+ MklShape mkl_out_backprop_shape;
+
+ // Tensors needed to create temporary buffers
+ Tensor outbackprop_buf_tensor;
+ Tensor workspace_buf_tensor;
+
+ dnnPrimitive_t pooling_bwd_;
+ dnnPrimitive_t convert_outbackprop_;
+ dnnLayout_t lt_user_input_bwd_;
+ dnnLayout_t lt_outbackprop_user_;
+ dnnLayout_t lt_outbackprop_prim_;
+ dnnLayout_t lt_workspace_prim_;
+
+ void* workspace_;
+ void* outbackprop_buf_;
+ void* pooling_res_[dnnResourceNumber]; // Pooling resource array
+
+ MklPoolingOpParams mkl_params_;
+
+ void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
+ const Tensor& tensor_in_shape = MklGetInput(context, 0);
+ const Tensor &out_backprop = MklGetInput(context, 1);
+
+ if (!outbackprop_in_mkl_format_) {
+ // For avgpooling, tensor_in_shape should have 1 dimension, and 4
+ // elements.
+ OP_REQUIRES(context,
+ tensor_in_shape.dims() == 1 &&
+ tensor_in_shape.NumElements() == 4,
+ errors::InvalidArgument("original input shape must be "
+ "1-dimensional and 4 elements"));
+
+ // For avgpooling, out_backprop should have 4 dimensions.
+ OP_REQUIRES(context, out_backprop.dims() == 4,
+ errors::InvalidArgument("out_backprop must be "
+ "4-dimensional"));
+ } else {
+ // Input in MKL format.
+ OP_REQUIRES(context, out_backprop.dims() == 2,
+ errors::InvalidArgument("out_backprop in MKL format must be "
+ "2-dimensional"));
+
+ // For avgpooling, out_backprop should have 4 dimensions.
+ OP_REQUIRES(context, mkl_out_backprop_shape.GetDimension() == 4,
+ errors::InvalidArgument("out_backprop must be "
+ "4-dimensional"));
+ }
+
+ TensorShape output_shape;
+ auto shape_vec = tensor_in_shape.vec<int32>();
+ for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) {
+ output_shape.AddDim(shape_vec(i));
+ }
+
+ MklPoolParameters params;
+ params.Init(context, ksize_, stride_, padding_, data_format_, output_shape);
+
+ // Extract the parameters for the op from the pooling specs
+ ExtractMklOpParams(context, data_format_, params, &mkl_params_);
+
+ // TODO(inteltf): Get outbackprop layout.
+ // Do we need to create layout in every invocation?
+ if (!outbackprop_in_mkl_format_) {
+ CHECK_EQ(dnnLayoutCreate_F32(&lt_outbackprop_user_,
+ mkl_params_.in_dim,
+ mkl_params_.out_sizes,
+ mkl_params_.out_strides),
+ E_SUCCESS);
+ } else {
+ lt_outbackprop_user_ =
+ (dnnLayout_t) mkl_out_backprop_shape.GetCurLayout();
+ }
+
+ // Create the backward primitive
+ // Create DNN user layout
+ CHECK_EQ(dnnLayoutCreate_F32(&lt_user_input_bwd_,
+ mkl_params_.in_dim,
+ mkl_params_.in_sizes,
+ mkl_params_.in_strides),
+ E_SUCCESS);
+
+ // Create PoolingBackward primitive
+ dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg;
+ dnnPrimitiveAttributes_t primAttr = nullptr;
+ CHECK_EQ(dnnPoolingCreateBackward_F32(&pooling_bwd_,
+ primAttr,
+ algorithm,
+ lt_user_input_bwd_,
+ mkl_params_.kernel_size,
+ mkl_params_.kernel_stride,
+ mkl_params_.in_offset,
+ dnnBorderZerosAsymm),
+ E_SUCCESS);
+
+ // Create expected outbackprop layout from the primitive.
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_outbackprop_prim_,
+ pooling_bwd_,
+ dnnResourceDiffDst),
+ E_SUCCESS);
+
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace_prim_,
+ pooling_bwd_,
+ dnnResourceWorkspace),
+ E_SUCCESS);
+ }
+
+ void MklCleanup() {
+ CHECK_EQ(dnnDelete_F32(pooling_bwd_), E_SUCCESS);
+ pooling_bwd_ = nullptr;
+
+ CHECK_EQ(dnnLayoutDelete_F32(lt_user_input_bwd_), E_SUCCESS);
+ lt_user_input_bwd_ = nullptr;
+
+ if (!outbackprop_in_mkl_format_) {
+ CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_user_), E_SUCCESS);
+ lt_outbackprop_user_ = nullptr;
+ }
+
+ CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_prim_), E_SUCCESS);
+ lt_outbackprop_prim_ = nullptr;
+
+ CHECK_EQ(dnnLayoutDelete_F32(lt_workspace_prim_), E_SUCCESS);
+ lt_workspace_prim_ = nullptr;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MklAvgPool").Device(DEVICE_CPU).TypeConstraint<float>("T")
+ .Label(mkl_layer_registry::kMklLayerLabel),
+ MklAvgPoolingOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("MklAvgPoolGrad").Device(DEVICE_CPU).TypeConstraint<float>("T")
+ .Label(mkl_layer_registry::kMklLayerLabel),
+ MklAvgPoolingGradOp<CPUDevice, float>);
+
+} // namespace tensorflow
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
new file mode 100644
index 0000000000..4342efd764
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -0,0 +1,591 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/nn_ops.cc.
+#ifdef INTEL_MKL
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/util/mkl_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/padding.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+// An implementation of MaxPooling (forward).
+template <typename Device, typename T>
+class MklMaxPoolingOp : public OpKernel {
+ public:
+ explicit MklMaxPoolingOp(OpKernelConstruction* context) : OpKernel(context) {
+ string data_format;
+
+ pooling_fwd_ = nullptr;
+ lt_user_input_fwd_ = nullptr;
+ lt_workspace_ = nullptr;
+
+ workspace_ = nullptr;
+
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented("Pooling is not yet supported on the "
+ "batch dimension."));
+
+ workspace_enabled_ = false;
+ // We may not get this attribute for this node if it does not go through
+ // graph rewrite pass. So we do not check for error while retrieving this
+ // attribute value.
+ context->GetAttr("workspace_enabled", &workspace_enabled_);
+
+ mkl_params_.in_dim = 4;
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Get the input tensor
+ const Tensor& tensor_in = MklGetInput(context, 0);
+ GetMklShape(context, 0, &mkl_input_shape);
+ input_in_mkl_format_ = mkl_input_shape.IsMklTensor();
+
+ MklPoolParameters params;
+ if (input_in_mkl_format_ == false) {
+ params.Init(context, ksize_, stride_, padding_, data_format_,
+ tensor_in.shape());
+ OP_REQUIRES(context, (params.depth_window == 1),
+ errors::Unimplemented(
+ "Depthwise max pooling not supported by MKL"));
+
+ } else {
+ params.Init(context, ksize_, stride_, padding_, data_format_,
+ &mkl_input_shape);
+ }
+
+ // Extract the parameters for the op from the pooling specs
+ ExtractMklOpParams(context, data_format_, params, &mkl_params_);
+
+ MklCreateLayoutsAndPrimitives(context);
+
+ // Declare output tensor
+ TensorShape tensor_out_shape;
+ MklShape mkl_out_shape;
+ mkl_out_shape.SetMklTensor(true);
+ mkl_out_shape.SetMklLayout(pooling_fwd_, dnnResourceDst);
+ mkl_out_shape.SetTfLayout(mkl_params_.in_dim,
+ mkl_params_.out_sizes,
+ mkl_params_.out_strides);
+
+ Tensor* output_tensor = nullptr;
+ tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(
+ static_cast<dnnLayout_t>(mkl_out_shape.GetMklLayout())) / sizeof(T));
+ AllocateOutputSetMklshape(context,
+ 0,
+ &output_tensor,
+ tensor_out_shape,
+ mkl_out_shape);
+
+ // For allocating temporary buffer
+ Tensor workspace_tensor;
+
+ if (workspace_enabled_) {
+ Tensor *workspace_tensor;
+ TensorShape workspace_shape;
+ workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(
+ static_cast<dnnLayout_t>(lt_workspace_))/ sizeof(T));
+ AllocateOutputSetMklshape(context, 1, &workspace_tensor,
+ workspace_shape, mkl_out_shape);
+ pooling_res_[dnnResourceWorkspace] = const_cast<void*>(
+ static_cast<const void*>(workspace_tensor->flat<T>().data()));
+ } else {
+ AllocTmpBuffer(context, &workspace_tensor, lt_workspace_, &workspace_);
+ pooling_res_[dnnResourceWorkspace] = workspace_;
+ }
+
+ pooling_res_[dnnResourceSrc] =
+ const_cast<void*>(
+ static_cast<const void*>(tensor_in.flat<T>().data()));
+ pooling_res_[dnnResourceDst] =
+ const_cast<void*>(
+ static_cast<const void*>(output_tensor->flat<T>().data()));
+
+ CHECK_EQ(dnnExecute_F32(pooling_fwd_, pooling_res_),
+ E_SUCCESS);
+
+ if (workspace_enabled_ == false) {
+ workspace_ = nullptr;
+ }
+
+ MklCleanup();
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ MklShape mkl_input_shape;
+
+ bool workspace_enabled_;
+ bool input_in_mkl_format_;
+
+ void* workspace_;
+ void* pooling_res_[dnnResourceNumber];
+
+ dnnPrimitive_t pooling_fwd_;
+ dnnLayout_t lt_user_input_fwd_;
+ dnnLayout_t lt_workspace_;
+
+ MklPoolingOpParams mkl_params_;
+
+ void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
+ // Create or use existing DNN user layout
+ if (input_in_mkl_format_ == false) {
+ CHECK_EQ(dnnLayoutCreate_F32(&lt_user_input_fwd_,
+ mkl_params_.in_dim,
+ mkl_params_.in_sizes,
+ mkl_params_.in_strides),
+ E_SUCCESS);
+ } else {
+ lt_user_input_fwd_ = (dnnLayout_t)mkl_input_shape.GetCurLayout();
+ }
+
+ dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
+ dnnPrimitiveAttributes_t primAttr = nullptr;
+
+ // Create DNN primitives
+ CHECK_EQ(dnnPoolingCreateForward_F32(&pooling_fwd_,
+ primAttr,
+ algorithm,
+ lt_user_input_fwd_,
+ mkl_params_.kernel_size,
+ mkl_params_.kernel_stride,
+ mkl_params_.in_offset,
+ dnnBorderZerosAsymm),
+ E_SUCCESS);
+
+ // Creates layout for the workspace
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace_,
+ pooling_fwd_,
+ dnnResourceWorkspace),
+ E_SUCCESS);
+ }
+
+ void MklCleanup() {
+ CHECK_EQ(dnnDelete_F32(pooling_fwd_), E_SUCCESS);
+ pooling_fwd_ = nullptr;
+
+ if (input_in_mkl_format_) {
+ CHECK_EQ(dnnLayoutDelete_F32(lt_user_input_fwd_), E_SUCCESS);
+ lt_user_input_fwd_ = nullptr;
+ }
+
+ CHECK_EQ(dnnLayoutDelete_F32(lt_workspace_), E_SUCCESS);
+ lt_workspace_ = nullptr;
+ }
+};
+
+// The operation to compute MaxPool gradients.
+// It takes three inputs:
+// - The original input tensor
+// - The original output tensor
+// - Backprop tensor for output
+// It produces one output: backprop tensor for input.
+template <class Device, class T>
+class MklMaxPoolingGradOp : public OpKernel {
+ public:
+ explicit MklMaxPoolingGradOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format;
+
+ pooling_fwd_ = nullptr;
+ pooling_bwd_ = nullptr;
+
+ lt_outbackprop_user_ = nullptr;
+ lt_outbackprop_prim_ = nullptr;
+ lt_input_user_ = nullptr;
+ lt_input_prim_ = nullptr;
+
+ convert_outbackprop_ = nullptr;
+ convert_input_ = nullptr;
+
+ input_buf_ = nullptr;
+ outbackprop_buf_ = nullptr;
+
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ workspace_enabled_ = false;
+ // We may not get this attribute for this node if it does not go through
+ // graph rewrite pass. So we do not check for error while retrieving this
+ // attribute value.
+ context->GetAttr("workspace_enabled", &workspace_enabled_);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Input - The original input tensor
+ const Tensor& tensor_in = MklGetInput(context, 0);
+
+ // Output - Backprop tensor for input.
+ Tensor* output_tensor = nullptr;
+
+ GetMklShape(context, 0, &mkl_input_shape);
+ input_in_mkl_format_ = mkl_input_shape.IsMklTensor();
+
+ MklShape mkl_output_backprop_shape;
+ GetMklShape(context, 2, &mkl_output_backprop_shape);
+ outbackprop_in_mkl_format_ = mkl_output_backprop_shape.IsMklTensor();
+
+ if (input_in_mkl_format_ == false)
+ mkl_params_.in_dim = tensor_in.dims();
+ else
+ mkl_params_.in_dim = mkl_input_shape.GetDimension();
+
+ MklPoolParameters params;
+ if (input_in_mkl_format_ == false) {
+ params.Init(context, ksize_, stride_, padding_, data_format_,
+ tensor_in.shape());
+ OP_REQUIRES(context, (params.depth_window == 1),
+ errors::Unimplemented(
+ "Depthwise max pooling not supported by MKL"));
+
+ } else {
+ params.Init(context, ksize_, stride_, padding_, data_format_,
+ &mkl_input_shape);
+ }
+
+ // Extract the parameters for the op from the pooling specs
+ ExtractMklOpParams(context, data_format_, params, &mkl_params_);
+
+ // mkldnn
+ MklCreateLayouts(context);
+ MklCreatePrimitives(context);
+ MklPrepareInputs(context);
+
+ // Create shape for the input back prop output
+ TensorShape mkl_input_backprop;
+ MklShape mklOutputShape;
+ mklOutputShape.SetMklTensor(true);
+ mklOutputShape.SetMklLayout(pooling_bwd_, dnnResourceDiffSrc);
+ mklOutputShape.SetTfLayout(mkl_params_.in_dim,
+ mkl_params_.in_sizes,
+ mkl_params_.in_strides);
+
+ mkl_input_backprop.AddDim(
+ dnnLayoutGetMemorySize_F32(
+ static_cast<dnnLayout_t>(mklOutputShape.GetMklLayout())) /
+ sizeof(T));
+ AllocateOutputSetMklshape(context,
+ 0,
+ &output_tensor,
+ mkl_input_backprop,
+ mklOutputShape);
+ pooling_res_[dnnResourceDiffSrc] =
+ static_cast<void*>(const_cast<float*>(output_tensor->flat<T>().data()));
+
+ int64 output_size = output_tensor->NumElements();
+ for (int64 i = 0; i < output_size; ++i) {
+ (static_cast<float*>(pooling_res_[dnnResourceDiffSrc]))[i] = 0;
+ }
+
+ CHECK_EQ(dnnExecute_F32(pooling_bwd_, pooling_res_), E_SUCCESS);
+
+ MklCleanup();
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ MklShape mkl_input_shape;
+
+ bool workspace_enabled_;
+ bool input_in_mkl_format_;
+ bool outbackprop_in_mkl_format_;
+
+ void* input_buf_;
+ void* outbackprop_buf_;
+ void* pooling_res_fwd_[dnnResourceNumber]; // Pooling resource array for fwd
+ void* pooling_res_[dnnResourceNumber]; // Pooling resource array
+
+ dnnPrimitive_t pooling_fwd_;
+ dnnPrimitive_t pooling_bwd_;
+ dnnPrimitive_t convert_input_;
+ dnnPrimitive_t convert_outbackprop_;
+
+ dnnLayout_t lt_outbackprop_user_;
+ dnnLayout_t lt_outbackprop_prim_;
+ dnnLayout_t lt_input_user_;
+ dnnLayout_t lt_input_prim_;
+
+ MklPoolingOpParams mkl_params_;
+
+ void MklCreateLayouts(OpKernelContext* context) {
+ // Create DNN user layout for input and outbackprop or get existing layout
+ if (input_in_mkl_format_ == false) {
+ CHECK_EQ(dnnLayoutCreate_F32(&lt_input_user_,
+ mkl_params_.in_dim,
+ mkl_params_.in_sizes,
+ mkl_params_.in_strides),
+ E_SUCCESS);
+ } else {
+ lt_input_user_ = (dnnLayout_t)mkl_input_shape.GetCurLayout();
+ }
+
+ MklShape mkl_output_backprop_shape;
+ GetMklShape(context, 2, &mkl_output_backprop_shape);
+
+ // We dont care about the output layout for now as we can create it from
+ // primitives for the max pooling fwd prop
+ if (outbackprop_in_mkl_format_ == false) {
+ CHECK_EQ(dnnLayoutCreate_F32(&lt_outbackprop_user_,
+ mkl_params_.in_dim,
+ mkl_params_.out_sizes,
+ mkl_params_.out_strides),
+ E_SUCCESS);
+ } else {
+ lt_outbackprop_user_ =
+ (dnnLayout_t)mkl_output_backprop_shape.GetCurLayout();
+ }
+ }
+
+ // Create DNN primitives
+ void MklCreatePrimitives(OpKernelContext* context) {
+ dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
+ dnnPrimitiveAttributes_t primAttr = nullptr;
+
+ if (workspace_enabled_ == false) {
+ CHECK_EQ(dnnPoolingCreateForward_F32(&pooling_fwd_,
+ primAttr,
+ algorithm,
+ lt_input_user_,
+ mkl_params_.kernel_size,
+ mkl_params_.kernel_stride,
+ mkl_params_.in_offset,
+ dnnBorderZerosAsymm),
+ E_SUCCESS);
+ }
+
+ CHECK_EQ(dnnPoolingCreateBackward_F32(&pooling_bwd_,
+ primAttr,
+ algorithm,
+ lt_input_user_,
+ mkl_params_.kernel_size,
+ mkl_params_.kernel_stride,
+ mkl_params_.in_offset,
+ dnnBorderZerosAsymm),
+ E_SUCCESS);
+
+ // Creates conversions
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_outbackprop_prim_,
+ pooling_bwd_,
+ dnnResourceDiffDst),
+ E_SUCCESS);
+
+ // Tensors needed to create temporary buffers
+ Tensor input_buf_tensor, outbackprop_buf_tensor;
+
+ if (workspace_enabled_ == false) {
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_input_prim_,
+ pooling_fwd_,
+ dnnResourceSrc),
+ E_SUCCESS);
+ if (!dnnLayoutCompare_F32(lt_input_user_, lt_input_prim_)) {
+ CHECK_EQ(dnnConversionCreate_F32(&convert_input_,
+ lt_input_user_,
+ lt_input_prim_),
+ E_SUCCESS);
+ AllocTmpBuffer(context,
+ &input_buf_tensor,
+ lt_input_prim_,
+ &input_buf_);
+ }
+ }
+
+ if (!dnnLayoutCompare_F32(lt_outbackprop_user_, lt_outbackprop_prim_)) {
+ CHECK_EQ(dnnConversionCreate_F32(&convert_outbackprop_,
+ lt_outbackprop_user_,
+ lt_outbackprop_prim_),
+ E_SUCCESS);
+ AllocTmpBuffer(context,
+ &outbackprop_buf_tensor,
+ lt_outbackprop_prim_,
+ &outbackprop_buf_);
+ }
+ }
+
+ // Compare incoming tensor layouts with MKL preferred layouts and convert
+ // data to the preferred layout if necessary
+ void MklPrepareInputs(OpKernelContext* context) {
+ // Input - The original input tensor
+ const Tensor& tensor_in = MklGetInput(context, 0);
+ // Backprop tensor for output
+ const Tensor& out_backprop = MklGetInput(context, 2);
+
+ MklShape mkl_input_shape;
+ GetMklShape(context, 0, &mkl_input_shape);
+
+ void* tmp_output_buf;
+ Tensor tmp_output_buf_tensor;
+
+ void* workspace_buf;
+ Tensor workspace_buf_tensor;
+
+ if (workspace_enabled_ == false) {
+ if (convert_input_ != nullptr) {
+ if (input_in_mkl_format_ == false) {
+ CHECK_EQ(
+ dnnConversionExecute_F32(
+ convert_input_,
+ const_cast<void*>(
+ static_cast<const void*>(tensor_in.flat<T>().data())),
+ input_buf_),
+ E_SUCCESS);
+ CHECK_EQ(dnnDelete_F32(convert_input_), E_SUCCESS);
+ convert_input_ = nullptr;
+ } else {
+ mkl_input_shape.GetConvertedFlatData(
+ lt_input_prim_,
+ const_cast<void*>(
+ static_cast<const void*>(tensor_in.flat<T>().data())),
+ input_buf_);
+ }
+ pooling_res_fwd_[dnnResourceSrc] = input_buf_;
+ input_buf_ = nullptr;
+ } else {
+ pooling_res_fwd_[dnnResourceSrc] =
+ const_cast<void*>(
+ static_cast<const void*>(tensor_in.flat<T>().data()));
+ }
+
+ dnnLayout_t lt_workspace;
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace,
+ pooling_fwd_,
+ dnnResourceWorkspace),
+ E_SUCCESS);
+ AllocTmpBuffer(context,
+ &workspace_buf_tensor,
+ lt_workspace, &workspace_buf);
+ pooling_res_fwd_[dnnResourceWorkspace] = workspace_buf;
+
+ dnnLayoutDelete_F32(lt_workspace);
+
+ // We create the layout for max pooling fwd prop tmp output here
+ AllocTmpBuffer(context, &tmp_output_buf_tensor,
+ lt_outbackprop_prim_, &tmp_output_buf);
+ pooling_res_fwd_[dnnResourceDst] = tmp_output_buf;
+
+ CHECK_EQ(dnnExecute_F32(pooling_fwd_, pooling_res_fwd_), E_SUCCESS);
+ pooling_res_[dnnResourceWorkspace] =
+ pooling_res_fwd_[dnnResourceWorkspace];
+ } else {
+ const Tensor& workspace = MklGetInput(context, 3);
+ pooling_res_[dnnResourceWorkspace] = const_cast<void*>(
+ static_cast<const void*>(workspace.flat<T>().data()));
+ }
+
+ // Out backprop conversions if needed
+ if (convert_outbackprop_ != nullptr) {
+ if (outbackprop_in_mkl_format_ == false) {
+ CHECK_EQ(dnnConversionExecute_F32(
+ convert_outbackprop_,
+ const_cast<void*>(
+ static_cast<const void*>(out_backprop.flat<T>().data())),
+ outbackprop_buf_),
+ E_SUCCESS);
+ CHECK_EQ(dnnDelete_F32(convert_outbackprop_), E_SUCCESS);
+ convert_outbackprop_ = nullptr;
+ } else {
+ MklShape mkl_output_backprop_shape;
+ GetMklShape(context, 2, &mkl_output_backprop_shape);
+ mkl_output_backprop_shape.GetConvertedFlatData(
+ lt_outbackprop_prim_,
+ const_cast<void*>(
+ static_cast<const void*>(out_backprop.flat<T>().data())),
+ outbackprop_buf_);
+ }
+ pooling_res_[dnnResourceDiffDst] = outbackprop_buf_;
+ outbackprop_buf_ = nullptr;
+ } else {
+ pooling_res_[dnnResourceDiffDst] =
+ const_cast<void*>(
+ static_cast<const void*>(out_backprop.flat<T>().data()));
+ }
+ }
+
+ void MklCleanup() {
+ if (workspace_enabled_ == false) {
+ CHECK_EQ(dnnDelete_F32(pooling_fwd_), E_SUCCESS);
+ pooling_fwd_ = nullptr;
+ }
+
+ CHECK_EQ(dnnDelete_F32(pooling_bwd_), E_SUCCESS);
+ pooling_bwd_ = nullptr;
+
+ if (outbackprop_in_mkl_format_ == false) {
+ CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_user_), E_SUCCESS);
+ lt_outbackprop_user_ = nullptr;
+ }
+
+ CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_prim_), E_SUCCESS);
+ lt_outbackprop_prim_ = nullptr;
+
+ if (input_in_mkl_format_ == false) {
+ CHECK_EQ(dnnLayoutDelete_F32(lt_input_user_), E_SUCCESS);
+ lt_input_user_ = nullptr;
+ }
+
+ if (workspace_enabled_ == false) {
+ CHECK_EQ(dnnLayoutDelete_F32(lt_input_prim_), E_SUCCESS);
+ lt_input_prim_ = nullptr;
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MklMaxPool").Device(DEVICE_CPU).TypeConstraint<float>("T")
+ .Label(mkl_layer_registry::kMklLayerLabel),
+ MklMaxPoolingOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("MklMaxPoolGrad").Device(DEVICE_CPU).TypeConstraint<float>("T")
+ .Label(mkl_layer_registry::kMklLayerLabel),
+ MklMaxPoolingGradOp<CPUDevice, float>);
+
+} // namespace tensorflow
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
new file mode 100644
index 0000000000..3eb472d7e3
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
@@ -0,0 +1,166 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+#include <vector>
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
+#include "tensorflow/core/common_runtime/device.h"
+
+namespace tensorflow {
+
+ // Initialization for TensorFlow format
+ void MklPoolParameters::Init(OpKernelContext* context,
+ const std::vector<int32>& ksize,
+ const std::vector<int32>& stride,
+ Padding padding,
+ TensorFormat data_format,
+ const TensorShape& tensor_in_shape) {
+ // For maxpooling, tensor_in should have 4 dimensions.
+ OP_REQUIRES(context, tensor_in_shape.dims() == 4,
+ errors::InvalidArgument("tensor_in must be 4-dimensional"));
+
+ depth = GetTensorDim(tensor_in_shape, data_format, 'C');
+ tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
+ tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
+ tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
+
+ Init(context, ksize, stride, padding, data_format);
+ }
+
+ // Initialization for MKL format
+ void MklPoolParameters::Init(OpKernelContext* context,
+ const std::vector<int32>& ksize,
+ const std::vector<int32>& stride,
+ Padding padding,
+ TensorFormat data_format,
+ const MklShape* mklInputShape) {
+ // Get the input sizes
+ depth = mklInputShape->GetSizes()[2];
+ tensor_in_cols = mklInputShape->GetSizes()[0];
+ tensor_in_rows = mklInputShape->GetSizes()[1];
+ tensor_in_batch = mklInputShape->GetSizes()[3];
+
+ Init(context, ksize, stride, padding, data_format);
+ }
+
+ // Common Initialization for TensorFlow and MKL formats
+ void MklPoolParameters::Init(OpKernelContext* context,
+ const std::vector<int32>& ksize,
+ const std::vector<int32>& stride,
+ Padding padding,
+ TensorFormat data_format) {
+ // Get the data format
+ this->data_format = data_format;
+
+ // Get the output sizes
+ window_rows = GetTensorDim(ksize, data_format, 'H');
+ window_cols = GetTensorDim(ksize, data_format, 'W');
+ depth_window = GetTensorDim(ksize, data_format, 'C');
+
+ // Get the strides
+ row_stride = GetTensorDim(stride, data_format, 'H');
+ col_stride = GetTensorDim(stride, data_format, 'W');
+ depth_stride = GetTensorDim(stride, data_format, 'C');
+
+ // We only support 2D pooling across width/height and depthwise
+ // pooling, not a combination.
+ OP_REQUIRES(context,
+ (depth_window == 1 || (window_rows == 1 && window_cols == 1)),
+ errors::Unimplemented(
+ "MaxPooling supports exactly one of pooling across depth "
+ "or pooling across width/height."));
+
+ if (depth_window == 1) {
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSizeVerbose(tensor_in_rows,
+ window_rows,
+ row_stride,
+ padding,
+ &out_height,
+ &pad_top,
+ &pad_bottom));
+
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSizeVerbose(tensor_in_cols,
+ window_cols,
+ col_stride,
+ padding,
+ &out_width,
+ &pad_left,
+ &pad_right));
+ } else {
+ // Our current version of depthwise max pooling does not support
+ // any padding, and expects the depth_window to equal the depth
+ // stride (no overlapping).
+ OP_REQUIRES(context, depth % depth_window == 0,
+ errors::Unimplemented("Depthwise max pooling requires the"
+ " depth window to evenly divide the"
+ " input depth"));
+ OP_REQUIRES(context, depth_stride == depth_window,
+ errors::Unimplemented("Depthwise max pooling requires the"
+ " depth window to equal the depth"
+ " stride"));
+
+ // The current version of depthwise max is only implemented on CPU.
+ OP_REQUIRES(context,
+ (DeviceType(static_cast<Device*>(context->device())
+ ->attributes()
+ .device_type()) == DeviceType(DEVICE_CPU)),
+ errors::Unimplemented("Depthwise max pooling is currently "
+ "only implemented for CPU devices."));
+
+ pad_depth = 0;
+ out_depth = depth / depth_window;
+ }
+ }
+
+ // Transfers the right parameters for pooling to the op parameters
+ // Updates context->status if there is an invalid input.
+ void ExtractMklOpParams(OpKernelContext* context,
+ TensorFormat data_format,
+ const MklPoolParameters &params,
+ MklPoolingOpParams *mkl_params) {
+ mkl_params->in_sizes[0] = params.tensor_in_cols;
+ mkl_params->in_sizes[1] = params.tensor_in_rows;
+ mkl_params->in_sizes[2] = params.depth;
+ mkl_params->in_sizes[3] = params.tensor_in_batch;
+
+ GetStridesFromSizes(data_format,
+ mkl_params->in_strides,
+ mkl_params->in_sizes);
+
+ mkl_params->out_sizes[0] = params.out_width;
+ mkl_params->out_sizes[1] = params.out_height;
+ mkl_params->out_sizes[2] = params.depth;
+ mkl_params->out_sizes[3] = params.tensor_in_batch;
+
+ GetStridesFromSizes(data_format,
+ mkl_params->out_strides,
+ mkl_params->out_sizes);
+
+ mkl_params->in_offset[0] = -params.pad_left;
+ mkl_params->in_offset[1] = -params.pad_top;
+ mkl_params->in_offset[2] = -params.pad_right;
+ mkl_params->in_offset[3] = -params.pad_bottom;
+
+ mkl_params->kernel_stride[0] = params.col_stride;
+ mkl_params->kernel_stride[1] = params.row_stride;
+
+ mkl_params->kernel_size[0] = params.window_cols;
+ mkl_params->kernel_size[1] = params.window_rows;
+ }
+} // namespace tensorflow
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
new file mode 100644
index 0000000000..0a7c4dd15e
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
+#define TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
+
+#ifdef INTEL_MKL
+#include <vector>
+#include "tensorflow/core/util/mkl_util.h"
+#include "tensorflow/core/util/padding.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+struct MklPoolParameters {
+ int depth;
+
+ int tensor_in_cols;
+ int tensor_in_rows;
+ int tensor_in_batch;
+
+ int window_rows;
+ int window_cols;
+ int depth_window;
+
+ int row_stride;
+ int col_stride;
+ int depth_stride;
+
+ int64 out_height;
+ int64 out_width;
+ int out_depth;
+
+ int64 pad_left;
+ int64 pad_right;
+ int64 pad_top;
+ int64 pad_bottom;
+ int pad_depth;
+
+ TensorFormat data_format;
+
+ // Updates context->status if there is an invalid input.
+ void Init(OpKernelContext* context, const std::vector<int32>& ksize,
+ const std::vector<int32>& stride, Padding padding,
+ TensorFormat data_format, const TensorShape& tensor_in_shape);
+ void Init(OpKernelContext* context, const std::vector<int32>& ksize,
+ const std::vector<int32>& stride, Padding padding,
+ TensorFormat data_format, const MklShape* mkl_in_shape);
+
+ private:
+ // Common initialization for TensorFlow and MKL formats
+ void Init(OpKernelContext* context, const std::vector<int32>& ksize,
+ const std::vector<int32>& stride, Padding padding,
+ TensorFormat data_format);
+};
+
+//-------------------------------------------------------------------
+// Utility functions
+
+typedef struct {
+ size_t in_dim;
+ size_t in_sizes[4];
+ size_t in_strides[4];
+ size_t out_sizes[4];
+ size_t out_strides[4];
+ int in_offset[4];
+ size_t kernel_stride[2];
+ size_t kernel_size[2];
+} MklPoolingOpParams;
+
+// Transfers the right parameters for pooling to the op parameters
+// Updates context->status if there is an invalid input.
+void ExtractMklOpParams(OpKernelContext* context,
+ TensorFormat data_format,
+ const MklPoolParameters &params,
+ MklPoolingOpParams *mkl_params);
+} // namespace tensorflow
+
+#endif // INTEL_MKL
+#endif // TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
new file mode 100644
index 0000000000..b70064a24a
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -0,0 +1,387 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/nn_ops.cc.
+#ifdef INTEL_MKL
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include "tensorflow/core/platform/default/logging.h"
+#include "tensorflow/core/util/mkl_util.h"
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+struct MklReluHelpers {
+ static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g,
+ const Tensor& a) {
+ OP_REQUIRES(context, a.IsSameSize(g),
+ errors::InvalidArgument("g and a must be the same size"));
+ }
+ static bool ValidateSameSize(OpKernelContext* context, const Tensor& g,
+ const Tensor& a) {
+ ValidateSameSizeHelper(context, g, a);
+ return context->status().ok();
+ }
+};
+
+template <typename Device, typename T>
+class MklReluOp : public OpKernel {
+ public:
+ ~MklReluOp() {}
+
+ explicit MklReluOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = MklGetInput(context, 0);
+ GetMklShape(context, 0, &mkl_params.input_shape);
+ void* user_i = static_cast<void*>(const_cast<T*>(input.flat<T>().data()));
+ bool input_in_mkl_format = mkl_params.input_shape.IsMklTensor();
+ if (!input_in_mkl_format && !input.dims()) { // handle the case of a scalar
+ const TensorShape& o_shape = input.shape();
+ Tensor* out_tensor = nullptr;
+ mkl_params.output_shape.SetMklTensor(false);
+ AllocateOutputSetMklshape(context, 0, &out_tensor, o_shape,
+ mkl_params.output_shape);
+ void* out_o = static_cast<void*>(out_tensor->flat<T>().data());
+ (static_cast<T*>(out_o))[0] =
+ std::max((static_cast<T*>(user_i))[0], static_cast<T>(0));
+ return;
+ }
+
+ // Generate size, stride for input if input is in MKL format.
+ if (input_in_mkl_format) {
+ mkl_params.in_dims = mkl_params.input_shape.GetDimension();
+ mkl_params.in_sizes = new size_t[mkl_params.in_dims];
+ mkl_params.in_strides = new size_t[mkl_params.in_dims];
+ for (int i = 0; i < mkl_params.in_dims; i++) {
+ mkl_params.in_sizes[i] = mkl_params.input_shape.GetSizes()[i];
+ mkl_params.in_strides[i] = mkl_params.input_shape.GetStrides()[i];
+ }
+ } else {
+ mkl_params.in_dims = input.dims();
+ mkl_params.in_sizes = new size_t[mkl_params.in_dims];
+ mkl_params.in_strides = new size_t[mkl_params.in_dims];
+ for (int i = 0; i < mkl_params.in_dims; i++) {
+ mkl_params.in_sizes[i] = input.dim_size((mkl_params.in_dims - 1) - i);
+ }
+ mkl_params.in_strides[0] = 1;
+ for (int i = 1; i < mkl_params.in_dims; i++) {
+ mkl_params.in_strides[i] =
+ mkl_params.in_strides[i - 1] * mkl_params.in_sizes[i - 1];
+ }
+ }
+
+ float negative_slope = 0.0;
+ MklCreateInputLayouts(context);
+ CHECK_EQ(dnnReLUCreateForward_F32(&mkl_prim_relu_fwd_, NULL, mkl_lt_input_,
+ negative_slope),
+ E_SUCCESS);
+
+ Tensor* output = nullptr;
+
+ if (input_in_mkl_format) {
+ TensorShape tf_shape;
+ mkl_params.output_shape.SetMklTensor(true);
+ mkl_params.output_shape.SetMklLayout(mkl_prim_relu_fwd_, dnnResourceDst);
+ mkl_params.output_shape.SetTfLayout(
+ mkl_params.in_dims, mkl_params.in_sizes, mkl_params.in_strides);
+ tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+ mkl_params.output_shape.GetMklLayout())) /
+ sizeof(T));
+ AllocateOutputSetMklshape(context, 0, &output, tf_shape,
+ mkl_params.output_shape);
+ } else {
+ const TensorShape& o_shape = input.shape();
+ mkl_params.output_shape.SetMklTensor(false);
+ AllocateOutputSetMklshape(context, 0, &output, o_shape,
+ mkl_params.output_shape);
+ }
+
+ void* user_o = static_cast<void*>(const_cast<T*>(output->flat<T>().data()));
+
+ relu_res[dnnResourceDst] = user_o;
+ relu_res[dnnResourceSrc] = user_i;
+ CHECK_EQ(dnnExecute_F32(mkl_prim_relu_fwd_, relu_res), E_SUCCESS);
+ Mklcleanup();
+ }
+
+ private:
+ typedef struct {
+ int in_dims;
+ size_t* in_sizes;
+ size_t* in_strides;
+ MklShape input_shape, output_shape;
+ } MklReluOpParams_;
+
+ void Mklcleanup() {
+ bool input_in_mkl_format = mkl_params.input_shape.IsMklTensor();
+ if (!input_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_input_);
+ dnnDelete_F32(mkl_prim_relu_fwd_);
+ }
+ void MklCreateInputLayouts(OpKernelContext* context) {
+ bool input_in_mkl_format = mkl_params.input_shape.IsMklTensor();
+ if (!input_in_mkl_format) {
+ CHECK_EQ(dnnLayoutCreate_F32(&mkl_lt_input_, mkl_params.in_dims,
+ mkl_params.in_sizes, mkl_params.in_strides),
+ E_SUCCESS);
+ } else {
+ mkl_lt_input_ =
+ static_cast<dnnLayout_t>(mkl_params.input_shape.GetCurLayout());
+ }
+ }
+
+ dnnPrimitive_t mkl_prim_relu_fwd_ = nullptr;
+ MklReluOpParams_ mkl_params;
+ void* relu_res[dnnResourceNumber];
+ dnnLayout_t mkl_lt_input_ = nullptr;
+};
+
+template <typename Device, typename T>
+class MklReluGradOp : public OpKernel {
+ public:
+ ~MklReluGradOp() {}
+
+ explicit MklReluGradOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override;
+
+ private:
+ typedef struct {
+ int in_dims;
+ size_t* in_sizes;
+ size_t* in_strides;
+ MklShape input_shape, grad_shape, output_shape;
+ } MklReluGradOpParams_;
+ MklReluGradOpParams_ mkl_params;
+
+ void MklPrepareReluGradInputs(OpKernelContext* context,
+ Tensor* mkl_tmp_grad_buf_tensor,
+ Tensor* mkl_tmp_input_buf_tensor) {
+ dnnPrimitive_t cv_user_to_reluB_input = nullptr,
+ cv_user_to_reluB_grad = nullptr;
+ dnnLayout_t mkl_lt_internal_input = nullptr, mkl_lt_internal_grad = nullptr;
+
+ const Tensor& g = MklGetInput(context, 0);
+ const Tensor& a = MklGetInput(context, 1);
+
+ void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
+ void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
+
+ CHECK_EQ(
+ dnnLayoutCreateFromPrimitive_F32(
+ &mkl_lt_internal_grad, mkl_prim_relu_back_, dnnResourceDiffDst),
+ E_SUCCESS);
+
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
+ &mkl_lt_internal_input, mkl_prim_relu_back_, dnnResourceSrc),
+ E_SUCCESS);
+
+ if (!dnnLayoutCompare_F32(mkl_lt_internal_grad, mkl_lt_grad_)) {
+ AllocTmpBuffer(context, mkl_tmp_grad_buf_tensor, mkl_lt_internal_grad,
+ &relu_res[dnnResourceDiffDst]);
+ CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_grad, mkl_lt_grad_,
+ mkl_lt_internal_grad),
+ E_SUCCESS);
+ }
+
+ if (!dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_input_)) {
+ AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
+ &relu_res[dnnResourceSrc]);
+ CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_input, mkl_lt_input_,
+ mkl_lt_internal_input),
+ E_SUCCESS);
+ }
+ if (cv_user_to_reluB_input) {
+ CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_input, user_i,
+ relu_res[dnnResourceSrc]),
+ E_SUCCESS);
+ } else {
+ relu_res[dnnResourceSrc] = user_i;
+ }
+ if (cv_user_to_reluB_input) dnnDelete_F32(cv_user_to_reluB_input);
+
+ dnnLayoutDelete_F32(mkl_lt_internal_input);
+ if (cv_user_to_reluB_grad) {
+ CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_grad, user_g,
+ relu_res[dnnResourceDiffDst]),
+ E_SUCCESS);
+ } else {
+ relu_res[dnnResourceDiffDst] = user_g;
+ }
+
+ if (cv_user_to_reluB_grad) dnnDelete_F32(cv_user_to_reluB_grad);
+ dnnLayoutDelete_F32(mkl_lt_internal_grad);
+ }
+
+ void MklCreateInputLayouts(OpKernelContext* context) {
+ bool grad_is_mkl = mkl_params.grad_shape.IsMklTensor();
+ bool input_is_mkl = mkl_params.input_shape.IsMklTensor();
+ if (!input_is_mkl) {
+ CHECK_EQ(dnnLayoutCreate_F32(&mkl_lt_input_, mkl_params.in_dims,
+ mkl_params.in_sizes, mkl_params.in_strides),
+ E_SUCCESS);
+ } else {
+ mkl_lt_input_ =
+ static_cast<dnnLayout_t>(mkl_params.input_shape.GetCurLayout());
+ }
+
+ if (!grad_is_mkl) {
+ CHECK_EQ(dnnLayoutCreate_F32(&mkl_lt_grad_, mkl_params.in_dims,
+ mkl_params.in_sizes, mkl_params.in_strides),
+ E_SUCCESS);
+ } else {
+ mkl_lt_grad_ =
+ static_cast<dnnLayout_t>(mkl_params.grad_shape.GetCurLayout());
+ }
+ }
+
+ void MklCleanup() {
+ bool grad_is_mkl = mkl_params.grad_shape.IsMklTensor();
+ bool input_is_mkl = mkl_params.input_shape.IsMklTensor();
+ dnnDelete_F32(mkl_prim_relu_back_);
+ if (!input_is_mkl) {
+ dnnLayoutDelete_F32(mkl_lt_input_);
+ }
+ if (!grad_is_mkl) {
+ dnnLayoutDelete_F32(mkl_lt_grad_);
+ }
+ }
+ void* relu_res[dnnResourceNumber];
+ dnnPrimitive_t mkl_prim_relu_back_ = nullptr;
+ dnnLayout_t mkl_lt_input_, mkl_lt_grad_;
+};
+
+template <typename Device, typename T>
+
+void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
+ const Tensor& g = MklGetInput(context, 0);
+ const Tensor& a = MklGetInput(context, 1);
+
+ void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
+ void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
+
+ GetMklShape(context, 0, &mkl_params.grad_shape);
+ GetMklShape(context, 1, &mkl_params.input_shape);
+
+ bool grad_is_mkl = mkl_params.grad_shape.IsMklTensor();
+ bool input_is_mkl = mkl_params.input_shape.IsMklTensor();
+ if (!input_is_mkl && !grad_is_mkl &&
+ !MklReluHelpers::ValidateSameSize(context, g, a))
+ return;
+ Tensor* output = nullptr;
+ if (!input_is_mkl && !grad_is_mkl &&
+ !a.dims()) { // handle the case of a scalar
+ // Allocate space for g and
+ const TensorShape& g_shape = g.shape();
+ mkl_params.output_shape.SetMklTensor(false);
+ AllocateOutputSetMklshape(context, 0, &output, g_shape,
+ mkl_params.output_shape);
+ void* out_o = static_cast<void*>(output->flat<T>().data());
+ (static_cast<T*>(out_o))[0] =
+ (static_cast<T*>(user_g))[0] * ((static_cast<T*>(user_i))[0] > 0);
+ return;
+ }
+
+ // Generate size, stride for input if input/grad is in MKL format.
+ if (grad_is_mkl || input_is_mkl) {
+ const MklShape* tmp_mkl_shape =
+ (grad_is_mkl) ? &mkl_params.grad_shape : &mkl_params.input_shape;
+
+ mkl_params.in_dims = tmp_mkl_shape->GetDimension();
+ mkl_params.in_strides = new size_t[mkl_params.in_dims];
+ mkl_params.in_sizes = new size_t[mkl_params.in_dims];
+ for (int i = 0; i < mkl_params.in_dims; i++) {
+ mkl_params.in_sizes[i] = tmp_mkl_shape->GetSizes()[i];
+ mkl_params.in_strides[i] = tmp_mkl_shape->GetStrides()[i];
+ }
+ } else {
+ mkl_params.in_dims = g.dims();
+ mkl_params.in_strides = new size_t[mkl_params.in_dims];
+ mkl_params.in_sizes = new size_t[mkl_params.in_dims];
+
+ for (int i = 0; i < mkl_params.in_dims; i++) {
+ mkl_params.in_sizes[i] = g.dim_size((mkl_params.in_dims - 1) - i);
+ }
+ mkl_params.in_strides[0] = 1;
+ for (int i = 1; i < mkl_params.in_dims; i++) {
+ mkl_params.in_strides[i] =
+ mkl_params.in_strides[i - 1] * mkl_params.in_sizes[i - 1];
+ }
+ }
+
+ MklCreateInputLayouts(context);
+ float negative_slope = 0.0;
+ CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_prim_relu_back_, NULL, mkl_lt_grad_,
+ mkl_lt_input_, negative_slope),
+ E_SUCCESS);
+ Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor;
+ MklPrepareReluGradInputs(context, &mkl_tmp_grad_buf_tensor,
+ &mkl_tmp_input_buf_tensor);
+
+ if (input_is_mkl ||
+ grad_is_mkl) { /*if grad or input are MKL leave it in MKL*/
+ TensorShape tf_shape;
+ mkl_params.output_shape.SetMklTensor(true);
+ mkl_params.output_shape.SetMklLayout(mkl_prim_relu_back_,
+ dnnResourceDiffSrc);
+ mkl_params.output_shape.SetTfLayout(mkl_params.in_dims, mkl_params.in_sizes,
+ mkl_params.in_strides);
+ tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+ mkl_params.output_shape.GetMklLayout())) /
+ sizeof(T));
+ AllocateOutputSetMklshape(context, 0, &output, tf_shape,
+ mkl_params.output_shape);
+
+ } else {
+ const TensorShape& o_shape = g.shape();
+ mkl_params.output_shape.SetMklTensor(false);
+ AllocateOutputSetMklshape(context, 0, &output, o_shape,
+ mkl_params.output_shape);
+ }
+
+ relu_res[dnnResourceDiffSrc] = static_cast<void*>(output->flat<T>().data());
+
+ CHECK_EQ(dnnExecute_F32(mkl_prim_relu_back_, relu_res), E_SUCCESS);
+ MklCleanup();
+}
+
+/* Register DNN kernels for supported operations and supported types - right now
+ * it is only Relu and f32*/
+#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
+ REGISTER_KERNEL_BUILDER(Name("MklRelu") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .Label(mkl_layer_registry::kMklLayerLabel), \
+ MklReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER(Name("MklReluGrad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .Label(mkl_layer_registry::kMklLayerLabel), \
+ MklReluGradOp<CPUDevice, type>);
+TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES);
+
+} // namespace tensorflow
+
+#endif // INTEL_MKL
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 242b6b789e..3f87fb4ab4 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -144,6 +144,74 @@ output: 4-D. Gradients w.r.t. the input of `avg_pool`.
// --------------------------------------------------------------------------
+#ifdef INTEL_MKL
+
+REGISTER_OP("MklAvgPool")
+ .Input("value: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Attr("T: {float, half, double}")
+ .SetShapeFn(shape_inference::AvgPoolShape)
+ .Doc(R"doc(
+MKL version of AvgPool
+Performs average pooling on the input.
+
+Each entry in `output` is the mean of the corresponding size `ksize`
+window in `value`.
+
+value: 4-D with shape `[batch, height, width, channels]`.
+ksize: The size of the sliding window for each dimension of `value`.
+strides: The stride of the sliding window for each dimension of `value`.
+padding: The type of padding algorithm to use.
+data_format: Specify the data format of the input and output data. With the
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+output: The average pooled output tensor.
+)doc");
+
+REGISTER_OP("MklAvgPoolGrad")
+ .Input("orig_input_shape: int32")
+ .Input("mkl_orig_input: uint8")
+ .Input("grad: T")
+ .Input("mkl_grad: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Attr("T: {float, half, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ })
+ .Doc(R"doc(
+MKL version of AvgPoolGrad
+Computes gradients of the average pooling function.
+
+orig_input_shape: 1-D. Shape of the original input to `avg_pool`.
+grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t.
+ the output of `avg_pool`.
+ksize: The size of the sliding window for each dimension of the input.
+strides: The stride of the sliding window for each dimension of the input.
+padding: The type of padding algorithm to use.
+data_format: Specify the data format of the input and output data. With the
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+output: 4-D. Gradients w.r.t. the input of `avg_pool`.
+)doc");
+
+#endif // INTEL_MKL
+// --------------------------------------------------------------------------
+
REGISTER_OP("BatchNormWithGlobalNormalization")
.Input("t: T")
.Input("m: T")
@@ -1327,6 +1395,39 @@ input: 4-D input to pool over.
output: The max pooled output tensor.
)doc");
+#ifdef INTEL_MKL
+REGISTER_OP("MklMaxPool")
+ .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Attr("workspace_enabled: bool = false")
+ .Input("input: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Output("workspace: T")
+ .Output("mkl_workspace: uint8")
+ .SetShapeFn(shape_inference::MaxPoolShape)
+ .Doc(R"doc(
+MKL version of MaxPool
+Performs max pooling on the input.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+data_format: Specify the data format of the input and output data. With the
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+input: 4-D input to pool over.
+output: The max pooled output tensor.
+)doc");
+#endif
+
REGISTER_OP("MaxPoolGrad")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
@@ -1358,6 +1459,47 @@ grad: 4-D. Gradients w.r.t. the output of `max_pool`.
output: Gradients w.r.t. the input to `max_pool`.
)doc");
+#ifdef INTEL_MKL
+REGISTER_OP("MklMaxPoolGrad")
+ .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr("workspace_enabled: bool = false")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Input("orig_input: T")
+ .Input("mkl_orig_input: uint8")
+ .Input("orig_output: T")
+ .Input("mkl_orig_output: uint8")
+ .Input("grad: T")
+ .Input("mkl_grad: uint8")
+ .Input("workspace: T")
+ .Input("mkl_workspace: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .SetShapeFn([](InferenceContext* c) {
+ return UnchangedShapeWithRank(c, 4);
+ })
+ .Doc(R"doc(
+MKL version of MaxPoolGrad
+Computes gradients of the maxpooling function.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+data_format: Specify the data format of the input and output data. With the
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+orig_input: The original input tensor.
+orig_output: The original output tensor.
+grad: 4-D. Gradients w.r.t. the output of `max_pool`.
+output: Gradients w.r.t. the input to `max_pool`.
+)doc");
+#endif
+
REGISTER_OP("MaxPoolWithArgmax")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
@@ -1597,6 +1739,19 @@ REGISTER_OP("Relu")
Computes rectified linear: `max(features, 0)`.
)doc");
+#ifdef INTEL_MKL
+REGISTER_OP("MklRelu")
+ .Input("features: T")
+ .Input("mkl_features: uint8")
+ .Output("activations: T")
+ .Output("mkl_activations: uint8")
+ .Attr("T: realnumbertype")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Computes rectified linear: `max(features, 0)`.
+)doc");
+#endif
+
REGISTER_OP("ReluGrad")
.Input("gradients: T")
.Input("features: T")
@@ -1612,6 +1767,26 @@ features: The features passed as input to the corresponding Relu operation, OR
backprops: `gradients * (features > 0)`.
)doc");
+#ifdef INTEL_MKL
+REGISTER_OP("MklReluGrad")
+ .Input("gradients: T")
+ .Input("mkl_gradients: uint8")
+ .Input("features: T")
+ .Input("mkl_features: uint8")
+ .Output("backprops: T")
+ .Output("mkl_backprops: uint8")
+ .Attr("T: realnumbertype")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
+ .Doc(R"doc(
+Computes rectified linear gradients for a Relu operation.
+
+gradients: The backpropagated gradients to the corresponding Relu operation.
+features: The features passed as input to the corresponding Relu operation, OR
+ the outputs of that operation (both work equivalently).
+backprops: `gradients * (features > 0)`.
+)doc");
+#endif
+
REGISTER_OP("Relu6")
.Input("features: T")
.Output("activations: T")