diff options
author | 2017-10-02 14:38:34 -0700 | |
---|---|---|
committer | 2017-10-02 14:49:11 -0700 | |
commit | 553d10cfe42edcb6b3b8d748b315f13925fcf28f (patch) | |
tree | 4ac0714f0fbb82e3ce27741450f9e6b27df0dc7c | |
parent | 061897179e9f576380f72fe2131cd48d4af3b581 (diff) |
[TF:XLA] Add support for negative values of "split_dim" argument to Split operator.
PiperOrigin-RevId: 170755169
-rw-r--r-- | tensorflow/compiler/tests/binary_ops_test.py | 46 | ||||
-rw-r--r-- | tensorflow/compiler/tests/randomized_tests.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/split_op.cc | 36 |
3 files changed, 46 insertions, 39 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index f3ea57596e..792c01327c 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -790,28 +790,30 @@ class BinaryOpsTest(XLATestCase): def testSplit(self): for dtype in self.numeric_types: - self._testBinary( - lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x), - np.int32(0), - np.array([[[1], [2]], [[3], [4]], [[5], [6]]], - dtype=dtype), - expected=[ - np.array([[[1], [2]]], dtype=dtype), - np.array([[[3], [4]]], dtype=dtype), - np.array([[[5], [6]]], dtype=dtype), - ], - equality_test=self.ListsAreClose) - - self._testBinary( - lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x), - np.int32(1), - np.array([[[1], [2]], [[3], [4]], [[5], [6]]], - dtype=dtype), - expected=[ - np.array([[[1]], [[3]], [[5]]], dtype=dtype), - np.array([[[2]], [[4]], [[6]]], dtype=dtype), - ], - equality_test=self.ListsAreClose) + for axis in [0, -3]: + self._testBinary( + lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x), + np.int32(axis), + np.array([[[1], [2]], [[3], [4]], [[5], [6]]], + dtype=dtype), + expected=[ + np.array([[[1], [2]]], dtype=dtype), + np.array([[[3], [4]]], dtype=dtype), + np.array([[[5], [6]]], dtype=dtype), + ], + equality_test=self.ListsAreClose) + + for axis in [1, -2]: + self._testBinary( + lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x), + np.int32(axis), + np.array([[[1], [2]], [[3], [4]], [[5], [6]]], + dtype=dtype), + expected=[ + np.array([[[1]], [[3]], [[5]]], dtype=dtype), + np.array([[[2]], [[4]], [[6]]], dtype=dtype), + ], + equality_test=self.ListsAreClose) def testTile(self): for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index b3ec9424c7..7e307f16af 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -2653,7 +2653,8 @@ TEST_F(OpTest, Split) { std::vector<int64> dims = RandomDims(1); std::uniform_int_distribution<int> ud; int32 dim = std::uniform_int_distribution<int32>( - 0, static_cast<int32>(dims.size()) - 1)(generator()); + -static_cast<int32>(dims.size()), + static_cast<int32>(dims.size()) - 1)(generator()); int n = std::uniform_int_distribution<int>(1, 5)(generator()); // Ensure 'dim' is evenly divisible by 'n'. dims[dim] /= n; diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 44ee81461e..795eb1794f 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -33,13 +33,16 @@ class SplitOp : public XlaOpKernel { explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + const int32 num_split = num_outputs(); const TensorShape index_shape = ctx->InputShape(0); + const TensorShape input_shape = ctx->InputShape(1); + xla::Literal literal_index; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index)); - int32 split_dim; + int32 split_dim_orig; if (index_shape.dims() == 0) { - split_dim = literal_index.Get<int>({}); + split_dim_orig = literal_index.Get<int>({}); } else { OP_REQUIRES( ctx, index_shape.dims() == 1, @@ -49,27 +52,28 @@ class SplitOp : public XlaOpKernel { ctx, index_shape.dim_size(0) == 1, errors::InvalidArgument("split_index input to Split Op must be a " "scalar or a vector with 1 element")); - split_dim = literal_index.Get<int>({0}); + split_dim_orig = literal_index.Get<int>({0}); } - const int32 num_split = num_outputs(); - const TensorShape input_shape = ctx->InputShape(1); - - OP_REQUIRES( - ctx, 0 <= split_dim && split_dim < input_shape.dims(), - errors::InvalidArgument("0 <= split_dim < number of input dimensions (", - input_shape.dims(), "), but got ", split_dim)); + int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() + : split_dim_orig; + OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), + errors::InvalidArgument("-input rank(-", input_shape.dims(), + ") <= split_dim < input rank (", + input_shape.dims(), "), but got ", + split_dim_orig)); OP_REQUIRES( ctx, num_split > 0, errors::InvalidArgument( "Number of ways to split should be > 0, but got ", num_split)); - OP_REQUIRES(ctx, input_shape.dim_size(split_dim) % num_split == 0, - errors::InvalidArgument( - "Number of ways to split should evenly divide the split " - "dimension, but got split_dim ", - split_dim, " (size = ", input_shape.dim_size(split_dim), - ") ", "and num_split ", num_split)); + OP_REQUIRES( + ctx, input_shape.dim_size(split_dim) % num_split == 0, + errors::InvalidArgument( + "Number of ways to split should evenly divide the split " + "dimension, but got split_dim ", + split_dim_orig, " (size = ", input_shape.dim_size(split_dim), ") ", + "and num_split ", num_split)); // All the slices are the same size: this is the size along the // split dimension. |