aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-31 13:39:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-31 13:43:08 -0700
commit0b8070253d6c62ad395a42c3f496c3f21ae5d975 (patch)
tree991360e089b2a102645a53e4d7aa3f04c4535fba
parentbc236cfc3bb5496607a030ff2ae456a8449afb7f (diff)
Support negative axis for Split op
PiperOrigin-RevId: 157628162
-rw-r--r--tensorflow/core/framework/shape_inference.cc56
-rw-r--r--tensorflow/core/framework/shape_inference.h13
-rw-r--r--tensorflow/core/kernels/split_op.cc28
-rw-r--r--tensorflow/core/kernels/split_v_op.cc17
-rw-r--r--tensorflow/core/ops/array_ops.cc14
-rw-r--r--tensorflow/core/ops/array_ops_test.cc23
-rw-r--r--tensorflow/python/kernel_tests/split_op_test.py42
-rw-r--r--tensorflow/python/ops/array_ops.py2
8 files changed, 145 insertions, 50 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index b30a90027c..2cbbf966b8 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -637,27 +637,34 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
return MakeShapeFromPartialTensorShape(partial_shape, out);
}
-// Returns a new dimension whose value is given by a scalar input tensor.
-Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
- const Tensor* t = input_tensor(idx);
- if (t == nullptr) {
- *out = UnknownDim();
- return Status::OK();
- }
+Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
+ // Caller must ensure that <t> is not NULL.
const int rank = t->dims();
if (rank != 0) {
return errors::InvalidArgument("Input must be scalar but has rank ", rank);
}
- int64 val;
if (t->dtype() == DT_INT32) {
- val = t->scalar<int32>()();
+ *val = t->scalar<int32>()();
+ return Status::OK();
} else if (t->dtype() == DT_INT64) {
- val = t->scalar<int64>()();
+ *val = t->scalar<int64>()();
+ return Status::OK();
} else {
return errors::InvalidArgument(
"Scalar input for dim size must be int32 or int64");
}
+}
+
+// Returns a new dimension whose value is given by a scalar input tensor.
+Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
+ int64 val;
+ const Tensor* t = input_tensor(idx);
+ if (t == nullptr) {
+ *out = UnknownDim();
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
if (val < 0) {
return errors::InvalidArgument("Dimension size, given by scalar input ",
idx, ", must be non-negative but is ", val);
@@ -666,6 +673,35 @@ Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
return Status::OK();
}
+Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
+ int idx, int input_rank, DimensionHandle* out) {
+ int64 val;
+ const Tensor* t = input_tensor(idx);
+ if (t == nullptr) {
+ *out = UnknownDim();
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
+ if (val < 0) {
+ if (input_rank < 0) {
+ *out = UnknownDim();
+ return Status::OK();
+ } else if (val + input_rank < 0) {
+ return errors::InvalidArgument("Dimension size, given by scalar input ",
+ val, " must be in range [-", input_rank,
+ ", ", input_rank, ")");
+ } else {
+ val += input_rank;
+ }
+ } else if (input_rank >= 0 && val >= input_rank) {
+ return errors::InvalidArgument("Dimension size, given by scalar input ",
+ val, " must be in range [-", input_rank,
+ ", ", input_rank, ")");
+ }
+ *out = MakeDim(val);
+ return Status::OK();
+}
+
Status InferenceContext::Divide(DimensionHandle dividend,
DimensionOrConstant divisor,
bool evenly_divisible, DimensionHandle* out) {
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 99bbed64b1..baeab93e30 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -401,11 +401,24 @@ class InferenceContext {
inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
+ // Returns in <val> a scalar value from an input tensor <t>. The input tensor
+ // must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the
+ // input tensor is not NULL.
+ Status GetScalarFromTensor(const Tensor* t, int64* val);
+
// Returns a new dimension whose value is given by a scalar input tensor.
// The input tensor must be in host memory, since it is dereferenced to get
// the value.
Status MakeDimForScalarInput(int idx, DimensionHandle* out);
+ // Returns a new dimension whose value is given by a scalar input tensor.
+ // This allows for a negative input dimension given the rank of a separate
+ // tensor. This rank can be negative if unknown.
+ // The input tensor must be in host memory, since it is dereferenced to get
+ // the value.
+ Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
+ DimensionHandle* out);
+
// Look up the attr for the NodeDef being evaluated with name attr_name and
// set *value to its value. If no attr with attr_name is found in def(), or
// the attr does not have a matching type, a non-ok status will be returned.
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index cf22a22fa3..5051e736f1 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -46,15 +46,18 @@ class SplitOpBase : public OpKernel {
explicit SplitOpBase(OpKernelConstruction* c) : OpKernel(c) {}
void ComputeEasyCases(OpKernelContext* context, bool* done) {
- const int32 split_dim = context->input(0).flat<int32>()(0);
- const int32 num_split = num_outputs();
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
+ const int32 split_dim_orig = context->input(0).flat<int32>()(0);
+ const int32 split_dim =
+ split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
+ const int32 num_split = num_outputs();
OP_REQUIRES(
context, 0 <= split_dim && split_dim < input_shape.dims(),
- errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
- input_shape.dims(), "), but got ", split_dim));
+ errors::InvalidArgument("-input rank(-", input.dims(),
+ ") <= split_dim < input rank (", input.dims(),
+ "), but got ", split_dim_orig));
OP_REQUIRES(
context, num_split > 0,
@@ -129,10 +132,12 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> {
if (!context->status().ok() || done) {
return;
}
- const int32 split_dim = context->input(0).flat<int32>()(0);
const int32 num_split = Base::num_outputs();
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
+ const int32 split_dim_orig = context->input(0).flat<int32>()(0);
+ const int32 split_dim =
+ split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
// Android also uses int32 indexing, so check here also.
OP_REQUIRES(
@@ -204,15 +209,16 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
if (!context->status().ok() || done) {
return;
}
- const int32 split_dim = context->input(0).flat<int32>()(0);
- const int32 num_split = Base::num_outputs();
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
+ const int32 split_dim_orig = context->input(0).flat<int32>()(0);
+ const int32 split_dim =
+ split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
+ const int32 num_split = Base::num_outputs();
OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
std::numeric_limits<int32>::max()),
errors::InvalidArgument("Split on GPU requires input size "
"< max int32"));
-
int32 prefix_dim_size;
int32 split_dim_size;
int32 suffix_dim_size;
@@ -260,10 +266,12 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
if (!context->status().ok() || done) {
return;
}
- const int32 split_dim = context->input(0).flat<int32>()(0);
- const int32 num_split = Base::num_outputs();
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
+ const int32 split_dim_orig = context->input(0).flat<int32>()(0);
+ const int32 split_dim =
+ split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
+ const int32 num_split = Base::num_outputs();
// Android also uses int32 indexing, so check here also.
OP_REQUIRES(
diff --git a/tensorflow/core/kernels/split_v_op.cc b/tensorflow/core/kernels/split_v_op.cc
index 4dff1ea046..0eae0328bd 100644
--- a/tensorflow/core/kernels/split_v_op.cc
+++ b/tensorflow/core/kernels/split_v_op.cc
@@ -55,7 +55,9 @@ class SplitVOpBase : public OpKernel {
const TensorShape& input_shape = input.shape();
const Tensor& split_tensor = context->input(1);
- const int32 split_dim = context->input(2).flat<int32>()(0);
+ const int32 split_dim_orig = context->input(2).flat<int32>()(0);
+ const int32 split_dim =
+ split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
OP_REQUIRES(
context,
@@ -79,8 +81,9 @@ class SplitVOpBase : public OpKernel {
OP_REQUIRES(
context, 0 <= split_dim && split_dim < input.dims(),
- errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
- input.dims(), "), but got ", split_dim));
+ errors::InvalidArgument("-input rank(-", input.dims(),
+ ") <= split_dim < input rank (", input.dims(),
+ "), but got ", split_dim_orig));
Tlen input_size_split_dim = input_shape.dim_size(split_dim);
@@ -187,7 +190,9 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
const int32 num_split = Base::num_outputs();
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
- const int32 split_dim = context->input(2).flat<int32>()(0);
+ const int32 split_dim_orig = context->input(2).flat<int32>()(0);
+ const int32 split_dim =
+ split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
// Android also uses int32 indexing, so check here also.
OP_REQUIRES(
@@ -257,7 +262,9 @@ class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
const int32 num_split = Base::num_outputs();
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
- const int32 split_dim = context->input(2).flat<int32>()(0);
+ const int32 split_dim_orig = context->input(2).flat<int32>()(0);
+ const int32 split_dim =
+ split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
std::numeric_limits<int32>::max()),
errors::InvalidArgument("Split on GPU requires input size "
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index b7d97e50e1..40400255a1 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -456,9 +456,10 @@ REGISTER_OP("Split")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
DimensionHandle split_dimension;
- TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(0, &split_dimension));
- int num_split = c->num_outputs();
ShapeHandle input = c->input(1);
+ TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
+ 0, c->Rank(input), &split_dimension));
+ int num_split = c->num_outputs();
ShapeHandle out;
if (!c->ValueKnown(split_dimension)) {
if (c->RankKnown(input)) {
@@ -484,7 +485,7 @@ REGISTER_OP("Split")
Splits a tensor into `num_split` tensors along one dimension.
split_dim: 0-D. The dimension along which to split. Must be in the range
- `[0, rank(value))`.
+ `[-rank(value), rank(value))`.
num_split: The number of ways to split. Must evenly divide
`value.shape[split_dim]`.
value: The tensor to split.
@@ -503,9 +504,10 @@ REGISTER_OP("SplitV")
.Attr("Tlen: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
DimensionHandle split_dimension;
- TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &split_dimension));
- int32 num_outputs = c->num_outputs();
ShapeHandle input = c->input(0);
+ TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
+ 2, c->Rank(input), &split_dimension));
+ int32 num_outputs = c->num_outputs();
int32 rank = c->Rank(input);
ShapeHandle output_shape;
const Tensor* size_splits = c->input_tensor(1);
@@ -594,7 +596,7 @@ size_splits: list containing the sizes of each output tensor along the split
dimension. Must sum to the dimension of value along split_dim.
Can contain one -1 indicating that dimension is to be inferred.
split_dim: 0-D. The dimension along which to split. Must be in the range
- `[0, rank(value))`.
+ `[-rank(value), rank(value))`.
output: Tensors whose shape matches that of `value`
except along `split_dim`, where their sizes are
`size_splits[i]`.
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 97842b24da..1be68b6000 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -993,6 +993,9 @@ TEST(ArrayOpsTest, Split_ShapeFn) {
// If the rank is known, we know the rank of each output.
INFER_OK(op, "?;[?,?]", "[?,?];[?,?]");
+ // split_dim is unknown but other inputs are known.
+ INFER_OK(op, "?;[1,4]", "[?,?];[?,?]");
+
// split_dim is known.
Tensor split_dim = test::AsTensor<int32>({1, 2});
op.input_tensors[0] = &split_dim;
@@ -1004,6 +1007,26 @@ TEST(ArrayOpsTest, Split_ShapeFn) {
INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
"?;[1,5]");
+
+ // split_dim too large.
+ split_dim = test::AsScalar<int32>(3);
+ INFER_ERROR(
+ "Dimension size, given by scalar input 3 must be in range [-3, 3)", op,
+ "?;[1,4,8]");
+
+ // Negative split_dim.
+ split_dim = test::AsScalar<int32>(-1);
+ INFER_OK(op, "?;?", "?;?");
+ INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
+ INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
+ INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
+ INFER_OK(op, "?;[1,4,8]", "[d1_0,d1_1,4];[d1_0,d1_1,4]");
+ split_dim = test::AsScalar<int32>(-2);
+ INFER_OK(op, "?;[1,4,8]", "[d1_0,2,d1_2];[d1_0,2,d1_2]");
+ split_dim = test::AsScalar<int32>(-4);
+ INFER_ERROR(
+ "Dimension size, given by scalar input -4 must be in range [-3, 3)", op,
+ "?;[1,4,8]");
}
TEST(ArrayOpsTest, Tile_ShapeFn) {
diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py
index 3dcafd2496..b44dc037f1 100644
--- a/tensorflow/python/kernel_tests/split_op_test.py
+++ b/tensorflow/python/kernel_tests/split_op_test.py
@@ -53,20 +53,21 @@ class SplitOpTest(test.TestCase):
model_input = array_ops.placeholder(dtypes.float32)
inp = np.zeros((1, 10))
# check that we still fail at runtime if the shapes were unknown
- with self.test_session(use_gpu=False) as sess:
+ with self.test_session(use_gpu=True) as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(array_ops.split(model_input, [4]), {model_input: inp})
# test that we can pass a scalar Tensor as num_splits
- with self.test_session(use_gpu=False) as sess:
- result = sess.run(
- array_ops.split(
- array_ops.ones([4, 4]),
- num_or_size_splits=array_ops.ones([2, 2]).get_shape()[1],
- axis=0))
+ for axis in [0, -2]:
+ with self.test_session(use_gpu=True) as sess:
+ result = sess.run(
+ array_ops.split(
+ array_ops.ones([4, 4]),
+ num_or_size_splits=array_ops.ones([2, 2]).get_shape()[1],
+ axis=axis))
- self.assertEqual(result[0].shape, (2, 4))
- self.assertEqual(result[1].shape, (2, 4))
+ self.assertEqual(result[0].shape, (2, 4))
+ self.assertEqual(result[1].shape, (2, 4))
# test that none split dimensions remain, even if we don't know how
# the split_dim will be split, but we do know the axis
@@ -80,7 +81,7 @@ class SplitOpTest(test.TestCase):
model_input2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
result = array_ops.split(model_input2, [2, 2], axis=0)[0]
- with self.test_session(use_gpu=False) as sess:
+ with self.test_session(use_gpu=True) as sess:
sess.run(result, feed_dict={model_input2: np.ones([4, 2])})
def testExplicitNum(self):
@@ -116,7 +117,7 @@ class SplitOpTest(test.TestCase):
def _RunAndVerifyVariable(self, dtype, large_num_splits=False):
# Random dims of rank 5
shape = np.random.randint(1, 5, size=5)
- split_dim = np.random.randint(0, 5)
+ split_dim = np.random.randint(-5, 5)
if large_num_splits:
num_split = np.random.randint(16, 25)
else:
@@ -180,12 +181,13 @@ class SplitOpTest(test.TestCase):
self.assertAllEqual(result[:, 1:4], inp_grads[1])
def testOutputShape(self):
- with self.test_session(use_gpu=True):
- tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
- size_splits = [3, 7, 2]
- outputs = array_ops.split(tensor, size_splits, 1)
- for i, output in enumerate(outputs):
- self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]])
+ for axis in [1, -1]:
+ with self.test_session(use_gpu=True):
+ tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
+ size_splits = [3, 7, 2]
+ outputs = array_ops.split(tensor, size_splits, axis)
+ for i, output in enumerate(outputs):
+ self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]])
def _compare(self, x, dim, num):
np_ans = np.split(x, num, dim)
@@ -246,7 +248,7 @@ class SplitOpTest(test.TestCase):
def _RunAndVerify(self, dtype, large_num_splits=False):
# Random dims of rank 5
shape = np.random.randint(0, 5, size=5)
- split_dim = np.random.randint(0, 5)
+ split_dim = np.random.randint(-5, 5)
if large_num_splits:
num_split = np.random.randint(9, 15)
else:
@@ -295,6 +297,10 @@ class SplitOpTest(test.TestCase):
with self.assertRaises(ValueError):
array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=2)
+ # split dim less than -(rank of input)
+ with self.assertRaises(ValueError):
+ array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=-3)
+
# num_split does not evenly divide the size in split_dim.
with self.assertRaisesRegexp(ValueError, "should evenly divide"):
array_ops.split(value=[0, 1, 2, 3], num_or_size_splits=3, axis=0)
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 8e5cb8e251..4d1a260ffe 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1196,7 +1196,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
evenly divide `value.shape[axis]`; otherwise the sum of sizes along the
split dimension must match that of the `value`.
axis: A 0-D `int32` `Tensor`. The dimension along which to split.
- Must be in the range `[0, rank(value))`. Defaults to 0.
+ Must be in the range `[-rank(value), rank(value))`. Defaults to 0.
num: Optional, used to specify the number of outputs when it cannot be
inferred from the shape of `size_splits`.
name: A name for the operation (optional).