diff options
Diffstat (limited to 'tensorflow/core/ops/nn_ops.cc')
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index eee9961b28..e56b27b0c0 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -2502,4 +2502,45 @@ scale_after_normalization: A bool indicating whether the resulted tensor needs to be multiplied with gamma. )doc"); +#ifdef INTEL_MKL +REGISTER_OP("MklConv2D") + .Input("input: T") + .Input("mkl_input: uint8") + .Input("filter: T") + .Input("mkl_filter: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: {half, float, double}") + .Attr("strides: list(int)") + .Attr("use_cudnn_on_gpu: bool = true") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .SetShapeFn(shape_inference::Conv2DShape) + .Doc(R"doc( +MKL version of Conv2D +)doc"); + +REGISTER_OP("MklConv2DWithBias") + .Input("input: T") + .Input("mkl_input: uint8") + .Input("filter: T") + .Input("mkl_filter: uint8") + .Input("bias: T") + .Input("mkl_bias: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: {half, float, double}") + .Attr("strides: list(int)") + .Attr("use_cudnn_on_gpu: bool = true") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()); + +REGISTER_OP("MklToTf") + .Input("input: T") + .Input("mkl_input: uint8") + .Output("output: T") + .Attr("T: {half, float, double}") + .Attr(GetConvnetDataFormatAttrString()); +#endif // INTEL_MKL + } // namespace tensorflow |