diff options
author | 2017-05-31 13:39:29 -0700 | |
---|---|---|
committer | 2017-05-31 13:43:08 -0700 | |
commit | 0b8070253d6c62ad395a42c3f496c3f21ae5d975 (patch) | |
tree | 991360e089b2a102645a53e4d7aa3f04c4535fba | |
parent | bc236cfc3bb5496607a030ff2ae456a8449afb7f (diff) |
Support negative axis for Split op
PiperOrigin-RevId: 157628162
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 56 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.h | 13 | ||||
-rw-r--r-- | tensorflow/core/kernels/split_op.cc | 28 | ||||
-rw-r--r-- | tensorflow/core/kernels/split_v_op.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 14 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops_test.cc | 23 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/split_op_test.py | 42 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 2 |
8 files changed, 145 insertions, 50 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index b30a90027c..2cbbf966b8 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -637,27 +637,34 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, return MakeShapeFromPartialTensorShape(partial_shape, out); } -// Returns a new dimension whose value is given by a scalar input tensor. -Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { - const Tensor* t = input_tensor(idx); - if (t == nullptr) { - *out = UnknownDim(); - return Status::OK(); - } +Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) { + // Caller must ensure that <t> is not NULL. const int rank = t->dims(); if (rank != 0) { return errors::InvalidArgument("Input must be scalar but has rank ", rank); } - int64 val; if (t->dtype() == DT_INT32) { - val = t->scalar<int32>()(); + *val = t->scalar<int32>()(); + return Status::OK(); } else if (t->dtype() == DT_INT64) { - val = t->scalar<int64>()(); + *val = t->scalar<int64>()(); + return Status::OK(); } else { return errors::InvalidArgument( "Scalar input for dim size must be int32 or int64"); } +} + +// Returns a new dimension whose value is given by a scalar input tensor. +Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { + int64 val; + const Tensor* t = input_tensor(idx); + if (t == nullptr) { + *out = UnknownDim(); + return Status::OK(); + } + TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val)); if (val < 0) { return errors::InvalidArgument("Dimension size, given by scalar input ", idx, ", must be non-negative but is ", val); @@ -666,6 +673,35 @@ Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { return Status::OK(); } +Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing( + int idx, int input_rank, DimensionHandle* out) { + int64 val; + const Tensor* t = input_tensor(idx); + if (t == nullptr) { + *out = UnknownDim(); + return Status::OK(); + } + TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val)); + if (val < 0) { + if (input_rank < 0) { + *out = UnknownDim(); + return Status::OK(); + } else if (val + input_rank < 0) { + return errors::InvalidArgument("Dimension size, given by scalar input ", + val, " must be in range [-", input_rank, + ", ", input_rank, ")"); + } else { + val += input_rank; + } + } else if (input_rank >= 0 && val >= input_rank) { + return errors::InvalidArgument("Dimension size, given by scalar input ", + val, " must be in range [-", input_rank, + ", ", input_rank, ")"); + } + *out = MakeDim(val); + return Status::OK(); +} + Status InferenceContext::Divide(DimensionHandle dividend, DimensionOrConstant divisor, bool evenly_divisible, DimensionHandle* out) { diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 99bbed64b1..baeab93e30 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -401,11 +401,24 @@ class InferenceContext { inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); } + // Returns in <val> a scalar value from an input tensor <t>. The input tensor + // must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the + // input tensor is not NULL. + Status GetScalarFromTensor(const Tensor* t, int64* val); + // Returns a new dimension whose value is given by a scalar input tensor. // The input tensor must be in host memory, since it is dereferenced to get // the value. Status MakeDimForScalarInput(int idx, DimensionHandle* out); + // Returns a new dimension whose value is given by a scalar input tensor. + // This allows for a negative input dimension given the rank of a separate + // tensor. This rank can be negative if unknown. + // The input tensor must be in host memory, since it is dereferenced to get + // the value. + Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank, + DimensionHandle* out); + // Look up the attr for the NodeDef being evaluated with name attr_name and // set *value to its value. If no attr with attr_name is found in def(), or // the attr does not have a matching type, a non-ok status will be returned. diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index cf22a22fa3..5051e736f1 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -46,15 +46,18 @@ class SplitOpBase : public OpKernel { explicit SplitOpBase(OpKernelConstruction* c) : OpKernel(c) {} void ComputeEasyCases(OpKernelContext* context, bool* done) { - const int32 split_dim = context->input(0).flat<int32>()(0); - const int32 num_split = num_outputs(); const Tensor& input = context->input(1); const TensorShape& input_shape = input.shape(); + const int32 split_dim_orig = context->input(0).flat<int32>()(0); + const int32 split_dim = + split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; + const int32 num_split = num_outputs(); OP_REQUIRES( context, 0 <= split_dim && split_dim < input_shape.dims(), - errors::InvalidArgument("0 <= split_dim < number of input dimensions (", - input_shape.dims(), "), but got ", split_dim)); + errors::InvalidArgument("-input rank(-", input.dims(), + ") <= split_dim < input rank (", input.dims(), + "), but got ", split_dim_orig)); OP_REQUIRES( context, num_split > 0, @@ -129,10 +132,12 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> { if (!context->status().ok() || done) { return; } - const int32 split_dim = context->input(0).flat<int32>()(0); const int32 num_split = Base::num_outputs(); const Tensor& input = context->input(1); const TensorShape& input_shape = input.shape(); + const int32 split_dim_orig = context->input(0).flat<int32>()(0); + const int32 split_dim = + split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; // Android also uses int32 indexing, so check here also. OP_REQUIRES( @@ -204,15 +209,16 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> { if (!context->status().ok() || done) { return; } - const int32 split_dim = context->input(0).flat<int32>()(0); - const int32 num_split = Base::num_outputs(); const Tensor& input = context->input(1); const TensorShape& input_shape = input.shape(); + const int32 split_dim_orig = context->input(0).flat<int32>()(0); + const int32 split_dim = + split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; + const int32 num_split = Base::num_outputs(); OP_REQUIRES(context, FastBoundsCheck(input.NumElements(), std::numeric_limits<int32>::max()), errors::InvalidArgument("Split on GPU requires input size " "< max int32")); - int32 prefix_dim_size; int32 split_dim_size; int32 suffix_dim_size; @@ -260,10 +266,12 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> { if (!context->status().ok() || done) { return; } - const int32 split_dim = context->input(0).flat<int32>()(0); - const int32 num_split = Base::num_outputs(); const Tensor& input = context->input(1); const TensorShape& input_shape = input.shape(); + const int32 split_dim_orig = context->input(0).flat<int32>()(0); + const int32 split_dim = + split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; + const int32 num_split = Base::num_outputs(); // Android also uses int32 indexing, so check here also. OP_REQUIRES( diff --git a/tensorflow/core/kernels/split_v_op.cc b/tensorflow/core/kernels/split_v_op.cc index 4dff1ea046..0eae0328bd 100644 --- a/tensorflow/core/kernels/split_v_op.cc +++ b/tensorflow/core/kernels/split_v_op.cc @@ -55,7 +55,9 @@ class SplitVOpBase : public OpKernel { const TensorShape& input_shape = input.shape(); const Tensor& split_tensor = context->input(1); - const int32 split_dim = context->input(2).flat<int32>()(0); + const int32 split_dim_orig = context->input(2).flat<int32>()(0); + const int32 split_dim = + split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; OP_REQUIRES( context, @@ -79,8 +81,9 @@ class SplitVOpBase : public OpKernel { OP_REQUIRES( context, 0 <= split_dim && split_dim < input.dims(), - errors::InvalidArgument("0 <= split_dim < number of input dimensions (", - input.dims(), "), but got ", split_dim)); + errors::InvalidArgument("-input rank(-", input.dims(), + ") <= split_dim < input rank (", input.dims(), + "), but got ", split_dim_orig)); Tlen input_size_split_dim = input_shape.dim_size(split_dim); @@ -187,7 +190,9 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> { const int32 num_split = Base::num_outputs(); const Tensor& input = context->input(0); const TensorShape& input_shape = input.shape(); - const int32 split_dim = context->input(2).flat<int32>()(0); + const int32 split_dim_orig = context->input(2).flat<int32>()(0); + const int32 split_dim = + split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; // Android also uses int32 indexing, so check here also. OP_REQUIRES( @@ -257,7 +262,9 @@ class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> { const int32 num_split = Base::num_outputs(); const Tensor& input = context->input(0); const TensorShape& input_shape = input.shape(); - const int32 split_dim = context->input(2).flat<int32>()(0); + const int32 split_dim_orig = context->input(2).flat<int32>()(0); + const int32 split_dim = + split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; OP_REQUIRES(context, FastBoundsCheck(input.NumElements(), std::numeric_limits<int32>::max()), errors::InvalidArgument("Split on GPU requires input size " diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index b7d97e50e1..40400255a1 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -456,9 +456,10 @@ REGISTER_OP("Split") .Attr("T: type") .SetShapeFn([](InferenceContext* c) { DimensionHandle split_dimension; - TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(0, &split_dimension)); - int num_split = c->num_outputs(); ShapeHandle input = c->input(1); + TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing( + 0, c->Rank(input), &split_dimension)); + int num_split = c->num_outputs(); ShapeHandle out; if (!c->ValueKnown(split_dimension)) { if (c->RankKnown(input)) { @@ -484,7 +485,7 @@ REGISTER_OP("Split") Splits a tensor into `num_split` tensors along one dimension. split_dim: 0-D. The dimension along which to split. Must be in the range - `[0, rank(value))`. + `[-rank(value), rank(value))`. num_split: The number of ways to split. Must evenly divide `value.shape[split_dim]`. value: The tensor to split. @@ -503,9 +504,10 @@ REGISTER_OP("SplitV") .Attr("Tlen: {int32, int64} = DT_INT64") .SetShapeFn([](InferenceContext* c) { DimensionHandle split_dimension; - TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &split_dimension)); - int32 num_outputs = c->num_outputs(); ShapeHandle input = c->input(0); + TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing( + 2, c->Rank(input), &split_dimension)); + int32 num_outputs = c->num_outputs(); int32 rank = c->Rank(input); ShapeHandle output_shape; const Tensor* size_splits = c->input_tensor(1); @@ -594,7 +596,7 @@ size_splits: list containing the sizes of each output tensor along the split dimension. Must sum to the dimension of value along split_dim. Can contain one -1 indicating that dimension is to be inferred. split_dim: 0-D. The dimension along which to split. Must be in the range - `[0, rank(value))`. + `[-rank(value), rank(value))`. output: Tensors whose shape matches that of `value` except along `split_dim`, where their sizes are `size_splits[i]`. diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 97842b24da..1be68b6000 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -993,6 +993,9 @@ TEST(ArrayOpsTest, Split_ShapeFn) { // If the rank is known, we know the rank of each output. INFER_OK(op, "?;[?,?]", "[?,?];[?,?]"); + // split_dim is unknown but other inputs are known. + INFER_OK(op, "?;[1,4]", "[?,?];[?,?]"); + // split_dim is known. Tensor split_dim = test::AsTensor<int32>({1, 2}); op.input_tensors[0] = &split_dim; @@ -1004,6 +1007,26 @@ TEST(ArrayOpsTest, Split_ShapeFn) { INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]"); INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op, "?;[1,5]"); + + // split_dim too large. + split_dim = test::AsScalar<int32>(3); + INFER_ERROR( + "Dimension size, given by scalar input 3 must be in range [-3, 3)", op, + "?;[1,4,8]"); + + // Negative split_dim. + split_dim = test::AsScalar<int32>(-1); + INFER_OK(op, "?;?", "?;?"); + INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]"); + INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]"); + INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]"); + INFER_OK(op, "?;[1,4,8]", "[d1_0,d1_1,4];[d1_0,d1_1,4]"); + split_dim = test::AsScalar<int32>(-2); + INFER_OK(op, "?;[1,4,8]", "[d1_0,2,d1_2];[d1_0,2,d1_2]"); + split_dim = test::AsScalar<int32>(-4); + INFER_ERROR( + "Dimension size, given by scalar input -4 must be in range [-3, 3)", op, + "?;[1,4,8]"); } TEST(ArrayOpsTest, Tile_ShapeFn) { diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py index 3dcafd2496..b44dc037f1 100644 --- a/tensorflow/python/kernel_tests/split_op_test.py +++ b/tensorflow/python/kernel_tests/split_op_test.py @@ -53,20 +53,21 @@ class SplitOpTest(test.TestCase): model_input = array_ops.placeholder(dtypes.float32) inp = np.zeros((1, 10)) # check that we still fail at runtime if the shapes were unknown - with self.test_session(use_gpu=False) as sess: + with self.test_session(use_gpu=True) as sess: with self.assertRaises(errors_impl.InvalidArgumentError): sess.run(array_ops.split(model_input, [4]), {model_input: inp}) # test that we can pass a scalar Tensor as num_splits - with self.test_session(use_gpu=False) as sess: - result = sess.run( - array_ops.split( - array_ops.ones([4, 4]), - num_or_size_splits=array_ops.ones([2, 2]).get_shape()[1], - axis=0)) + for axis in [0, -2]: + with self.test_session(use_gpu=True) as sess: + result = sess.run( + array_ops.split( + array_ops.ones([4, 4]), + num_or_size_splits=array_ops.ones([2, 2]).get_shape()[1], + axis=axis)) - self.assertEqual(result[0].shape, (2, 4)) - self.assertEqual(result[1].shape, (2, 4)) + self.assertEqual(result[0].shape, (2, 4)) + self.assertEqual(result[1].shape, (2, 4)) # test that none split dimensions remain, even if we don't know how # the split_dim will be split, but we do know the axis @@ -80,7 +81,7 @@ class SplitOpTest(test.TestCase): model_input2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) result = array_ops.split(model_input2, [2, 2], axis=0)[0] - with self.test_session(use_gpu=False) as sess: + with self.test_session(use_gpu=True) as sess: sess.run(result, feed_dict={model_input2: np.ones([4, 2])}) def testExplicitNum(self): @@ -116,7 +117,7 @@ class SplitOpTest(test.TestCase): def _RunAndVerifyVariable(self, dtype, large_num_splits=False): # Random dims of rank 5 shape = np.random.randint(1, 5, size=5) - split_dim = np.random.randint(0, 5) + split_dim = np.random.randint(-5, 5) if large_num_splits: num_split = np.random.randint(16, 25) else: @@ -180,12 +181,13 @@ class SplitOpTest(test.TestCase): self.assertAllEqual(result[:, 1:4], inp_grads[1]) def testOutputShape(self): - with self.test_session(use_gpu=True): - tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12]) - size_splits = [3, 7, 2] - outputs = array_ops.split(tensor, size_splits, 1) - for i, output in enumerate(outputs): - self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]]) + for axis in [1, -1]: + with self.test_session(use_gpu=True): + tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12]) + size_splits = [3, 7, 2] + outputs = array_ops.split(tensor, size_splits, axis) + for i, output in enumerate(outputs): + self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]]) def _compare(self, x, dim, num): np_ans = np.split(x, num, dim) @@ -246,7 +248,7 @@ class SplitOpTest(test.TestCase): def _RunAndVerify(self, dtype, large_num_splits=False): # Random dims of rank 5 shape = np.random.randint(0, 5, size=5) - split_dim = np.random.randint(0, 5) + split_dim = np.random.randint(-5, 5) if large_num_splits: num_split = np.random.randint(9, 15) else: @@ -295,6 +297,10 @@ class SplitOpTest(test.TestCase): with self.assertRaises(ValueError): array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=2) + # split dim less than -(rank of input) + with self.assertRaises(ValueError): + array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=-3) + # num_split does not evenly divide the size in split_dim. with self.assertRaisesRegexp(ValueError, "should evenly divide"): array_ops.split(value=[0, 1, 2, 3], num_or_size_splits=3, axis=0) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 8e5cb8e251..4d1a260ffe 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1196,7 +1196,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"): evenly divide `value.shape[axis]`; otherwise the sum of sizes along the split dimension must match that of the `value`. axis: A 0-D `int32` `Tensor`. The dimension along which to split. - Must be in the range `[0, rank(value))`. Defaults to 0. + Must be in the range `[-rank(value), rank(value))`. Defaults to 0. num: Optional, used to specify the number of outputs when it cannot be inferred from the shape of `size_splits`. name: A name for the operation (optional). |