diff options
author | 2018-09-26 23:14:39 -0700 | |
---|---|---|
committer | 2018-09-26 23:14:39 -0700 | |
commit | 08a6cfed1cf0cccc8ff35448266f44fbc55be0bc (patch) | |
tree | 73f61074984cd9dcf05e5d65b454a6ce08484f4a /tensorflow/core | |
parent | d3f14ef70cdf113f9d330c1f7c638003429a1dc4 (diff) | |
parent | d1ab8b71c2115caacfec19d849ddabf7f1f4287b (diff) |
Merge pull request #22076 from Intel-tensorflow:feature/daoxin/slice
PiperOrigin-RevId: 214726180
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 107 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.h | 3 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass_test.cc | 20 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 6 | ||||
-rw-r--r-- | tensorflow/core/kernels/mkl_slice_op.cc | 358 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 122 |
8 files changed, 530 insertions, 107 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 8bf53958b6..d575604a56 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1367,6 +1367,7 @@ cc_library( "//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", + "//tensorflow/core/kernels:mkl_slice_op", "//tensorflow/core/kernels:mkl_softmax_op", "//tensorflow/core/kernels:mkl_transpose_op", "//tensorflow/core/kernels:mkl_tfconv_op", @@ -3827,6 +3828,7 @@ tf_cc_test_mkl( "//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", + "//tensorflow/core/kernels:mkl_slice_op", "//tensorflow/core/kernels:mkl_softmax_op", "//tensorflow/core/kernels:mkl_tfconv_op", ]), diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 20a07d86a2..50403b4004 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1306,6 +1306,113 @@ Status RandomShape(shape_inference::InferenceContext* c) { return Status::OK(); } +namespace { + +// This SliceHelper processes the output shape of the `slice` +// when the tensor of `sizes` is available. +template <typename T> +Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, + const Tensor* sizes_value, + std::vector<DimensionHandle>* dims) { + auto sizes_vec = sizes_value->vec<T>(); + for (int i = 0; i < sizes_value->NumElements(); ++i) { + DimensionHandle dim = c->Dim(c->input(0), i); + if (sizes_vec(i) != -1) { + auto dim_val = c->Value(dim); + if (sizes_vec(i) < 0) { + return errors::InvalidArgument( + "Out of bounds slicing on dimension ", i, " of length ", dim_val, + ": sizes vector cannot be < -1, but was ", sizes_vec(i)); + } + + dims->emplace_back(c->MakeDim(sizes_vec(i))); + } else { + DimensionHandle result; + TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result)); + dims->emplace_back(result); + } + } + + return Status::OK(); +} +} // namespace + +Status SliceShape(InferenceContext* c) { + ShapeHandle input = c->input(0); + ShapeHandle begin_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); + ShapeHandle sizes_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape)); + + // Merge to check compatibility of begin and sizes tensors. + TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape)); + + DimensionHandle ndims = c->Dim(begin_shape, 0); + if (c->ValueKnown(ndims)) { + TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input)); + } + + // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known + // values, even though the `begin` value does not represent a shape. + ShapeHandle begin_value; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); + + // We check the tensor value here and will only use + // `MakeShapeFromShapeTensor` when `sizes_value` is null. + // The reason is that `sizes` might contain -1, which can't + // be represented (-1 in the ShapeHandle would mean "unknown"). + const Tensor* sizes_value = c->input_tensor(2); + + if (sizes_value != nullptr) { + TF_RETURN_IF_ERROR( + c->WithRank(begin_value, sizes_value->NumElements(), &begin_value)); + std::vector<DimensionHandle> dims; + // If the begin and sizes tensors are available, then + // we can be precise about the shape of the output. + if (sizes_value->dtype() == DT_INT64) { + TF_RETURN_IF_ERROR( + SliceHelper<int64>(c, begin_value, sizes_value, &dims)); + } else { + TF_RETURN_IF_ERROR( + SliceHelper<int32>(c, begin_value, sizes_value, &dims)); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + } else { + // In case `sizes` is not available (`sizes_value` is null), + // we could try to use `MakeShapeFromShapeTensor` here. + // If sizes contain -1, we will simply consider it as `Unknown`. + // This is less than ideal but still an improvement of shape inference. + // The following is an example that returns [None, 1, None] with this + // code path: + // z = tf.zeros((1, 2, 3)) + // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) + // m.get_shape().as_list() + ShapeHandle sizes_value; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); + if (c->RankKnown(sizes_value)) { + TF_RETURN_IF_ERROR( + c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); + std::vector<DimensionHandle> dims; + dims.reserve(c->Rank(sizes_value)); + for (int i = 0; i < c->Rank(sizes_value); ++i) { + dims.emplace_back(c->Dim(sizes_value, i)); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + } + // We might know the rank of the input. + if (c->RankKnown(input)) { + c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); + return Status::OK(); + } else { + return shape_inference::UnknownShape(c); + } + } + + return Status::OK(); +} + Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, ShapeHandle values_shape, ShapeHandle shape_shape) { // Validate ranks. diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index e6f9f935f9..3a496e06ae 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -293,6 +293,9 @@ inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) { // Shape function for random operations. Status RandomShape(shape_inference::InferenceContext* c); +// Shape function for Slice opertaions. +Status SliceShape(shape_inference::InferenceContext* c); + // Validates the 3 component tensors of a sparse tensor have the proper // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 37b88f1728..06d3fefef1 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2450,6 +2450,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.tanh = "Tanh"; csinfo_.tanh_grad = "TanhGrad"; csinfo_.reshape = "Reshape"; + csinfo_.slice = "Slice"; csinfo_.softmax = "Softmax"; csinfo_.split = "Split"; // Element-wise ops. Ensure you also add any new ops to IsOpElementWise @@ -2557,6 +2558,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape), CopyAttrsReshape, AlwaysRewrite}); + rinfo_.push_back({csinfo_.slice, + mkl_op_registry::GetMklOpName(csinfo_.slice), + CopyAttrsSlice, AlwaysRewrite}); rinfo_.push_back({csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax), CopyAttrsDataType, AlwaysRewrite}); @@ -2676,6 +2680,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string tanh; string tanh_grad; string reshape; + string slice; string softmax; string split; string squared_difference; @@ -3134,6 +3139,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb); // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, @@ -3739,6 +3745,19 @@ void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, nb->Attr("Tshape", Tshape); } +void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + DataType Index; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index)); + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("Index", Index); +} + void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb) { DataType T; diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index f42a4ee98b..77640e287c 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -3510,6 +3510,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1"); } +TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Int32Input'}" + "node { name: 'C' op: 'Int32Input'}" + "node { name: 'D' op: 'Slice'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'Index' value { type: DT_INT32 } }" + " input: ['A', 'B', 'C'] }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'D'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Int32Input);C(Int32Input);" + "D(_MklSlice);DMT/_0(Const);DMT/_1(Const);DMT/" + "_2(Const);E(Zeta)|A->D;A->E;" + "A:control->DMT/_0:control;A:control->DMT/" + "_1:control;A:control->DMT/_2:control;" + "B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); +} + ///////////////////////////////////////////////////////////////////// // Post-rewrite fixup pass test diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index b08562d7d1..0534b1829d 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6416,6 +6416,12 @@ tf_mkl_kernel_library( ) tf_mkl_kernel_library( + name = "mkl_slice_op", + prefix = "mkl_slice_op", + deps = ARRAY_DEPS + mkl_deps(), +) + +tf_mkl_kernel_library( name = "mkl_identity_op", prefix = "mkl_identity_op", deps = ARRAY_DEPS + mkl_deps(), diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc new file mode 100644 index 0000000000..d63e14adf6 --- /dev/null +++ b/tensorflow/core/kernels/mkl_slice_op.cc @@ -0,0 +1,358 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/array_ops.cc. + +#ifdef INTEL_MKL +#ifndef INTEL_MKL_ML_ONLY + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/prefetch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "mkldnn.hpp" +#include "tensorflow/core/util/mkl_util.h" + +using mkldnn::stream; +using mkldnn::view; + +namespace tensorflow { + +namespace { + +gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) { + gtl::InlinedVector<int64, 4> out; + if (tensor.dtype() == DT_INT32) { + for (int64 i = 0; i < tensor.NumElements(); ++i) { + out.push_back(tensor.flat<int32>()(i)); + } + } else if (tensor.dtype() == DT_INT64) { + for (int64 i = 0; i < tensor.NumElements(); ++i) { + out.push_back(tensor.flat<int64>()(i)); + } + } else { + // tensor must be either int32 or int64 + DCHECK(false); + } + return out; +} + +} // namespace + +typedef Eigen::ThreadPoolDevice CPUDevice; + +// A version of SharedValidation (slice_op.h) written for input that is in +// either Mkl layout or Tensorflow layout. +// A shared code to validate input shapes and check for identity, which is not dependent on the type of T. +// We do this to reduce code size by not duplicating all this for all T (float, double, int32, etc.) +static void ValidateMklInputs(OpKernelContext* context, bool* is_identity, + gtl::InlinedVector<int64, 4>* begin, + gtl::InlinedVector<int64, 4>* size) { + const int kInputTensorIndex = 0; + const int kInputBeginIndex = 1; + const int kInputSizeIndex = 2; + const Tensor& input = MklGetInput(context, kInputTensorIndex); + const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex); + const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex); + + MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape; + GetMklShape(context, kInputTensorIndex, &input_mkl_shape); + GetMklShape(context, kInputBeginIndex, &begin_mkl_shape); + GetMklShape(context, kInputSizeIndex, &size_mkl_shape); + + // Begin and size tensors cannot be in MklDnn layout. + DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false); + DCHECK_EQ(size_mkl_shape.IsMklTensor(), false); + + TensorShape input_tf_shape = input_mkl_shape.IsMklTensor() + ? input_mkl_shape.GetTfShape() + : input.shape(); + const int input_dims = input_tf_shape.dims(); + + OP_REQUIRES( + context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) && + context->op_kernel().IsLegacyVector(size_tensor.shape()) && + begin_tensor.NumElements() == input_dims && + size_tensor.NumElements() == input_dims, + errors::InvalidArgument( + "Expected begin and size arguments to be 1-D tensors of size ", + input_dims, ", but got shapes ", begin_tensor.shape().DebugString(), + " and ", size_tensor.shape().DebugString(), " instead.")); + + *begin = IntTensorToInt64Vec(begin_tensor); + *size = IntTensorToInt64Vec(size_tensor); + for (int i = 0; i < input_dims; ++i) { + if ((*size)[i] == -1) { + // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". + (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i]; + } + } + + *is_identity = true; + for (int i = 0; i < input_dims; ++i) { + int64 b = (*begin)[i]; + int64 s = (*size)[i]; + if (input_tf_shape.dim_size(i) == 0) { + OP_REQUIRES( + context, b == 0 && s == 0, + errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b, + ") and size[", i, "] == 0 ", "(got ", s, + ") when ", "input.dim_size(", i, ") == 0")); + } else { + OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i), + errors::InvalidArgument("Expected begin[", i, "] in [0, ", + input_tf_shape.dim_size(i), + "], but got ", b)); + OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i), + errors::InvalidArgument("Expected size[", i, "] in [0, ", + input_tf_shape.dim_size(i) - b, + "], but ", "got ", s)); + } + const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i)); + (*is_identity) &= take_all; + } +} + +// A version of SharedSliceCommonCases function written for input tensor +// that may be in MklDnn layout or in Tensorflow layout. +template <typename T> +static void CheckCommonCasesForMklInputs(OpKernelContext* context, + gtl::InlinedVector<int64, 4>* begin, + gtl::InlinedVector<int64, 4>* size, + bool* done) { + bool is_identity = true; + *done = false; + + ValidateMklInputs(context, &is_identity, begin, size); + if (!context->status().ok()) return; + + const Tensor& input = MklGetInput(context, 0); + MklDnnShape input_mkl_shape; + GetMklShape(context, 0, &input_mkl_shape); + + if (is_identity) { + VLOG(1) << "Slice identity"; + context->set_output(0, input); + // Mkl metadata tensor in this case can just be forwarded from input to + // output. + AllocateOutputSetMklShape(context, 0, input_mkl_shape); + *done = true; + } +} + +// MKL-DNN implementation of Slice +template <typename Device, typename T> +class MklDnnSliceOp : public OpKernel { + public: + explicit MklDnnSliceOp(OpKernelConstruction* context) : OpKernel(context) {} + + ~MklDnnSliceOp() {} + + void Compute(OpKernelContext* context) override { + gtl::InlinedVector<int64, 4> begin; + gtl::InlinedVector<int64, 4> size; + bool done = false; + + CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done); + if (!context->status().ok() || done == true) return; + + // Though MKL-DNN supports more than 8 dimension and + // less than 12 dimension tensor. + // But we are mimicking functionality of Eigen Slice op for CPU. + if (begin.size() >= 8) { + OP_REQUIRES( + context, false, + errors::Unimplemented("MklDnnSliceOp : Unhandled input dimensions")); + } + + ComputeMklDnnSlice(context, begin, size); + } + + private: + // Slice op implemented using MKL-DNN APIs. + void ComputeMklDnnSlice(OpKernelContext* context, + const gtl::InlinedVector<int64, 4>& begin, + const gtl::InlinedVector<int64, 4>& size) { + try { + // MKL-DNN API usage below is guided by description at: + // https://github.com/01org/mkl-dnn/issues/69 + // + // Relevant part of the description is copied below: + // + // Let's say you want to copy a part of memory into another buffer (and + // probably change the format). Then your steps are: + // + // 1. create memory primitive descriptor in_mem_pd and memory primitive + // in_mem_p for the entire source data. + // 2. create view primitive descriptor in_submem_pd based on in_mem_pd, + // initial offsets, and sub-sizes + // 3. create memory primitive descriptor out_mem_pd and memory primitive + // out_mem_p for the output (the logical sizes should match sub-sizes + // used in step 2, but the format might be arbitrary) + // 4. create reorder primitive descriptor reorder_pd based on in_submem_pd + // and out_mem_pd + // 5. create reorder primitive itself based on reorder_pd, in_mem_p, and + // out_mem_p. + // + // Please notice that there is no view primitive. There is only view + // primitive descriptor. And the reorder uses source memory as input but + // traverses it according to a view in_submem_pd. + + auto cpu_engine = engine(engine::cpu, 0); + MklDnnData<T> src(&cpu_engine); + MklDnnData<T> output(&cpu_engine); + + // Populate offsets and sizes in memory::dims format based on vector. + memory::dims begin_dims = {}; + begin_dims.resize(begin.size()); + for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i]; + memory::dims size_dims = {}; + bool empty = false; + size_dims.resize(size.size()); + for (size_t i = 0; i < size.size(); ++i) { + size_dims[i] = size[i]; + if (size_dims[i] == 0) empty = true; + } + + Tensor* output_tensor = nullptr; + MklDnnShape output_mkl_shape; + + // If no dimension is selected in slice, the result should be empty. + // Just return an empty output tensor, and a dummy Mkl-shape tensor. + if (empty) { // for empty dims + auto shape_to = MklDnnDimsToTFShape(size_dims); + AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to, + output_mkl_shape); + return; + } + + // Step 1 (as per above description) - Create memory for user data. + // We use blocked format here to describe input tensor. + const Tensor& input_tensor = MklGetInput(context, 0); + MklDnnShape input_mkl_shape; + GetMklShape(context, 0, &input_mkl_shape); + + if (input_mkl_shape.IsMklTensor()) { + auto input_mkl_format = input_mkl_shape.GetTfDataFormat(); + auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format); + begin_dims = MklDnnDimsInNCHW(begin_dims, input_tf_format); + size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format); + auto input_md = input_mkl_shape.GetMklLayout(); + src.SetUsrMem(input_md, &input_tensor); + } else { + // Initialize input dimensions and strides to be used when input is not + // in MklDnn layout. + memory::dims input_dims, input_strides; + input_dims = TFShapeToMklDnnDims(input_tensor.shape()); + input_strides = CalculateTFStrides(input_dims); + // Create input memory descriptor. + auto input_md = + MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides); + src.SetUsrMem(input_md, &input_tensor); + } + + // Step 2 - create view primitive descriptor + auto view_pd = + view::primitive_desc(src.GetUsrMemPrimDesc(), size_dims, begin_dims) + .dst_primitive_desc(); + auto output_strides = CalculateTFStrides(size_dims); + auto output_md = + MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides); + auto output_pd = memory::primitive_desc(output_md, cpu_engine); + + // Step 3 - Create memory for output. If input is in MklDnn layout, then + // output is also in MklDnn layout. Otherwise, output is in Tensorflow + // layout. + AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims, + &output_tensor, &output_mkl_shape); + DCHECK(output_tensor); + DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor()); + output.SetUsrMem(output_md, output_tensor); + + std::vector<primitive> net; + // Step 4 - create reorder primitive desc between view_pd and output_pd. + auto reorder_pd = + reorder::primitive_desc(view_pd, output.GetUsrMemPrimDesc()); + // Step 5 - create reorder primitive itself. + net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *output.GetUsrMem())); + // Execute the reorder primitive. + stream(stream::kind::eager).submit(net).wait(); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); + } + } + + private: + void AllocateOutputTensor(OpKernelContext* context, + const MklDnnShape& input_mkl_shape, + memory::primitive_desc* output_pd, + const memory::dims& output_dims, + Tensor** output_tensor, + MklDnnShape* output_mkl_shape) { + DCHECK(output_tensor); + DCHECK(output_mkl_shape); + + TensorShape output_tf_shape; + + if (input_mkl_shape.IsMklTensor()) { + // Since input tensor is in Mkl layout, output tensor will be in Mkl + // layout. + + // Allocate shape of Mkl tensor. + output_mkl_shape->SetMklTensor(true); + output_mkl_shape->SetMklLayout(output_pd); + output_mkl_shape->SetElemType(MklDnnType<T>()); + output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims, + input_mkl_shape.GetTfDataFormat()); + + output_tf_shape.AddDim((output_pd->get_size() / sizeof(T)) + 1); + } else { + // If input is not in Mkl layout, then output won't be in Mkl layout. + output_mkl_shape->SetMklTensor(false); + output_tf_shape = MklDnnDimsToTFShape(output_dims); + } + + AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape, + *output_mkl_shape); + } +}; + +// MKL-DNN Slice registration +#define REGISTER_MKL_SLICE(type) \ + REGISTER_KERNEL_BUILDER(Name("_MklSlice") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("begin") \ + .HostMemory("size") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklDnnSliceOp<CPUDevice, type>); + +TF_CALL_float(REGISTER_MKL_SLICE); +#undef REGISTER_MKL_SLICE + +} // namespace tensorflow + +#endif // INTEL_MKL_DNN +#endif // INTEL_MKL diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 442686c92a..c9f80df5e4 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1531,37 +1531,6 @@ REGISTER_OP("Size") .Attr("out_type: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ScalarShape); -namespace { - -// This SliceHelper processes the output shape of the `slice` -// when the tensor of `sizes` is available. -template <typename T> -Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, - const Tensor* sizes_value, - std::vector<DimensionHandle>* dims) { - auto sizes_vec = sizes_value->vec<T>(); - for (int i = 0; i < sizes_value->NumElements(); ++i) { - DimensionHandle dim = c->Dim(c->input(0), i); - if (sizes_vec(i) != -1) { - auto dim_val = c->Value(dim); - if (sizes_vec(i) < 0) { - return errors::InvalidArgument( - "Out of bounds slicing on dimension ", i, " of length ", dim_val, - ": sizes vector cannot be < -1, but was ", sizes_vec(i)); - } - - dims->emplace_back(c->MakeDim(sizes_vec(i))); - } else { - DimensionHandle result; - TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result)); - dims->emplace_back(result); - } - } - - return Status::OK(); -} -} // namespace - // -------------------------------------------------------------------------- REGISTER_OP("Slice") .Input("input: T") @@ -1570,83 +1539,22 @@ REGISTER_OP("Slice") .Output("output: T") .Attr("T: type") .Attr("Index: {int32,int64}") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle input = c->input(0); - ShapeHandle begin_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); - ShapeHandle sizes_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape)); - - // Merge to check compatibility of begin and sizes tensors. - TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape)); - - DimensionHandle ndims = c->Dim(begin_shape, 0); - if (c->ValueKnown(ndims)) { - TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input)); - } - - // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known - // values, even though the `begin` value does not represent a shape. - ShapeHandle begin_value; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); - - // We check the tensor value here and will only use - // `MakeShapeFromShapeTensor` when `sizes_value` is null. - // The reason is that `sizes`might contain -1, which can't - // be represented (-1 in the ShapeHandle would mean "unknown". - const Tensor* sizes_value = c->input_tensor(2); - - if (sizes_value != nullptr) { - TF_RETURN_IF_ERROR( - c->WithRank(begin_value, sizes_value->NumElements(), &begin_value)); - std::vector<DimensionHandle> dims; - // If the begin and sizes tensors are available, then - // we can be precise about the shape of the output. - if (sizes_value->dtype() == DT_INT64) { - TF_RETURN_IF_ERROR( - SliceHelper<int64>(c, begin_value, sizes_value, &dims)); - } else { - TF_RETURN_IF_ERROR( - SliceHelper<int32>(c, begin_value, sizes_value, &dims)); - } - - c->set_output(0, c->MakeShape(dims)); - return Status::OK(); - } else { - // In case `sizes` is not available (`sizes_value` is null), - // we could try to use `MakeShapeFromShapeTensor` here. - // If sizes contain -1, we will simply consider it as `Unknown`. - // This is less than ideal but still an improvement of shape inference. - // The following is an example that returns [None, 1, None] with this - // code path: - // z = tf.zeros((1, 2, 3)) - // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) - // m.get_shape().as_list() - ShapeHandle sizes_value; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); - if (c->RankKnown(sizes_value)) { - TF_RETURN_IF_ERROR( - c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); - std::vector<DimensionHandle> dims; - dims.reserve(c->Rank(sizes_value)); - for (int i = 0; i < c->Rank(sizes_value); ++i) { - dims.emplace_back(c->Dim(sizes_value, i)); - } - c->set_output(0, c->MakeShape(dims)); - return Status::OK(); - } - - // We might know the rank of the input. - if (c->RankKnown(input)) { - c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); - return Status::OK(); - } else { - return shape_inference::UnknownShape(c); - } - } + .SetShapeFn(shape_inference::SliceShape); - return Status::OK(); - }); +#ifdef INTEL_MKL +REGISTER_OP("_MklSlice") + .Input("input: T") + .Input("begin: Index") + .Input("size: Index") + .Input("mkl_input: uint8") + .Input("mkl_begin: uint8") + .Input("mkl_size: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: type") + .Attr("Index: {int32,int64}") + .SetShapeFn(shape_inference::SliceShape); +#endif REGISTER_OP("StridedSlice") .Input("input: T") |