aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/nn_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/nn_ops.cc')
-rw-r--r--tensorflow/core/ops/nn_ops.cc41
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index eee9961b28..e56b27b0c0 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -2502,4 +2502,45 @@ scale_after_normalization: A bool indicating whether the resulted tensor
needs to be multiplied with gamma.
)doc");
+#ifdef INTEL_MKL
+REGISTER_OP("MklConv2D")
+ .Input("input: T")
+ .Input("mkl_input: uint8")
+ .Input("filter: T")
+ .Input("mkl_filter: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int)")
+ .Attr("use_cudnn_on_gpu: bool = true")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .SetShapeFn(shape_inference::Conv2DShape)
+ .Doc(R"doc(
+MKL version of Conv2D
+)doc");
+
+REGISTER_OP("MklConv2DWithBias")
+ .Input("input: T")
+ .Input("mkl_input: uint8")
+ .Input("filter: T")
+ .Input("mkl_filter: uint8")
+ .Input("bias: T")
+ .Input("mkl_bias: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int)")
+ .Attr("use_cudnn_on_gpu: bool = true")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString());
+
+REGISTER_OP("MklToTf")
+ .Input("input: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Attr("T: {half, float, double}")
+ .Attr(GetConvnetDataFormatAttrString());
+#endif // INTEL_MKL
+
} // namespace tensorflow