aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_tfconv_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_tfconv_op.h')
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.h80
1 files changed, 75 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h
index a240ee44fb..0a5be4fec9 100644
--- a/tensorflow/core/kernels/mkl_tfconv_op.h
+++ b/tensorflow/core/kernels/mkl_tfconv_op.h
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
-
#ifndef TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_
#define TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_
+#ifdef INTEL_MKL
+
#include <algorithm>
#include <vector>
#include "tensorflow/core/framework/numeric_op.h"
@@ -35,6 +35,10 @@ limitations under the License.
#include "mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
+#ifdef INTEL_MKL_DNN
+using mkldnn::stream;
+#endif
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -57,6 +61,71 @@ class MklToTfOp : public OpKernel {
VLOG(1) << "MKLToTFConversion complete successfully.";
}
+#ifdef INTEL_MKL_DNN
+ static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context,
+ string data_format_str, DataType op_data_type,
+ bool has_avx512f, uint input_number) {
+ try {
+ // Check that input tensor is in MKL format.
+ const Tensor& input_tensor = MklGetInput(context, input_number);
+ MklDnnShape input_shape;
+ GetMklShape(context, input_number, &input_shape);
+
+ // if input is already in Tf format, then copy input tensor to output.
+ if (!input_shape.IsMklTensor()) {
+ context->set_output(input_number, input_tensor);
+ VLOG(1) << "MKLToTFConversion: No conversion needed, "
+ << "copying input to output";
+ return;
+ }
+
+ // Check that input data type is same as operator data type and that it
+ // is same as output data type.
+ DataType input_data_type = op_kernel->input_type(input_number);
+ DataType output_data_type = op_kernel->output_type(input_number);
+ CHECK_EQ(op_data_type, input_data_type);
+ CHECK_EQ(op_data_type, output_data_type);
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> input(&cpu_engine);
+
+ // Get Mkl layout of input tensor.
+ auto input_mkl_md = input_shape.GetMklLayout();
+ // Get TensorFlow layout of input tensor. Expected output of conversion
+ // has same layout as Tensorflow layout of input tensor.
+ auto output_tf_md = input_shape.GetTfLayout();
+ auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
+ // Set input Mkl layout as the user layout.
+ input.SetUsrMem(input_mkl_md, &input_tensor);
+
+ // Allocate output tensor.
+ TensorShape output_shape = input_shape.GetTfShape();
+ Tensor* output_tensor = NULL;
+ OP_REQUIRES_OK(context, context->allocate_output(input_number,
+ output_shape, &output_tensor));
+ CHECK_NOTNULL(output_tensor);
+
+ // Do we need to reorder Mkl layout into TensorFlow layout?
+ if (input.IsReorderNeeded(output_tf_pd)) {
+ // Insert reorder between Mkl layout and TensorFlow layout.
+ std::vector<primitive> net;
+ CHECK_EQ(input.CheckReorderToOpMem(output_tf_pd, output_tensor, &net),
+ true);
+ stream(stream::kind::eager).submit(net).wait();
+ } else {
+ // If not, just forward input tensor to output tensor.
+ CHECK(output_tensor->CopyFrom(input_tensor, output_shape));
+ }
+ } catch (mkldnn::error &e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + std::string(e.message) +
+ ", in file " + std::string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+#else
static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context,
string data_format_str, DataType op_data_type,
bool has_avx512f, uint input_number) {
@@ -91,8 +160,8 @@ class MklToTfOp : public OpKernel {
// Allocate output tensor.
Tensor* output_tensor = NULL;
- OP_REQUIRES_OK(context,
- context->allocate_output(input_number, output_shape, &output_tensor));
+ OP_REQUIRES_OK(context, context->allocate_output(input_number,
+ output_shape, &output_tensor));
dnnLayout_t output_layout =
static_cast<dnnLayout_t>(input_shape.GetTfLayout());
@@ -106,6 +175,7 @@ class MklToTfOp : public OpKernel {
output_buffer);
VLOG(1) << "MKLToTFConversion complete successfully.";
}
+#endif
private:
/// Data format of the operation
@@ -132,5 +202,5 @@ class MklToTfOp : public OpKernel {
TF_CALL_NUMBER_TYPES(REGISTER_CPU);
#undef REGISTER_CPU
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_
#endif // INTEL_MKL
+#endif // TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_