aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_input_conversion_op.cc
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-02-07 14:36:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 14:39:49 -0800
commitd90054e7c0f41f4bab81df0548577a73b939a87a (patch)
treea15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/core/kernels/mkl_input_conversion_op.cc
parent8461760f9f6cde8ed97507484d2a879140141032 (diff)
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/core/kernels/mkl_input_conversion_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc62
1 files changed, 53 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
index 73d41efce1..5a8799ae93 100644
--- a/tensorflow/core/kernels/mkl_input_conversion_op.cc
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/kernels/mkl_tfconv_op.h"
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::stream;
@@ -59,7 +59,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
// convert the TF format input to MKL format
///////////////////////////////////////////////////////////
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
template <typename Device, typename T>
class MklInputConversionOp : public OpKernel {
public:
@@ -293,14 +293,58 @@ class MklInputConversionOp : public OpKernel {
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// If both inputs are in MKL format
if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
- // If both have the same shape, pass them through
if (tf_shapes_are_same) {
- VLOG(1) << "MklInputConversionOp: No conversion needed, "
- << "copying MKL inputs with identical shapes to output";
-
- ForwardMklTensorInToOut(context, 0, 0);
- ForwardMklTensorInToOut(context, 1, 1);
- return;
+ auto input0_md = input_shape_0.GetMklLayout();
+ auto input1_md = input_shape_1.GetMklLayout();
+
+ // If both have the same shape and same format, pass them through
+ if (input0_md.data.format == input1_md.data.format) {
+ VLOG(1) << "MklInputConversionOp: No conversion needed, "
+ << "copying MKL inputs with identical shapes to output";
+
+ ForwardMklTensorInToOut(context, 0, 0);
+ ForwardMklTensorInToOut(context, 1, 1);
+ return;
+ } else {
+ VLOG(1) << "MklInputConversionOp: Shape is same, but format is "
+ "different, "
+ << "need to convert to same format";
+
+ // Convert input0, and keep input1 unchanged
+ // Create MklDnnShape for output mkl tensor based on input0
+ Tensor* tensor_out;
+ MklDnnShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(true);
+ mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
+ mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(),
+ input_shape_0.GetSizesAsMklDnnDims(),
+ input_shape_0.GetTfDataFormat());
+
+ // Get MKL layout from input1 as destination layout
+ mkl_output_mkl_shape.SetMklLayout(&input1_md);
+
+ // Create output Mkl tensor for index 0
+ AllocateOutputSetMklShape(context, 0, &tensor_out,
+ input_tensor_0.shape(),
+ mkl_output_mkl_shape);
+
+ // Create MklDnnData object for input0 tesnsor
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> input(&cpu_engine);
+ input.SetUsrMem(input0_md, &input_tensor_0);
+
+ // Create reorder from input0's layout to input1's layout
+ std::vector<primitive> net;
+ CHECK_EQ(input.CheckReorderToOpMem(
+ memory::primitive_desc(input1_md, cpu_engine),
+ tensor_out, &net),
+ true);
+ stream(stream::kind::eager).submit(net).wait();
+
+ // Input1 will be passed through
+ ForwardMklTensorInToOut(context, 1, 1);
+ return;
+ }
}
// Sanity check