diff options
author | 2018-09-19 13:07:07 +0800 | |
---|---|---|
committer | 2018-09-19 13:07:07 +0800 | |
commit | d7a8e852941e8cd856caafddf7c93d857e83e8b1 (patch) | |
tree | fce349d606189cc76f7c12d25c855c33e8727ca9 /tensorflow/core/ops/array_ops.cc | |
parent | 1f7e51560e26992e8e56f6426525c1df1e53b974 (diff) |
Move location of Slice shape function.
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 182 |
1 files changed, 2 insertions, 180 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 2dec430710..325690eded 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1531,37 +1531,6 @@ REGISTER_OP("Size") .Attr("out_type: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ScalarShape); -namespace { - -// This SliceHelper processes the output shape of the `slice` -// when the tensor of `sizes` is available. -template <typename T> -Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, - const Tensor* sizes_value, - std::vector<DimensionHandle>* dims) { - auto sizes_vec = sizes_value->vec<T>(); - for (int i = 0; i < sizes_value->NumElements(); ++i) { - DimensionHandle dim = c->Dim(c->input(0), i); - if (sizes_vec(i) != -1) { - auto dim_val = c->Value(dim); - if (sizes_vec(i) < 0) { - return errors::InvalidArgument( - "Out of bounds slicing on dimension ", i, " of length ", dim_val, - ": sizes vector cannot be < -1, but was ", sizes_vec(i)); - } - - dims->emplace_back(c->MakeDim(sizes_vec(i))); - } else { - DimensionHandle result; - TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result)); - dims->emplace_back(result); - } - } - - return Status::OK(); -} -} // namespace - // -------------------------------------------------------------------------- REGISTER_OP("Slice") .Input("input: T") @@ -1571,81 +1540,7 @@ REGISTER_OP("Slice") .Attr("T: type") .Attr("Index: {int32,int64}") .SetShapeFn([](InferenceContext* c) { - ShapeHandle input = c->input(0); - ShapeHandle begin_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); - ShapeHandle sizes_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape)); - - // Merge to check compatibility of begin and sizes tensors. - TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape)); - - DimensionHandle ndims = c->Dim(begin_shape, 0); - if (c->ValueKnown(ndims)) { - TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input)); - } - - // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known - // values, even though the `begin` value does not represent a shape. - ShapeHandle begin_value; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); - - // We check the tensor value here and will only use - // `MakeShapeFromShapeTensor` when `sizes_value` is null. - // The reason is that `sizes`might contain -1, which can't - // be represented (-1 in the ShapeHandle would mean "unknown". - const Tensor* sizes_value = c->input_tensor(2); - - if (sizes_value != nullptr) { - TF_RETURN_IF_ERROR( - c->WithRank(begin_value, sizes_value->NumElements(), &begin_value)); - std::vector<DimensionHandle> dims; - // If the begin and sizes tensors are available, then - // we can be precise about the shape of the output. - if (sizes_value->dtype() == DT_INT64) { - TF_RETURN_IF_ERROR( - SliceHelper<int64>(c, begin_value, sizes_value, &dims)); - } else { - TF_RETURN_IF_ERROR( - SliceHelper<int32>(c, begin_value, sizes_value, &dims)); - } - - c->set_output(0, c->MakeShape(dims)); - return Status::OK(); - } else { - // In case `sizes` is not available (`sizes_value` is null), - // we could try to use `MakeShapeFromShapeTensor` here. - // If sizes contain -1, we will simply consider it as `Unknown`. - // This is less than ideal but still an improvement of shape inference. - // The following is an example that returns [None, 1, None] with this - // code path: - // z = tf.zeros((1, 2, 3)) - // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) - // m.get_shape().as_list() - ShapeHandle sizes_value; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); - if (c->RankKnown(sizes_value)) { - TF_RETURN_IF_ERROR( - c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); - std::vector<DimensionHandle> dims; - dims.reserve(c->Rank(sizes_value)); - for (int i = 0; i < c->Rank(sizes_value); ++i) { - dims.emplace_back(c->Dim(sizes_value, i)); - } - c->set_output(0, c->MakeShape(dims)); - return Status::OK(); - } - - // We might know the rank of the input. - if (c->RankKnown(input)) { - c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); - return Status::OK(); - } else { - return shape_inference::UnknownShape(c); - } - } - - return Status::OK(); + return shape_inference::SliceShape(c); }); #ifdef INTEL_MKL @@ -1661,80 +1556,7 @@ REGISTER_OP("_MklSlice") .Attr("T: type") .Attr("Index: {int32,int64}") .SetShapeFn([](InferenceContext* c) { - ShapeHandle input = c->input(0); - ShapeHandle begin_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); - ShapeHandle sizes_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape)); - - // Merge to check compatibility of begin and sizes tensors. - TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape)); - - DimensionHandle ndims = c->Dim(begin_shape, 0); - if (c->ValueKnown(ndims)) { - TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input)); - } - - // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known - // values, even though the `begin` value does not represent a shape. - ShapeHandle begin_value; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); - - // NOTE(mrry): We can't use `MakeShapeFromShapeTensor` for `sizes` because - // it might contain -1, which can't be represented. (-1 in the ShapeHandle - // would mean "unknown".) - const Tensor* sizes_value = c->input_tensor(3); - - if (sizes_value != nullptr) { - TF_RETURN_IF_ERROR( - c->WithRank(begin_value, sizes_value->NumElements(), &begin_value)); - std::vector<DimensionHandle> dims; - // If the begin and sizes tensors are available, then - // we can be precise about the shape of the output. - if (sizes_value->dtype() == DT_INT64) { - TF_RETURN_IF_ERROR( - SliceHelper<int64>(c, begin_value, sizes_value, &dims)); - } else { - TF_RETURN_IF_ERROR( - SliceHelper<int32>(c, begin_value, sizes_value, &dims)); - } - - c->set_output(0, c->MakeShape(dims)); - return Status::OK(); - } else { - // In case `sizes` is not available (`sizes_value` is null), - // we could try to use `MakeShapeFromShapeTensor` here. - // If sizes contain -1, we will simply consider it as `Unknown`. - // This is less than ideal but still an improvement of shape inference. - // The following is an example that returns [None, 1, None] with this - // code path: - // z = tf.zeros((1, 2, 3)) - // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) - // m.get_shape().as_list() - ShapeHandle sizes_value; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); - if (c->RankKnown(sizes_value)) { - TF_RETURN_IF_ERROR( - c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); - std::vector<DimensionHandle> dims; - dims.reserve(c->Rank(sizes_value)); - for (int i = 0; i < c->Rank(sizes_value); ++i) { - dims.emplace_back(c->Dim(sizes_value, i)); - } - c->set_output(0, c->MakeShape(dims)); - return Status::OK(); - } - - // We might know the rank of the input. - if (c->RankKnown(input)) { - c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); - return Status::OK(); - } else { - return shape_inference::UnknownShape(c); - } - } - - return Status::OK(); + return shape_inference::SliceShape(c); }); #endif |