aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Pan Daoxin <daoxin.pan@intel.com>2018-09-19 13:07:07 +0800
committerGravatar Pan Daoxin <daoxin.pan@intel.com>2018-09-19 13:07:07 +0800
commitd7a8e852941e8cd856caafddf7c93d857e83e8b1 (patch)
treefce349d606189cc76f7c12d25c855c33e8727ca9 /tensorflow/core/framework
parent1f7e51560e26992e8e56f6426525c1df1e53b974 (diff)
Move location of Slice shape function.
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc104
-rw-r--r--tensorflow/core/framework/common_shape_fns.h3
2 files changed, 107 insertions, 0 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 20a07d86a2..20922d7884 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1306,6 +1306,110 @@ Status RandomShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
+// This SliceHelper processes the output shape of the `slice`
+// when the tensor of `sizes` is available.
+template <typename T>
+Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
+ const Tensor* sizes_value,
+ std::vector<DimensionHandle>* dims) {
+ auto sizes_vec = sizes_value->vec<T>();
+ for (int i = 0; i < sizes_value->NumElements(); ++i) {
+ DimensionHandle dim = c->Dim(c->input(0), i);
+ if (sizes_vec(i) != -1) {
+ auto dim_val = c->Value(dim);
+ if (sizes_vec(i) < 0) {
+ return errors::InvalidArgument(
+ "Out of bounds slicing on dimension ", i, " of length ", dim_val,
+ ": sizes vector cannot be < -1, but was ", sizes_vec(i));
+ }
+
+ dims->emplace_back(c->MakeDim(sizes_vec(i)));
+ } else {
+ DimensionHandle result;
+ TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
+ dims->emplace_back(result);
+ }
+ }
+
+ return Status::OK();
+}
+
+Status SliceShape(InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+ ShapeHandle begin_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
+ ShapeHandle sizes_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
+
+ // Merge to check compatibility of begin and sizes tensors.
+ TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
+
+ DimensionHandle ndims = c->Dim(begin_shape, 0);
+ if (c->ValueKnown(ndims)) {
+ TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
+ }
+
+ // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
+ // values, even though the `begin` value does not represent a shape.
+ ShapeHandle begin_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
+
+ // We check the tensor value here and will only use
+ // `MakeShapeFromShapeTensor` when `sizes_value` is null.
+ // The reason is that `sizes`might contain -1, which can't
+ // be represented (-1 in the ShapeHandle would mean "unknown").
+ const Tensor* sizes_value = c->input_tensor(2);
+
+ if (sizes_value != nullptr) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
+ std::vector<DimensionHandle> dims;
+ // If the begin and sizes tensors are available, then
+ // we can be precise about the shape of the output.
+ if (sizes_value->dtype() == DT_INT64) {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int64>(c, begin_value, sizes_value, &dims));
+ } else {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int32>(c, begin_value, sizes_value, &dims));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ } else {
+ // In case `sizes` is not available (`sizes_value` is null),
+ // we could try to use `MakeShapeFromShapeTensor` here.
+ // If sizes contain -1, we will simply consider it as `Unknown`.
+ // This is less than ideal but still an improvement of shape inference.
+ // The following is an example that returns [None, 1, None] with this
+ // code path:
+ // z = tf.zeros((1, 2, 3))
+ // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
+ // m.get_shape().as_list()
+ ShapeHandle sizes_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
+ if (c->RankKnown(sizes_value)) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
+ std::vector<DimensionHandle> dims;
+ dims.reserve(c->Rank(sizes_value));
+ for (int i = 0; i < c->Rank(sizes_value); ++i) {
+ dims.emplace_back(c->Dim(sizes_value, i));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ }
+ // We might know the rank of the input.
+ if (c->RankKnown(input)) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
+ return Status::OK();
+ } else {
+ return shape_inference::UnknownShape(c);
+ }
+ }
+
+ return Status::OK();
+}
+
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle values_shape, ShapeHandle shape_shape) {
// Validate ranks.
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index e6f9f935f9..478f796516 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -293,6 +293,9 @@ inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
// Shape function for random operations.
Status RandomShape(shape_inference::InferenceContext* c);
+// Shape function for Slice operator.
+Status SliceShape(shape_inference::InferenceContext* c);
+
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,