diff options
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 56 |
1 files changed, 46 insertions, 10 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index b30a90027c..2cbbf966b8 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -637,27 +637,34 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, return MakeShapeFromPartialTensorShape(partial_shape, out); } -// Returns a new dimension whose value is given by a scalar input tensor. -Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { - const Tensor* t = input_tensor(idx); - if (t == nullptr) { - *out = UnknownDim(); - return Status::OK(); - } +Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) { + // Caller must ensure that <t> is not NULL. const int rank = t->dims(); if (rank != 0) { return errors::InvalidArgument("Input must be scalar but has rank ", rank); } - int64 val; if (t->dtype() == DT_INT32) { - val = t->scalar<int32>()(); + *val = t->scalar<int32>()(); + return Status::OK(); } else if (t->dtype() == DT_INT64) { - val = t->scalar<int64>()(); + *val = t->scalar<int64>()(); + return Status::OK(); } else { return errors::InvalidArgument( "Scalar input for dim size must be int32 or int64"); } +} + +// Returns a new dimension whose value is given by a scalar input tensor. +Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { + int64 val; + const Tensor* t = input_tensor(idx); + if (t == nullptr) { + *out = UnknownDim(); + return Status::OK(); + } + TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val)); if (val < 0) { return errors::InvalidArgument("Dimension size, given by scalar input ", idx, ", must be non-negative but is ", val); @@ -666,6 +673,35 @@ Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { return Status::OK(); } +Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing( + int idx, int input_rank, DimensionHandle* out) { + int64 val; + const Tensor* t = input_tensor(idx); + if (t == nullptr) { + *out = UnknownDim(); + return Status::OK(); + } + TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val)); + if (val < 0) { + if (input_rank < 0) { + *out = UnknownDim(); + return Status::OK(); + } else if (val + input_rank < 0) { + return errors::InvalidArgument("Dimension size, given by scalar input ", + val, " must be in range [-", input_rank, + ", ", input_rank, ")"); + } else { + val += input_rank; + } + } else if (input_rank >= 0 && val >= input_rank) { + return errors::InvalidArgument("Dimension size, given by scalar input ", + val, " must be in range [-", input_rank, + ", ", input_rank, ")"); + } + *out = MakeDim(val); + return Status::OK(); +} + Status InferenceContext::Divide(DimensionHandle dividend, DimensionOrConstant divisor, bool evenly_divisible, DimensionHandle* out) { |