diff options
author | 2017-05-31 13:39:29 -0700 | |
---|---|---|
committer | 2017-05-31 13:43:08 -0700 | |
commit | 0b8070253d6c62ad395a42c3f496c3f21ae5d975 (patch) | |
tree | 991360e089b2a102645a53e4d7aa3f04c4535fba /tensorflow/core/framework/shape_inference.cc | |
parent | bc236cfc3bb5496607a030ff2ae456a8449afb7f (diff) |
Support negative axis for Split op
PiperOrigin-RevId: 157628162
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 56 |
1 files changed, 46 insertions, 10 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index b30a90027c..2cbbf966b8 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -637,27 +637,34 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, return MakeShapeFromPartialTensorShape(partial_shape, out); } -// Returns a new dimension whose value is given by a scalar input tensor. -Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { - const Tensor* t = input_tensor(idx); - if (t == nullptr) { - *out = UnknownDim(); - return Status::OK(); - } +Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) { + // Caller must ensure that <t> is not NULL. const int rank = t->dims(); if (rank != 0) { return errors::InvalidArgument("Input must be scalar but has rank ", rank); } - int64 val; if (t->dtype() == DT_INT32) { - val = t->scalar<int32>()(); + *val = t->scalar<int32>()(); + return Status::OK(); } else if (t->dtype() == DT_INT64) { - val = t->scalar<int64>()(); + *val = t->scalar<int64>()(); + return Status::OK(); } else { return errors::InvalidArgument( "Scalar input for dim size must be int32 or int64"); } +} + +// Returns a new dimension whose value is given by a scalar input tensor. +Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { + int64 val; + const Tensor* t = input_tensor(idx); + if (t == nullptr) { + *out = UnknownDim(); + return Status::OK(); + } + TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val)); if (val < 0) { return errors::InvalidArgument("Dimension size, given by scalar input ", idx, ", must be non-negative but is ", val); @@ -666,6 +673,35 @@ Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { return Status::OK(); } +Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing( + int idx, int input_rank, DimensionHandle* out) { + int64 val; + const Tensor* t = input_tensor(idx); + if (t == nullptr) { + *out = UnknownDim(); + return Status::OK(); + } + TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val)); + if (val < 0) { + if (input_rank < 0) { + *out = UnknownDim(); + return Status::OK(); + } else if (val + input_rank < 0) { + return errors::InvalidArgument("Dimension size, given by scalar input ", + val, " must be in range [-", input_rank, + ", ", input_rank, ")"); + } else { + val += input_rank; + } + } else if (input_rank >= 0 && val >= input_rank) { + return errors::InvalidArgument("Dimension size, given by scalar input ", + val, " must be in range [-", input_rank, + ", ", input_rank, ")"); + } + *out = MakeDim(val); + return Status::OK(); +} + Status InferenceContext::Divide(DimensionHandle dividend, DimensionOrConstant divisor, bool evenly_divisible, DimensionHandle* out) { |