diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 23:14:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 23:14:39 -0700 |
commit | 08a6cfed1cf0cccc8ff35448266f44fbc55be0bc (patch) | |
tree | 73f61074984cd9dcf05e5d65b454a6ce08484f4a /tensorflow/core/framework | |
parent | d3f14ef70cdf113f9d330c1f7c638003429a1dc4 (diff) | |
parent | d1ab8b71c2115caacfec19d849ddabf7f1f4287b (diff) |
Merge pull request #22076 from Intel-tensorflow:feature/daoxin/slice
PiperOrigin-RevId: 214726180
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 107 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.h | 3 |
2 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 20a07d86a2..50403b4004 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1306,6 +1306,113 @@ Status RandomShape(shape_inference::InferenceContext* c) { return Status::OK(); } +namespace { + +// This SliceHelper processes the output shape of the `slice` +// when the tensor of `sizes` is available. +template <typename T> +Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, + const Tensor* sizes_value, + std::vector<DimensionHandle>* dims) { + auto sizes_vec = sizes_value->vec<T>(); + for (int i = 0; i < sizes_value->NumElements(); ++i) { + DimensionHandle dim = c->Dim(c->input(0), i); + if (sizes_vec(i) != -1) { + auto dim_val = c->Value(dim); + if (sizes_vec(i) < 0) { + return errors::InvalidArgument( + "Out of bounds slicing on dimension ", i, " of length ", dim_val, + ": sizes vector cannot be < -1, but was ", sizes_vec(i)); + } + + dims->emplace_back(c->MakeDim(sizes_vec(i))); + } else { + DimensionHandle result; + TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result)); + dims->emplace_back(result); + } + } + + return Status::OK(); +} +} // namespace + +Status SliceShape(InferenceContext* c) { + ShapeHandle input = c->input(0); + ShapeHandle begin_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); + ShapeHandle sizes_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape)); + + // Merge to check compatibility of begin and sizes tensors. + TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape)); + + DimensionHandle ndims = c->Dim(begin_shape, 0); + if (c->ValueKnown(ndims)) { + TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input)); + } + + // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known + // values, even though the `begin` value does not represent a shape. + ShapeHandle begin_value; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); + + // We check the tensor value here and will only use + // `MakeShapeFromShapeTensor` when `sizes_value` is null. + // The reason is that `sizes` might contain -1, which can't + // be represented (-1 in the ShapeHandle would mean "unknown"). + const Tensor* sizes_value = c->input_tensor(2); + + if (sizes_value != nullptr) { + TF_RETURN_IF_ERROR( + c->WithRank(begin_value, sizes_value->NumElements(), &begin_value)); + std::vector<DimensionHandle> dims; + // If the begin and sizes tensors are available, then + // we can be precise about the shape of the output. + if (sizes_value->dtype() == DT_INT64) { + TF_RETURN_IF_ERROR( + SliceHelper<int64>(c, begin_value, sizes_value, &dims)); + } else { + TF_RETURN_IF_ERROR( + SliceHelper<int32>(c, begin_value, sizes_value, &dims)); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + } else { + // In case `sizes` is not available (`sizes_value` is null), + // we could try to use `MakeShapeFromShapeTensor` here. + // If sizes contain -1, we will simply consider it as `Unknown`. + // This is less than ideal but still an improvement of shape inference. + // The following is an example that returns [None, 1, None] with this + // code path: + // z = tf.zeros((1, 2, 3)) + // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) + // m.get_shape().as_list() + ShapeHandle sizes_value; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); + if (c->RankKnown(sizes_value)) { + TF_RETURN_IF_ERROR( + c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); + std::vector<DimensionHandle> dims; + dims.reserve(c->Rank(sizes_value)); + for (int i = 0; i < c->Rank(sizes_value); ++i) { + dims.emplace_back(c->Dim(sizes_value, i)); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + } + // We might know the rank of the input. + if (c->RankKnown(input)) { + c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); + return Status::OK(); + } else { + return shape_inference::UnknownShape(c); + } + } + + return Status::OK(); +} + Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, ShapeHandle values_shape, ShapeHandle shape_shape) { // Validate ranks. diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index e6f9f935f9..3a496e06ae 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -293,6 +293,9 @@ inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) { // Shape function for random operations. Status RandomShape(shape_inference::InferenceContext* c); +// Shape function for Slice opertaions. +Status SliceShape(shape_inference::InferenceContext* c); + // Validates the 3 component tensors of a sparse tensor have the proper // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, |