aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-31 13:39:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-31 13:43:08 -0700
commit0b8070253d6c62ad395a42c3f496c3f21ae5d975 (patch)
tree991360e089b2a102645a53e4d7aa3f04c4535fba /tensorflow/core/framework/shape_inference.cc
parentbc236cfc3bb5496607a030ff2ae456a8449afb7f (diff)
Support negative axis for Split op
PiperOrigin-RevId: 157628162
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc56
1 files changed, 46 insertions, 10 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index b30a90027c..2cbbf966b8 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -637,27 +637,34 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
return MakeShapeFromPartialTensorShape(partial_shape, out);
}
-// Returns a new dimension whose value is given by a scalar input tensor.
-Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
- const Tensor* t = input_tensor(idx);
- if (t == nullptr) {
- *out = UnknownDim();
- return Status::OK();
- }
+Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
+ // Caller must ensure that <t> is not NULL.
const int rank = t->dims();
if (rank != 0) {
return errors::InvalidArgument("Input must be scalar but has rank ", rank);
}
- int64 val;
if (t->dtype() == DT_INT32) {
- val = t->scalar<int32>()();
+ *val = t->scalar<int32>()();
+ return Status::OK();
} else if (t->dtype() == DT_INT64) {
- val = t->scalar<int64>()();
+ *val = t->scalar<int64>()();
+ return Status::OK();
} else {
return errors::InvalidArgument(
"Scalar input for dim size must be int32 or int64");
}
+}
+
+// Returns a new dimension whose value is given by a scalar input tensor.
+Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
+ int64 val;
+ const Tensor* t = input_tensor(idx);
+ if (t == nullptr) {
+ *out = UnknownDim();
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
if (val < 0) {
return errors::InvalidArgument("Dimension size, given by scalar input ",
idx, ", must be non-negative but is ", val);
@@ -666,6 +673,35 @@ Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
return Status::OK();
}
+Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
+ int idx, int input_rank, DimensionHandle* out) {
+ int64 val;
+ const Tensor* t = input_tensor(idx);
+ if (t == nullptr) {
+ *out = UnknownDim();
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
+ if (val < 0) {
+ if (input_rank < 0) {
+ *out = UnknownDim();
+ return Status::OK();
+ } else if (val + input_rank < 0) {
+ return errors::InvalidArgument("Dimension size, given by scalar input ",
+ val, " must be in range [-", input_rank,
+ ", ", input_rank, ")");
+ } else {
+ val += input_rank;
+ }
+ } else if (input_rank >= 0 && val >= input_rank) {
+ return errors::InvalidArgument("Dimension size, given by scalar input ",
+ val, " must be in range [-", input_rank,
+ ", ", input_rank, ")");
+ }
+ *out = MakeDim(val);
+ return Status::OK();
+}
+
Status InferenceContext::Divide(DimensionHandle dividend,
DimensionOrConstant divisor,
bool evenly_divisible, DimensionHandle* out) {