diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_input_conversion_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_input_conversion_op.cc | 52 |
1 files changed, 25 insertions, 27 deletions
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc index d91f7107c5..68d3e1c9ab 100644 --- a/tensorflow/core/kernels/mkl_input_conversion_op.cc +++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc @@ -263,21 +263,18 @@ class MklInputConversionOp : public OpKernel { private: void Compute(OpKernelContext* context) override { - const Tensor& input_tensor_0 = MklGetInput(context, 0); + const int kInputIndex_0 = 0, kInputIndex_1 = 1; + const Tensor& input_tensor_0 = MklGetInput(context, kInputIndex_0); MklDnnShape input_shape_0; - GetMklShape(context, 0, &input_shape_0); + GetMklShape(context, kInputIndex_0, &input_shape_0); - const Tensor& input_tensor_1 = MklGetInput(context, 1); + const Tensor& input_tensor_1 = MklGetInput(context, kInputIndex_1); MklDnnShape input_shape_1; - GetMklShape(context, 1, &input_shape_1); - - bool tf_shapes_are_same = - context->input(0).shape() == context->input(1).shape(); + GetMklShape(context, kInputIndex_1, &input_shape_1); - VLOG(1) << "MklInputConversionOp: Input shapes are " - << (tf_shapes_are_same ? "*same*" : "*different*") << ": " - << context->input(0).shape().DebugString() << " and " - << context->input(1).shape().DebugString(); + VLOG(1) << "MklInputConversionOp: Input shapes are: " + << context->input(kInputIndex_0).shape().DebugString() << " and " + << context->input(kInputIndex_1).shape().DebugString(); // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // if both inputs are in TF format, just copy input tensors to output. @@ -285,15 +282,19 @@ class MklInputConversionOp : public OpKernel { VLOG(1) << "MklInputConversionOp: No conversion needed, " << "copying TF inputs to output"; - ForwardTfTensorInToOut(context, 0, 0); - ForwardTfTensorInToOut(context, 1, 1); + ForwardTfTensorInToOut(context, kInputIndex_0, kInputIndex_0); + ForwardTfTensorInToOut(context, kInputIndex_1, kInputIndex_1); return; } // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // If both inputs are in MKL format if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { - if (tf_shapes_are_same) { + // It is safer to compare the original TensorFlow shapes than to compare + // Mkl shapes since element wise ops are forwarded to Eigen implementation. + TensorShape tf_shape0 = input_shape_0.GetTfShape(); + TensorShape tf_shape1 = input_shape_1.GetTfShape(); + if (tf_shape0 == tf_shape1) { auto input0_md = input_shape_0.GetMklLayout(); auto input1_md = input_shape_1.GetMklLayout(); @@ -302,8 +303,8 @@ class MklInputConversionOp : public OpKernel { VLOG(1) << "MklInputConversionOp: No conversion needed, " << "copying MKL inputs with identical shapes to output"; - ForwardMklTensorInToOut(context, 0, 0); - ForwardMklTensorInToOut(context, 1, 1); + ForwardMklTensorInToOut(context, kInputIndex_0, kInputIndex_0); + ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1); return; } else { VLOG(1) << "MklInputConversionOp: Shape is same, but format is " @@ -324,7 +325,7 @@ class MklInputConversionOp : public OpKernel { mkl_output_mkl_shape.SetMklLayout(&input1_md); // Create output Mkl tensor for index 0 - AllocateOutputSetMklShape(context, 0, &tensor_out, + AllocateOutputSetMklShape(context, kInputIndex_0, &tensor_out, input_tensor_0.shape(), mkl_output_mkl_shape); @@ -342,7 +343,7 @@ class MklInputConversionOp : public OpKernel { stream(stream::kind::eager).submit(net).wait(); // Input1 will be passed through - ForwardMklTensorInToOut(context, 1, 1); + ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1); return; } } @@ -361,11 +362,11 @@ class MklInputConversionOp : public OpKernel { << "converted MKL inputs to TF format"; MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, - op_data_type, has_avx512f_, 0); + op_data_type, has_avx512f_, kInputIndex_0); MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, - op_data_type, has_avx512f_, 1); - SetDummyMklShapeOutput(context, 0); - SetDummyMklShapeOutput(context, 1); + op_data_type, has_avx512f_, kInputIndex_1); + SetDummyMklShapeOutput(context, kInputIndex_0); + SetDummyMklShapeOutput(context, kInputIndex_1); return; } @@ -377,7 +378,6 @@ class MklInputConversionOp : public OpKernel { const Tensor* mkl_tensor; const MklDnnShape* mkl_shape; const Tensor* tf_tensor; - MklDnnShape* tf_mkl_shape; uint mkl_tensor_index; uint tf_tensor_index; if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { @@ -385,14 +385,12 @@ class MklInputConversionOp : public OpKernel { mkl_shape = &input_shape_0; mkl_tensor_index = 0; tf_tensor = &input_tensor_1; - tf_mkl_shape = &input_shape_1; tf_tensor_index = 1; } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { mkl_tensor = &input_tensor_1; mkl_shape = &input_shape_1; mkl_tensor_index = 1; tf_tensor = &input_tensor_0; - tf_mkl_shape = &input_shape_0; tf_tensor_index = 0; } else { CHECK(false) << "MklInputConversionOp: Unexpected combination of input " @@ -466,8 +464,8 @@ class MklInputConversionOp : public OpKernel { } VLOG(1) << "MklInputConversionOp: Shapes (output): " - << context->mutable_output(0)->shape().DebugString() << " and " - << context->mutable_output(1)->shape().DebugString(); + << context->mutable_output(kInputIndex_0)->shape().DebugString() << " and " + << context->mutable_output(kInputIndex_1)->shape().DebugString(); VLOG(1) << "MklInputConversion completed successfully."; } |