diff options
author | Pan Daoxin <daoxin.pan@intel.com> | 2018-09-19 13:07:07 +0800 |
---|---|---|
committer | Pan Daoxin <daoxin.pan@intel.com> | 2018-09-19 13:07:07 +0800 |
commit | d7a8e852941e8cd856caafddf7c93d857e83e8b1 (patch) | |
tree | fce349d606189cc76f7c12d25c855c33e8727ca9 /tensorflow/core/framework | |
parent | 1f7e51560e26992e8e56f6426525c1df1e53b974 (diff) |
Move location of Slice shape function.
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 104 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.h | 3 |
2 files changed, 107 insertions, 0 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 20a07d86a2..20922d7884 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1306,6 +1306,110 @@ Status RandomShape(shape_inference::InferenceContext* c) { return Status::OK(); } +// This SliceHelper processes the output shape of the `slice` +// when the tensor of `sizes` is available. +template <typename T> +Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, + const Tensor* sizes_value, + std::vector<DimensionHandle>* dims) { + auto sizes_vec = sizes_value->vec<T>(); + for (int i = 0; i < sizes_value->NumElements(); ++i) { + DimensionHandle dim = c->Dim(c->input(0), i); + if (sizes_vec(i) != -1) { + auto dim_val = c->Value(dim); + if (sizes_vec(i) < 0) { + return errors::InvalidArgument( + "Out of bounds slicing on dimension ", i, " of length ", dim_val, + ": sizes vector cannot be < -1, but was ", sizes_vec(i)); + } + + dims->emplace_back(c->MakeDim(sizes_vec(i))); + } else { + DimensionHandle result; + TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result)); + dims->emplace_back(result); + } + } + + return Status::OK(); +} + +Status SliceShape(InferenceContext* c) { + ShapeHandle input = c->input(0); + ShapeHandle begin_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); + ShapeHandle sizes_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape)); + + // Merge to check compatibility of begin and sizes tensors. + TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape)); + + DimensionHandle ndims = c->Dim(begin_shape, 0); + if (c->ValueKnown(ndims)) { + TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input)); + } + + // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known + // values, even though the `begin` value does not represent a shape. + ShapeHandle begin_value; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); + + // We check the tensor value here and will only use + // `MakeShapeFromShapeTensor` when `sizes_value` is null. + // The reason is that `sizes`might contain -1, which can't + // be represented (-1 in the ShapeHandle would mean "unknown"). + const Tensor* sizes_value = c->input_tensor(2); + + if (sizes_value != nullptr) { + TF_RETURN_IF_ERROR( + c->WithRank(begin_value, sizes_value->NumElements(), &begin_value)); + std::vector<DimensionHandle> dims; + // If the begin and sizes tensors are available, then + // we can be precise about the shape of the output. + if (sizes_value->dtype() == DT_INT64) { + TF_RETURN_IF_ERROR( + SliceHelper<int64>(c, begin_value, sizes_value, &dims)); + } else { + TF_RETURN_IF_ERROR( + SliceHelper<int32>(c, begin_value, sizes_value, &dims)); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + } else { + // In case `sizes` is not available (`sizes_value` is null), + // we could try to use `MakeShapeFromShapeTensor` here. + // If sizes contain -1, we will simply consider it as `Unknown`. + // This is less than ideal but still an improvement of shape inference. + // The following is an example that returns [None, 1, None] with this + // code path: + // z = tf.zeros((1, 2, 3)) + // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) + // m.get_shape().as_list() + ShapeHandle sizes_value; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); + if (c->RankKnown(sizes_value)) { + TF_RETURN_IF_ERROR( + c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); + std::vector<DimensionHandle> dims; + dims.reserve(c->Rank(sizes_value)); + for (int i = 0; i < c->Rank(sizes_value); ++i) { + dims.emplace_back(c->Dim(sizes_value, i)); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + } + // We might know the rank of the input. + if (c->RankKnown(input)) { + c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); + return Status::OK(); + } else { + return shape_inference::UnknownShape(c); + } + } + + return Status::OK(); +} + Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, ShapeHandle values_shape, ShapeHandle shape_shape) { // Validate ranks. diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index e6f9f935f9..478f796516 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -293,6 +293,9 @@ inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) { // Shape function for random operations. Status RandomShape(shape_inference::InferenceContext* c); +// Shape function for Slice operator. +Status SliceShape(shape_inference::InferenceContext* c); + // Validates the 3 component tensors of a sparse tensor have the proper // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, |