aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_reshape_op.cc
diff options
context:
space:
mode:
authorGravatar Dandelion Man? <dandelion@google.com>2017-12-15 17:12:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 17:16:29 -0800
commitd55f532867a3670d66460c5ee3b774519542adc1 (patch)
tree7de4d85bcd61e93401459276b4d371ab0be23c1f /tensorflow/core/kernels/mkl_reshape_op.cc
parent32d5048ae96116202f2aa0fa739ef37514ee8a54 (diff)
Merge changes from github.
PiperOrigin-RevId: 179258973
Diffstat (limited to 'tensorflow/core/kernels/mkl_reshape_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_reshape_op.cc182
1 files changed, 182 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc
index 5e98582475..11c92ebdb4 100644
--- a/tensorflow/core/kernels/mkl_reshape_op.cc
+++ b/tensorflow/core/kernels/mkl_reshape_op.cc
@@ -28,6 +28,11 @@ limitations under the License.
#include "mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+using mkldnn::stream;
+#endif
+
namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
template <typename Device, typename T>
@@ -35,6 +40,7 @@ class MklReshapeOp : public OpKernel {
public:
explicit MklReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
+#ifndef INTEL_MKL_DNN
void Compute(OpKernelContext* context) override {
const Tensor& input = MklGetInput(context, 0);
const Tensor& sizes = MklGetInput(context, 1);
@@ -129,7 +135,183 @@ class MklReshapeOp : public OpKernel {
}
}
+#else
+
private:
+ // When the input tensor is in MKL layout and we are reshaping the tensor to a
+ // different shape than its actual shape, then we use MKLDNN reorder primitive
+ // to put tensor back in Tensorflow layout. But we can skip this reordering
+ // some times. This function checks for all such cases.
+ bool SkipReorder(const MklDnnShape& mkl_shape_input,
+ const TensorShape& reshape_to) {
+ CHECK_EQ(mkl_shape_input.IsMklTensor(), true);
+ bool ret = false;
+
+ // If Tensorflow's data format and the underlying format maintained by
+ // MKLDNN are equivalent (both are NHWC or both are NCHW), then we can
+ // safely return true.
+ auto input_mkl_md = mkl_shape_input.GetMklLayout();
+ if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format) {
+ ret = true;
+ }
+
+ return ret;
+ }
+
+ public:
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_tensor = MklGetInput(context, 0);
+ const Tensor& sizes = MklGetInput(context, 1);
+
+ MklDnnShape mkl_shape_input;
+ GetMklShape(context, kInputSlotIdx, &mkl_shape_input);
+ bool input_in_mkl_format = mkl_shape_input.IsMklTensor();
+ const int64 nelems = input_in_mkl_format ?
+ mkl_shape_input.GetTfShape().num_elements()
+ : input_tensor.NumElements();
+
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
+ errors::InvalidArgument("sizes input must be 1-D, not shape ",
+ sizes.shape().DebugString()));
+
+ // Compute the output shape. Determine product of specified
+ // dimensions, and find the index of the unspecified one.
+ TensorShape shape;
+ int64 product = 1;
+ int unknown_index = -1;
+ switch (sizes.dtype()) {
+ case DT_INT32:
+ OP_REQUIRES_OK(context, ValidateSizes<int32>(sizes, &product,
+ &unknown_index, &shape));
+ break;
+ case DT_INT64:
+ OP_REQUIRES_OK(context, ValidateSizes<int64>(sizes, &product,
+ &unknown_index, &shape));
+ break;
+ default:
+ context->CtxFailure(errors::InvalidArgument(
+ "desired shape must be a DT_INT32 or DT_INT64 vector, not a ",
+ DataTypeString(sizes.dtype())));
+ return;
+ }
+ if (unknown_index != -1) {
+ OP_REQUIRES(
+ context, product > 0,
+ errors::InvalidArgument("Reshape cannot infer the missing input size "
+ "for an empty tensor unless all specified "
+ "input sizes are non-zero"));
+ const int64 missing = nelems / product;
+ OP_REQUIRES(
+ context, product * missing == nelems,
+ errors::InvalidArgument(
+ "Input to reshape is a tensor with ", nelems,
+ " values, but the requested shape requires a multiple of ",
+ product));
+ shape.set_dim(unknown_index, missing);
+ }
+ OP_REQUIRES(context, shape.num_elements() == nelems,
+ errors::InvalidArgument("Input to reshape is a tensor with ",
+ nelems,
+ " values, but the requested shape has ",
+ shape.num_elements()));
+
+ if (input_in_mkl_format) {
+ TensorShape& shape_to = shape;
+ TensorShape shape_from = mkl_shape_input.GetTfShape();
+ if (shape_from == shape_to) {
+ CopyMklTensorInToOut(context, kInputSlotIdx, kOutputSlotIdx);
+ return;
+ } else {
+ try {
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> dnn_data_input(&cpu_engine);
+ // Reshape is just a logical view change operation for a tensor.
+ // It does not change underlying layout. But MKLDNN may maintain
+ // tensor data in different layout than that specified by Tensorflow.
+ // If MKLDNN maintains input tensor in different layout than that
+ // specified by Tensorflow, we will need to reorder tensor and then
+ // put it in the shape expected by Tensorflow. But if MKLDNN has
+ // maintained input tensor in the same layout as it is expected by
+ // Tensorflow, we don't need to reorder tensor contents, we just
+ // need to update MklDnnShape object associated with the input
+ // tensor to reflect the shape change expected by reshape.
+ if (!SkipReorder(mkl_shape_input, shape_to)) {
+ // If dimensions that are being expanded or collapsed are not
+ // maintained contiguously by MKLDNN, then we use reorder.
+
+ // Get Mkl layout of input tensor.
+ auto input_mkl_md = mkl_shape_input.GetMklLayout();
+ // Set input Mkl layout as the user layout.
+ dnn_data_input.SetUsrMem(input_mkl_md, &input_tensor);
+ // Get expected Tensorflow layout of input tensor.
+ auto output_tf_md = mkl_shape_input.GetTfLayout();
+ auto output_tf_pd = memory::primitive_desc(output_tf_md,
+ cpu_engine);
+
+ Tensor* output_tensor = nullptr;
+ MklShape mkl_shape_output;
+ mkl_shape_output.SetMklTensor(false);
+ // We allocate output tensor in the shape expected by Reshape.
+ AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor,
+ shape_to, mkl_shape_output);
+
+ // Insert reorder between Mkl layout and TensorFlow layout.
+ std::vector<primitive> net;
+ CHECK_EQ(dnn_data_input.CheckReorderToOpMem(output_tf_pd,
+ output_tensor, &net), true);
+ stream(stream::kind::eager).submit(net).wait();
+ return;
+ } else {
+ // If dimensions that are being expanded or collapsed are
+ // maintained contiguously by MKLDNN, then we skip reorder, just
+ // update MklDnnShape object for the tensorflow tensor, and forward
+ // Tensorflow tensor as it is to the output.
+ auto output_dims = TFShapeToMklDnnDims(shape_to);
+ auto output_strides = CalculateTFStrides(output_dims);
+ auto output_tf_md = MklDnnData<T>::CreateBlockedMemDesc(output_dims,
+ output_strides);
+ auto output_tf_pd = memory::primitive_desc(output_tf_md,
+ cpu_engine);
+
+ // Set MklDnnShape
+ MklDnnShape mkl_shape_output;
+ mkl_shape_output.SetMklTensor(true);
+ mkl_shape_output.SetMklLayout(&output_tf_pd);
+ mkl_shape_output.SetElemType(MklDnnType<T>());
+ mkl_shape_output.SetTfLayout(output_dims.size(), output_dims,
+ memory::format::blocked);
+
+ // We now simply forward input Mkl tensor to output and change its
+ // output MklDnnShape object.
+ ForwardMklTensorInToOutWithMklShape(context, kInputSlotIdx,
+ kOutputSlotIdx, mkl_shape_output);
+ return;
+ }
+ } catch (mkldnn::error &e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) +
+ ", in file " + string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context,
+ errors::Aborted("Operation received an exception:",
+ error_msg));
+ }
+ }
+ } else {
+ // If input tensor is not in Mkl format, then just copy Tensorflow tensor
+ // to output with specified shape.
+ CopyTfTensorInToOutWithShape(context, kInputSlotIdx, kOutputSlotIdx,
+ shape);
+ }
+ }
+
+#endif // INTEL_MKL_DNN
+
+ private:
+ const int kInputSlotIdx = 0;
+ const int kOutputSlotIdx = 0;
+
template <typename Tshape>
Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index,
TensorShape* shape) {