diff options
author | 2017-12-15 17:12:41 -0800 | |
---|---|---|
committer | 2017-12-15 17:16:29 -0800 | |
commit | d55f532867a3670d66460c5ee3b774519542adc1 (patch) | |
tree | 7de4d85bcd61e93401459276b4d371ab0be23c1f /tensorflow/core/kernels/mkl_reshape_op.cc | |
parent | 32d5048ae96116202f2aa0fa739ef37514ee8a54 (diff) |
Merge changes from github.
PiperOrigin-RevId: 179258973
Diffstat (limited to 'tensorflow/core/kernels/mkl_reshape_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_reshape_op.cc | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index 5e98582475..11c92ebdb4 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -28,6 +28,11 @@ limitations under the License. #include "mkl_dnn_types.h" #include "tensorflow/core/util/mkl_util.h" +#ifdef INTEL_MKL_DNN +#include "mkldnn.hpp" +using mkldnn::stream; +#endif + namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; template <typename Device, typename T> @@ -35,6 +40,7 @@ class MklReshapeOp : public OpKernel { public: explicit MklReshapeOp(OpKernelConstruction* context) : OpKernel(context) {} +#ifndef INTEL_MKL_DNN void Compute(OpKernelContext* context) override { const Tensor& input = MklGetInput(context, 0); const Tensor& sizes = MklGetInput(context, 1); @@ -129,7 +135,183 @@ class MklReshapeOp : public OpKernel { } } +#else + private: + // When the input tensor is in MKL layout and we are reshaping the tensor to a + // different shape than its actual shape, then we use MKLDNN reorder primitive + // to put tensor back in Tensorflow layout. But we can skip this reordering + // some times. This function checks for all such cases. + bool SkipReorder(const MklDnnShape& mkl_shape_input, + const TensorShape& reshape_to) { + CHECK_EQ(mkl_shape_input.IsMklTensor(), true); + bool ret = false; + + // If Tensorflow's data format and the underlying format maintained by + // MKLDNN are equivalent (both are NHWC or both are NCHW), then we can + // safely return true. + auto input_mkl_md = mkl_shape_input.GetMklLayout(); + if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format) { + ret = true; + } + + return ret; + } + + public: + void Compute(OpKernelContext* context) override { + const Tensor& input_tensor = MklGetInput(context, 0); + const Tensor& sizes = MklGetInput(context, 1); + + MklDnnShape mkl_shape_input; + GetMklShape(context, kInputSlotIdx, &mkl_shape_input); + bool input_in_mkl_format = mkl_shape_input.IsMklTensor(); + const int64 nelems = input_in_mkl_format ? + mkl_shape_input.GetTfShape().num_elements() + : input_tensor.NumElements(); + + // Preliminary validation of sizes. + OP_REQUIRES(context, IsLegacyVector(sizes.shape()), + errors::InvalidArgument("sizes input must be 1-D, not shape ", + sizes.shape().DebugString())); + + // Compute the output shape. Determine product of specified + // dimensions, and find the index of the unspecified one. + TensorShape shape; + int64 product = 1; + int unknown_index = -1; + switch (sizes.dtype()) { + case DT_INT32: + OP_REQUIRES_OK(context, ValidateSizes<int32>(sizes, &product, + &unknown_index, &shape)); + break; + case DT_INT64: + OP_REQUIRES_OK(context, ValidateSizes<int64>(sizes, &product, + &unknown_index, &shape)); + break; + default: + context->CtxFailure(errors::InvalidArgument( + "desired shape must be a DT_INT32 or DT_INT64 vector, not a ", + DataTypeString(sizes.dtype()))); + return; + } + if (unknown_index != -1) { + OP_REQUIRES( + context, product > 0, + errors::InvalidArgument("Reshape cannot infer the missing input size " + "for an empty tensor unless all specified " + "input sizes are non-zero")); + const int64 missing = nelems / product; + OP_REQUIRES( + context, product * missing == nelems, + errors::InvalidArgument( + "Input to reshape is a tensor with ", nelems, + " values, but the requested shape requires a multiple of ", + product)); + shape.set_dim(unknown_index, missing); + } + OP_REQUIRES(context, shape.num_elements() == nelems, + errors::InvalidArgument("Input to reshape is a tensor with ", + nelems, + " values, but the requested shape has ", + shape.num_elements())); + + if (input_in_mkl_format) { + TensorShape& shape_to = shape; + TensorShape shape_from = mkl_shape_input.GetTfShape(); + if (shape_from == shape_to) { + CopyMklTensorInToOut(context, kInputSlotIdx, kOutputSlotIdx); + return; + } else { + try { + auto cpu_engine = engine(engine::cpu, 0); + MklDnnData<T> dnn_data_input(&cpu_engine); + // Reshape is just a logical view change operation for a tensor. + // It does not change underlying layout. But MKLDNN may maintain + // tensor data in different layout than that specified by Tensorflow. + // If MKLDNN maintains input tensor in different layout than that + // specified by Tensorflow, we will need to reorder tensor and then + // put it in the shape expected by Tensorflow. But if MKLDNN has + // maintained input tensor in the same layout as it is expected by + // Tensorflow, we don't need to reorder tensor contents, we just + // need to update MklDnnShape object associated with the input + // tensor to reflect the shape change expected by reshape. + if (!SkipReorder(mkl_shape_input, shape_to)) { + // If dimensions that are being expanded or collapsed are not + // maintained contiguously by MKLDNN, then we use reorder. + + // Get Mkl layout of input tensor. + auto input_mkl_md = mkl_shape_input.GetMklLayout(); + // Set input Mkl layout as the user layout. + dnn_data_input.SetUsrMem(input_mkl_md, &input_tensor); + // Get expected Tensorflow layout of input tensor. + auto output_tf_md = mkl_shape_input.GetTfLayout(); + auto output_tf_pd = memory::primitive_desc(output_tf_md, + cpu_engine); + + Tensor* output_tensor = nullptr; + MklShape mkl_shape_output; + mkl_shape_output.SetMklTensor(false); + // We allocate output tensor in the shape expected by Reshape. + AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor, + shape_to, mkl_shape_output); + + // Insert reorder between Mkl layout and TensorFlow layout. + std::vector<primitive> net; + CHECK_EQ(dnn_data_input.CheckReorderToOpMem(output_tf_pd, + output_tensor, &net), true); + stream(stream::kind::eager).submit(net).wait(); + return; + } else { + // If dimensions that are being expanded or collapsed are + // maintained contiguously by MKLDNN, then we skip reorder, just + // update MklDnnShape object for the tensorflow tensor, and forward + // Tensorflow tensor as it is to the output. + auto output_dims = TFShapeToMklDnnDims(shape_to); + auto output_strides = CalculateTFStrides(output_dims); + auto output_tf_md = MklDnnData<T>::CreateBlockedMemDesc(output_dims, + output_strides); + auto output_tf_pd = memory::primitive_desc(output_tf_md, + cpu_engine); + + // Set MklDnnShape + MklDnnShape mkl_shape_output; + mkl_shape_output.SetMklTensor(true); + mkl_shape_output.SetMklLayout(&output_tf_pd); + mkl_shape_output.SetElemType(MklDnnType<T>()); + mkl_shape_output.SetTfLayout(output_dims.size(), output_dims, + memory::format::blocked); + + // We now simply forward input Mkl tensor to output and change its + // output MklDnnShape object. + ForwardMklTensorInToOutWithMklShape(context, kInputSlotIdx, + kOutputSlotIdx, mkl_shape_output); + return; + } + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, + errors::Aborted("Operation received an exception:", + error_msg)); + } + } + } else { + // If input tensor is not in Mkl format, then just copy Tensorflow tensor + // to output with specified shape. + CopyTfTensorInToOutWithShape(context, kInputSlotIdx, kOutputSlotIdx, + shape); + } + } + +#endif // INTEL_MKL_DNN + + private: + const int kInputSlotIdx = 0; + const int kOutputSlotIdx = 0; + template <typename Tshape> Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index, TensorShape* shape) { |