aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 23:14:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 23:14:39 -0700
commit08a6cfed1cf0cccc8ff35448266f44fbc55be0bc (patch)
tree73f61074984cd9dcf05e5d65b454a6ce08484f4a /tensorflow/core
parentd3f14ef70cdf113f9d330c1f7c638003429a1dc4 (diff)
parentd1ab8b71c2115caacfec19d849ddabf7f1f4287b (diff)
Merge pull request #22076 from Intel-tensorflow:feature/daoxin/slice
PiperOrigin-RevId: 214726180
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc107
-rw-r--r--tensorflow/core/framework/common_shape_fns.h3
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc19
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc20
-rw-r--r--tensorflow/core/kernels/BUILD6
-rw-r--r--tensorflow/core/kernels/mkl_slice_op.cc358
-rw-r--r--tensorflow/core/ops/array_ops.cc122
8 files changed, 530 insertions, 107 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 8bf53958b6..d575604a56 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1367,6 +1367,7 @@ cc_library(
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
+ "//tensorflow/core/kernels:mkl_slice_op",
"//tensorflow/core/kernels:mkl_softmax_op",
"//tensorflow/core/kernels:mkl_transpose_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
@@ -3827,6 +3828,7 @@ tf_cc_test_mkl(
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
+ "//tensorflow/core/kernels:mkl_slice_op",
"//tensorflow/core/kernels:mkl_softmax_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
]),
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 20a07d86a2..50403b4004 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1306,6 +1306,113 @@ Status RandomShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
+namespace {
+
+// This SliceHelper processes the output shape of the `slice`
+// when the tensor of `sizes` is available.
+template <typename T>
+Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
+ const Tensor* sizes_value,
+ std::vector<DimensionHandle>* dims) {
+ auto sizes_vec = sizes_value->vec<T>();
+ for (int i = 0; i < sizes_value->NumElements(); ++i) {
+ DimensionHandle dim = c->Dim(c->input(0), i);
+ if (sizes_vec(i) != -1) {
+ auto dim_val = c->Value(dim);
+ if (sizes_vec(i) < 0) {
+ return errors::InvalidArgument(
+ "Out of bounds slicing on dimension ", i, " of length ", dim_val,
+ ": sizes vector cannot be < -1, but was ", sizes_vec(i));
+ }
+
+ dims->emplace_back(c->MakeDim(sizes_vec(i)));
+ } else {
+ DimensionHandle result;
+ TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
+ dims->emplace_back(result);
+ }
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+Status SliceShape(InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+ ShapeHandle begin_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
+ ShapeHandle sizes_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
+
+ // Merge to check compatibility of begin and sizes tensors.
+ TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
+
+ DimensionHandle ndims = c->Dim(begin_shape, 0);
+ if (c->ValueKnown(ndims)) {
+ TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
+ }
+
+ // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
+ // values, even though the `begin` value does not represent a shape.
+ ShapeHandle begin_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
+
+ // We check the tensor value here and will only use
+ // `MakeShapeFromShapeTensor` when `sizes_value` is null.
+ // The reason is that `sizes` might contain -1, which can't
+ // be represented (-1 in the ShapeHandle would mean "unknown").
+ const Tensor* sizes_value = c->input_tensor(2);
+
+ if (sizes_value != nullptr) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
+ std::vector<DimensionHandle> dims;
+ // If the begin and sizes tensors are available, then
+ // we can be precise about the shape of the output.
+ if (sizes_value->dtype() == DT_INT64) {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int64>(c, begin_value, sizes_value, &dims));
+ } else {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int32>(c, begin_value, sizes_value, &dims));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ } else {
+ // In case `sizes` is not available (`sizes_value` is null),
+ // we could try to use `MakeShapeFromShapeTensor` here.
+ // If sizes contain -1, we will simply consider it as `Unknown`.
+ // This is less than ideal but still an improvement of shape inference.
+ // The following is an example that returns [None, 1, None] with this
+ // code path:
+ // z = tf.zeros((1, 2, 3))
+ // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
+ // m.get_shape().as_list()
+ ShapeHandle sizes_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
+ if (c->RankKnown(sizes_value)) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
+ std::vector<DimensionHandle> dims;
+ dims.reserve(c->Rank(sizes_value));
+ for (int i = 0; i < c->Rank(sizes_value); ++i) {
+ dims.emplace_back(c->Dim(sizes_value, i));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ }
+ // We might know the rank of the input.
+ if (c->RankKnown(input)) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
+ return Status::OK();
+ } else {
+ return shape_inference::UnknownShape(c);
+ }
+ }
+
+ return Status::OK();
+}
+
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle values_shape, ShapeHandle shape_shape) {
// Validate ranks.
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index e6f9f935f9..3a496e06ae 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -293,6 +293,9 @@ inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
// Shape function for random operations.
Status RandomShape(shape_inference::InferenceContext* c);
+// Shape function for Slice opertaions.
+Status SliceShape(shape_inference::InferenceContext* c);
+
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 37b88f1728..06d3fefef1 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -2450,6 +2450,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.tanh = "Tanh";
csinfo_.tanh_grad = "TanhGrad";
csinfo_.reshape = "Reshape";
+ csinfo_.slice = "Slice";
csinfo_.softmax = "Softmax";
csinfo_.split = "Split";
// Element-wise ops. Ensure you also add any new ops to IsOpElementWise
@@ -2557,6 +2558,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.reshape,
mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsReshape, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.slice,
+ mkl_op_registry::GetMklOpName(csinfo_.slice),
+ CopyAttrsSlice, AlwaysRewrite});
rinfo_.push_back({csinfo_.softmax,
mkl_op_registry::GetMklOpName(csinfo_.softmax),
CopyAttrsDataType, AlwaysRewrite});
@@ -2676,6 +2680,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string tanh;
string tanh_grad;
string reshape;
+ string slice;
string softmax;
string split;
string squared_difference;
@@ -3134,6 +3139,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
// Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
@@ -3739,6 +3745,19 @@ void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
nb->Attr("Tshape", Tshape);
}
+void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ DataType Index;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index));
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("Index", Index);
+}
+
void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index f42a4ee98b..77640e287c 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -3510,6 +3510,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}
+TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Int32Input'}"
+ "node { name: 'C' op: 'Int32Input'}"
+ "node { name: 'D' op: 'Slice'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'Index' value { type: DT_INT32 } }"
+ " input: ['A', 'B', 'C'] }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'D'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Int32Input);C(Int32Input);"
+ "D(_MklSlice);DMT/_0(Const);DMT/_1(Const);DMT/"
+ "_2(Const);E(Zeta)|A->D;A->E;"
+ "A:control->DMT/_0:control;A:control->DMT/"
+ "_1:control;A:control->DMT/_2:control;"
+ "B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+}
+
/////////////////////////////////////////////////////////////////////
// Post-rewrite fixup pass test
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index b08562d7d1..0534b1829d 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -6416,6 +6416,12 @@ tf_mkl_kernel_library(
)
tf_mkl_kernel_library(
+ name = "mkl_slice_op",
+ prefix = "mkl_slice_op",
+ deps = ARRAY_DEPS + mkl_deps(),
+)
+
+tf_mkl_kernel_library(
name = "mkl_identity_op",
prefix = "mkl_identity_op",
deps = ARRAY_DEPS + mkl_deps(),
diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc
new file mode 100644
index 0000000000..d63e14adf6
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_slice_op.cc
@@ -0,0 +1,358 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/array_ops.cc.
+
+#ifdef INTEL_MKL
+#ifndef INTEL_MKL_ML_ONLY
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/prefetch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include "mkldnn.hpp"
+#include "tensorflow/core/util/mkl_util.h"
+
+using mkldnn::stream;
+using mkldnn::view;
+
+namespace tensorflow {
+
+namespace {
+
+gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
+ gtl::InlinedVector<int64, 4> out;
+ if (tensor.dtype() == DT_INT32) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int32>()(i));
+ }
+ } else if (tensor.dtype() == DT_INT64) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int64>()(i));
+ }
+ } else {
+ // tensor must be either int32 or int64
+ DCHECK(false);
+ }
+ return out;
+}
+
+} // namespace
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+// A version of SharedValidation (slice_op.h) written for input that is in
+// either Mkl layout or Tensorflow layout.
+// A shared code to validate input shapes and check for identity, which is not dependent on the type of T.
+// We do this to reduce code size by not duplicating all this for all T (float, double, int32, etc.)
+static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size) {
+ const int kInputTensorIndex = 0;
+ const int kInputBeginIndex = 1;
+ const int kInputSizeIndex = 2;
+ const Tensor& input = MklGetInput(context, kInputTensorIndex);
+ const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex);
+ const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex);
+
+ MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape;
+ GetMklShape(context, kInputTensorIndex, &input_mkl_shape);
+ GetMklShape(context, kInputBeginIndex, &begin_mkl_shape);
+ GetMklShape(context, kInputSizeIndex, &size_mkl_shape);
+
+ // Begin and size tensors cannot be in MklDnn layout.
+ DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false);
+ DCHECK_EQ(size_mkl_shape.IsMklTensor(), false);
+
+ TensorShape input_tf_shape = input_mkl_shape.IsMklTensor()
+ ? input_mkl_shape.GetTfShape()
+ : input.shape();
+ const int input_dims = input_tf_shape.dims();
+
+ OP_REQUIRES(
+ context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
+ context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
+ begin_tensor.NumElements() == input_dims &&
+ size_tensor.NumElements() == input_dims,
+ errors::InvalidArgument(
+ "Expected begin and size arguments to be 1-D tensors of size ",
+ input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
+ " and ", size_tensor.shape().DebugString(), " instead."));
+
+ *begin = IntTensorToInt64Vec(begin_tensor);
+ *size = IntTensorToInt64Vec(size_tensor);
+ for (int i = 0; i < input_dims; ++i) {
+ if ((*size)[i] == -1) {
+ // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
+ (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i];
+ }
+ }
+
+ *is_identity = true;
+ for (int i = 0; i < input_dims; ++i) {
+ int64 b = (*begin)[i];
+ int64 s = (*size)[i];
+ if (input_tf_shape.dim_size(i) == 0) {
+ OP_REQUIRES(
+ context, b == 0 && s == 0,
+ errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
+ ") and size[", i, "] == 0 ", "(got ", s,
+ ") when ", "input.dim_size(", i, ") == 0"));
+ } else {
+ OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected begin[", i, "] in [0, ",
+ input_tf_shape.dim_size(i),
+ "], but got ", b));
+ OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected size[", i, "] in [0, ",
+ input_tf_shape.dim_size(i) - b,
+ "], but ", "got ", s));
+ }
+ const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i));
+ (*is_identity) &= take_all;
+ }
+}
+
+// A version of SharedSliceCommonCases function written for input tensor
+// that may be in MklDnn layout or in Tensorflow layout.
+template <typename T>
+static void CheckCommonCasesForMklInputs(OpKernelContext* context,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size,
+ bool* done) {
+ bool is_identity = true;
+ *done = false;
+
+ ValidateMklInputs(context, &is_identity, begin, size);
+ if (!context->status().ok()) return;
+
+ const Tensor& input = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (is_identity) {
+ VLOG(1) << "Slice identity";
+ context->set_output(0, input);
+ // Mkl metadata tensor in this case can just be forwarded from input to
+ // output.
+ AllocateOutputSetMklShape(context, 0, input_mkl_shape);
+ *done = true;
+ }
+}
+
+// MKL-DNN implementation of Slice
+template <typename Device, typename T>
+class MklDnnSliceOp : public OpKernel {
+ public:
+ explicit MklDnnSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ ~MklDnnSliceOp() {}
+
+ void Compute(OpKernelContext* context) override {
+ gtl::InlinedVector<int64, 4> begin;
+ gtl::InlinedVector<int64, 4> size;
+ bool done = false;
+
+ CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);
+ if (!context->status().ok() || done == true) return;
+
+ // Though MKL-DNN supports more than 8 dimension and
+ // less than 12 dimension tensor.
+ // But we are mimicking functionality of Eigen Slice op for CPU.
+ if (begin.size() >= 8) {
+ OP_REQUIRES(
+ context, false,
+ errors::Unimplemented("MklDnnSliceOp : Unhandled input dimensions"));
+ }
+
+ ComputeMklDnnSlice(context, begin, size);
+ }
+
+ private:
+ // Slice op implemented using MKL-DNN APIs.
+ void ComputeMklDnnSlice(OpKernelContext* context,
+ const gtl::InlinedVector<int64, 4>& begin,
+ const gtl::InlinedVector<int64, 4>& size) {
+ try {
+ // MKL-DNN API usage below is guided by description at:
+ // https://github.com/01org/mkl-dnn/issues/69
+ //
+ // Relevant part of the description is copied below:
+ //
+ // Let's say you want to copy a part of memory into another buffer (and
+ // probably change the format). Then your steps are:
+ //
+ // 1. create memory primitive descriptor in_mem_pd and memory primitive
+ // in_mem_p for the entire source data.
+ // 2. create view primitive descriptor in_submem_pd based on in_mem_pd,
+ // initial offsets, and sub-sizes
+ // 3. create memory primitive descriptor out_mem_pd and memory primitive
+ // out_mem_p for the output (the logical sizes should match sub-sizes
+ // used in step 2, but the format might be arbitrary)
+ // 4. create reorder primitive descriptor reorder_pd based on in_submem_pd
+ // and out_mem_pd
+ // 5. create reorder primitive itself based on reorder_pd, in_mem_p, and
+ // out_mem_p.
+ //
+ // Please notice that there is no view primitive. There is only view
+ // primitive descriptor. And the reorder uses source memory as input but
+ // traverses it according to a view in_submem_pd.
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> src(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ // Populate offsets and sizes in memory::dims format based on vector.
+ memory::dims begin_dims = {};
+ begin_dims.resize(begin.size());
+ for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i];
+ memory::dims size_dims = {};
+ bool empty = false;
+ size_dims.resize(size.size());
+ for (size_t i = 0; i < size.size(); ++i) {
+ size_dims[i] = size[i];
+ if (size_dims[i] == 0) empty = true;
+ }
+
+ Tensor* output_tensor = nullptr;
+ MklDnnShape output_mkl_shape;
+
+ // If no dimension is selected in slice, the result should be empty.
+ // Just return an empty output tensor, and a dummy Mkl-shape tensor.
+ if (empty) { // for empty dims
+ auto shape_to = MklDnnDimsToTFShape(size_dims);
+ AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
+ output_mkl_shape);
+ return;
+ }
+
+ // Step 1 (as per above description) - Create memory for user data.
+ // We use blocked format here to describe input tensor.
+ const Tensor& input_tensor = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (input_mkl_shape.IsMklTensor()) {
+ auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
+ auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
+ begin_dims = MklDnnDimsInNCHW(begin_dims, input_tf_format);
+ size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format);
+ auto input_md = input_mkl_shape.GetMklLayout();
+ src.SetUsrMem(input_md, &input_tensor);
+ } else {
+ // Initialize input dimensions and strides to be used when input is not
+ // in MklDnn layout.
+ memory::dims input_dims, input_strides;
+ input_dims = TFShapeToMklDnnDims(input_tensor.shape());
+ input_strides = CalculateTFStrides(input_dims);
+ // Create input memory descriptor.
+ auto input_md =
+ MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
+ src.SetUsrMem(input_md, &input_tensor);
+ }
+
+ // Step 2 - create view primitive descriptor
+ auto view_pd =
+ view::primitive_desc(src.GetUsrMemPrimDesc(), size_dims, begin_dims)
+ .dst_primitive_desc();
+ auto output_strides = CalculateTFStrides(size_dims);
+ auto output_md =
+ MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
+ auto output_pd = memory::primitive_desc(output_md, cpu_engine);
+
+ // Step 3 - Create memory for output. If input is in MklDnn layout, then
+ // output is also in MklDnn layout. Otherwise, output is in Tensorflow
+ // layout.
+ AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
+ &output_tensor, &output_mkl_shape);
+ DCHECK(output_tensor);
+ DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor());
+ output.SetUsrMem(output_md, output_tensor);
+
+ std::vector<primitive> net;
+ // Step 4 - create reorder primitive desc between view_pd and output_pd.
+ auto reorder_pd =
+ reorder::primitive_desc(view_pd, output.GetUsrMemPrimDesc());
+ // Step 5 - create reorder primitive itself.
+ net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *output.GetUsrMem()));
+ // Execute the reorder primitive.
+ stream(stream::kind::eager).submit(net).wait();
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+ string(e.message) + ", in file " + string(__FILE__) +
+ ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ void AllocateOutputTensor(OpKernelContext* context,
+ const MklDnnShape& input_mkl_shape,
+ memory::primitive_desc* output_pd,
+ const memory::dims& output_dims,
+ Tensor** output_tensor,
+ MklDnnShape* output_mkl_shape) {
+ DCHECK(output_tensor);
+ DCHECK(output_mkl_shape);
+
+ TensorShape output_tf_shape;
+
+ if (input_mkl_shape.IsMklTensor()) {
+ // Since input tensor is in Mkl layout, output tensor will be in Mkl
+ // layout.
+
+ // Allocate shape of Mkl tensor.
+ output_mkl_shape->SetMklTensor(true);
+ output_mkl_shape->SetMklLayout(output_pd);
+ output_mkl_shape->SetElemType(MklDnnType<T>());
+ output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims,
+ input_mkl_shape.GetTfDataFormat());
+
+ output_tf_shape.AddDim((output_pd->get_size() / sizeof(T)) + 1);
+ } else {
+ // If input is not in Mkl layout, then output won't be in Mkl layout.
+ output_mkl_shape->SetMklTensor(false);
+ output_tf_shape = MklDnnDimsToTFShape(output_dims);
+ }
+
+ AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
+ *output_mkl_shape);
+ }
+};
+
+// MKL-DNN Slice registration
+#define REGISTER_MKL_SLICE(type) \
+ REGISTER_KERNEL_BUILDER(Name("_MklSlice") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("size") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklDnnSliceOp<CPUDevice, type>);
+
+TF_CALL_float(REGISTER_MKL_SLICE);
+#undef REGISTER_MKL_SLICE
+
+} // namespace tensorflow
+
+#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 442686c92a..c9f80df5e4 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1531,37 +1531,6 @@ REGISTER_OP("Size")
.Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ScalarShape);
-namespace {
-
-// This SliceHelper processes the output shape of the `slice`
-// when the tensor of `sizes` is available.
-template <typename T>
-Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
- const Tensor* sizes_value,
- std::vector<DimensionHandle>* dims) {
- auto sizes_vec = sizes_value->vec<T>();
- for (int i = 0; i < sizes_value->NumElements(); ++i) {
- DimensionHandle dim = c->Dim(c->input(0), i);
- if (sizes_vec(i) != -1) {
- auto dim_val = c->Value(dim);
- if (sizes_vec(i) < 0) {
- return errors::InvalidArgument(
- "Out of bounds slicing on dimension ", i, " of length ", dim_val,
- ": sizes vector cannot be < -1, but was ", sizes_vec(i));
- }
-
- dims->emplace_back(c->MakeDim(sizes_vec(i)));
- } else {
- DimensionHandle result;
- TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
- dims->emplace_back(result);
- }
- }
-
- return Status::OK();
-}
-} // namespace
-
// --------------------------------------------------------------------------
REGISTER_OP("Slice")
.Input("input: T")
@@ -1570,83 +1539,22 @@ REGISTER_OP("Slice")
.Output("output: T")
.Attr("T: type")
.Attr("Index: {int32,int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle input = c->input(0);
- ShapeHandle begin_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
- ShapeHandle sizes_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
-
- // Merge to check compatibility of begin and sizes tensors.
- TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
-
- DimensionHandle ndims = c->Dim(begin_shape, 0);
- if (c->ValueKnown(ndims)) {
- TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
- }
-
- // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
- // values, even though the `begin` value does not represent a shape.
- ShapeHandle begin_value;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
-
- // We check the tensor value here and will only use
- // `MakeShapeFromShapeTensor` when `sizes_value` is null.
- // The reason is that `sizes`might contain -1, which can't
- // be represented (-1 in the ShapeHandle would mean "unknown".
- const Tensor* sizes_value = c->input_tensor(2);
-
- if (sizes_value != nullptr) {
- TF_RETURN_IF_ERROR(
- c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
- std::vector<DimensionHandle> dims;
- // If the begin and sizes tensors are available, then
- // we can be precise about the shape of the output.
- if (sizes_value->dtype() == DT_INT64) {
- TF_RETURN_IF_ERROR(
- SliceHelper<int64>(c, begin_value, sizes_value, &dims));
- } else {
- TF_RETURN_IF_ERROR(
- SliceHelper<int32>(c, begin_value, sizes_value, &dims));
- }
-
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
- } else {
- // In case `sizes` is not available (`sizes_value` is null),
- // we could try to use `MakeShapeFromShapeTensor` here.
- // If sizes contain -1, we will simply consider it as `Unknown`.
- // This is less than ideal but still an improvement of shape inference.
- // The following is an example that returns [None, 1, None] with this
- // code path:
- // z = tf.zeros((1, 2, 3))
- // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
- // m.get_shape().as_list()
- ShapeHandle sizes_value;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
- if (c->RankKnown(sizes_value)) {
- TF_RETURN_IF_ERROR(
- c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
- std::vector<DimensionHandle> dims;
- dims.reserve(c->Rank(sizes_value));
- for (int i = 0; i < c->Rank(sizes_value); ++i) {
- dims.emplace_back(c->Dim(sizes_value, i));
- }
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
- }
-
- // We might know the rank of the input.
- if (c->RankKnown(input)) {
- c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
- return Status::OK();
- } else {
- return shape_inference::UnknownShape(c);
- }
- }
+ .SetShapeFn(shape_inference::SliceShape);
- return Status::OK();
- });
+#ifdef INTEL_MKL
+REGISTER_OP("_MklSlice")
+ .Input("input: T")
+ .Input("begin: Index")
+ .Input("size: Index")
+ .Input("mkl_input: uint8")
+ .Input("mkl_begin: uint8")
+ .Input("mkl_size: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: type")
+ .Attr("Index: {int32,int64}")
+ .SetShapeFn(shape_inference::SliceShape);
+#endif
REGISTER_OP("StridedSlice")
.Input("input: T")