aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_input_conversion_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_input_conversion_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc52
1 files changed, 25 insertions, 27 deletions
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
index d91f7107c5..68d3e1c9ab 100644
--- a/tensorflow/core/kernels/mkl_input_conversion_op.cc
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -263,21 +263,18 @@ class MklInputConversionOp : public OpKernel {
private:
void Compute(OpKernelContext* context) override {
- const Tensor& input_tensor_0 = MklGetInput(context, 0);
+ const int kInputIndex_0 = 0, kInputIndex_1 = 1;
+ const Tensor& input_tensor_0 = MklGetInput(context, kInputIndex_0);
MklDnnShape input_shape_0;
- GetMklShape(context, 0, &input_shape_0);
+ GetMklShape(context, kInputIndex_0, &input_shape_0);
- const Tensor& input_tensor_1 = MklGetInput(context, 1);
+ const Tensor& input_tensor_1 = MklGetInput(context, kInputIndex_1);
MklDnnShape input_shape_1;
- GetMklShape(context, 1, &input_shape_1);
-
- bool tf_shapes_are_same =
- context->input(0).shape() == context->input(1).shape();
+ GetMklShape(context, kInputIndex_1, &input_shape_1);
- VLOG(1) << "MklInputConversionOp: Input shapes are "
- << (tf_shapes_are_same ? "*same*" : "*different*") << ": "
- << context->input(0).shape().DebugString() << " and "
- << context->input(1).shape().DebugString();
+ VLOG(1) << "MklInputConversionOp: Input shapes are: "
+ << context->input(kInputIndex_0).shape().DebugString() << " and "
+ << context->input(kInputIndex_1).shape().DebugString();
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// if both inputs are in TF format, just copy input tensors to output.
@@ -285,15 +282,19 @@ class MklInputConversionOp : public OpKernel {
VLOG(1) << "MklInputConversionOp: No conversion needed, "
<< "copying TF inputs to output";
- ForwardTfTensorInToOut(context, 0, 0);
- ForwardTfTensorInToOut(context, 1, 1);
+ ForwardTfTensorInToOut(context, kInputIndex_0, kInputIndex_0);
+ ForwardTfTensorInToOut(context, kInputIndex_1, kInputIndex_1);
return;
}
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// If both inputs are in MKL format
if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
- if (tf_shapes_are_same) {
+ // It is safer to compare the original TensorFlow shapes than to compare
+ // Mkl shapes since element wise ops are forwarded to Eigen implementation.
+ TensorShape tf_shape0 = input_shape_0.GetTfShape();
+ TensorShape tf_shape1 = input_shape_1.GetTfShape();
+ if (tf_shape0 == tf_shape1) {
auto input0_md = input_shape_0.GetMklLayout();
auto input1_md = input_shape_1.GetMklLayout();
@@ -302,8 +303,8 @@ class MklInputConversionOp : public OpKernel {
VLOG(1) << "MklInputConversionOp: No conversion needed, "
<< "copying MKL inputs with identical shapes to output";
- ForwardMklTensorInToOut(context, 0, 0);
- ForwardMklTensorInToOut(context, 1, 1);
+ ForwardMklTensorInToOut(context, kInputIndex_0, kInputIndex_0);
+ ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
return;
} else {
VLOG(1) << "MklInputConversionOp: Shape is same, but format is "
@@ -324,7 +325,7 @@ class MklInputConversionOp : public OpKernel {
mkl_output_mkl_shape.SetMklLayout(&input1_md);
// Create output Mkl tensor for index 0
- AllocateOutputSetMklShape(context, 0, &tensor_out,
+ AllocateOutputSetMklShape(context, kInputIndex_0, &tensor_out,
input_tensor_0.shape(),
mkl_output_mkl_shape);
@@ -342,7 +343,7 @@ class MklInputConversionOp : public OpKernel {
stream(stream::kind::eager).submit(net).wait();
// Input1 will be passed through
- ForwardMklTensorInToOut(context, 1, 1);
+ ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
return;
}
}
@@ -361,11 +362,11 @@ class MklInputConversionOp : public OpKernel {
<< "converted MKL inputs to TF format";
MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
- op_data_type, has_avx512f_, 0);
+ op_data_type, has_avx512f_, kInputIndex_0);
MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
- op_data_type, has_avx512f_, 1);
- SetDummyMklShapeOutput(context, 0);
- SetDummyMklShapeOutput(context, 1);
+ op_data_type, has_avx512f_, kInputIndex_1);
+ SetDummyMklShapeOutput(context, kInputIndex_0);
+ SetDummyMklShapeOutput(context, kInputIndex_1);
return;
}
@@ -377,7 +378,6 @@ class MklInputConversionOp : public OpKernel {
const Tensor* mkl_tensor;
const MklDnnShape* mkl_shape;
const Tensor* tf_tensor;
- MklDnnShape* tf_mkl_shape;
uint mkl_tensor_index;
uint tf_tensor_index;
if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
@@ -385,14 +385,12 @@ class MklInputConversionOp : public OpKernel {
mkl_shape = &input_shape_0;
mkl_tensor_index = 0;
tf_tensor = &input_tensor_1;
- tf_mkl_shape = &input_shape_1;
tf_tensor_index = 1;
} else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
mkl_tensor = &input_tensor_1;
mkl_shape = &input_shape_1;
mkl_tensor_index = 1;
tf_tensor = &input_tensor_0;
- tf_mkl_shape = &input_shape_0;
tf_tensor_index = 0;
} else {
CHECK(false) << "MklInputConversionOp: Unexpected combination of input "
@@ -466,8 +464,8 @@ class MklInputConversionOp : public OpKernel {
}
VLOG(1) << "MklInputConversionOp: Shapes (output): "
- << context->mutable_output(0)->shape().DebugString() << " and "
- << context->mutable_output(1)->shape().DebugString();
+ << context->mutable_output(kInputIndex_0)->shape().DebugString() << " and "
+ << context->mutable_output(kInputIndex_1)->shape().DebugString();
VLOG(1) << "MklInputConversion completed successfully.";
}