aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-10-02 14:38:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-02 14:49:11 -0700
commit553d10cfe42edcb6b3b8d748b315f13925fcf28f (patch)
tree4ac0714f0fbb82e3ce27741450f9e6b27df0dc7c
parent061897179e9f576380f72fe2131cd48d4af3b581 (diff)
[TF:XLA] Add support for negative values of "split_dim" argument to Split operator.
PiperOrigin-RevId: 170755169
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py46
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc36
3 files changed, 46 insertions, 39 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index f3ea57596e..792c01327c 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -790,28 +790,30 @@ class BinaryOpsTest(XLATestCase):
def testSplit(self):
for dtype in self.numeric_types:
- self._testBinary(
- lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x),
- np.int32(0),
- np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
- dtype=dtype),
- expected=[
- np.array([[[1], [2]]], dtype=dtype),
- np.array([[[3], [4]]], dtype=dtype),
- np.array([[[5], [6]]], dtype=dtype),
- ],
- equality_test=self.ListsAreClose)
-
- self._testBinary(
- lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x),
- np.int32(1),
- np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
- dtype=dtype),
- expected=[
- np.array([[[1]], [[3]], [[5]]], dtype=dtype),
- np.array([[[2]], [[4]], [[6]]], dtype=dtype),
- ],
- equality_test=self.ListsAreClose)
+ for axis in [0, -3]:
+ self._testBinary(
+ lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x),
+ np.int32(axis),
+ np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
+ dtype=dtype),
+ expected=[
+ np.array([[[1], [2]]], dtype=dtype),
+ np.array([[[3], [4]]], dtype=dtype),
+ np.array([[[5], [6]]], dtype=dtype),
+ ],
+ equality_test=self.ListsAreClose)
+
+ for axis in [1, -2]:
+ self._testBinary(
+ lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x),
+ np.int32(axis),
+ np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
+ dtype=dtype),
+ expected=[
+ np.array([[[1]], [[3]], [[5]]], dtype=dtype),
+ np.array([[[2]], [[4]], [[6]]], dtype=dtype),
+ ],
+ equality_test=self.ListsAreClose)
def testTile(self):
for dtype in self.numeric_types:
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index b3ec9424c7..7e307f16af 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -2653,7 +2653,8 @@ TEST_F(OpTest, Split) {
std::vector<int64> dims = RandomDims(1);
std::uniform_int_distribution<int> ud;
int32 dim = std::uniform_int_distribution<int32>(
- 0, static_cast<int32>(dims.size()) - 1)(generator());
+ -static_cast<int32>(dims.size()),
+ static_cast<int32>(dims.size()) - 1)(generator());
int n = std::uniform_int_distribution<int>(1, 5)(generator());
// Ensure 'dim' is evenly divisible by 'n'.
dims[dim] /= n;
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index 44ee81461e..795eb1794f 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -33,13 +33,16 @@ class SplitOp : public XlaOpKernel {
explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
+ const int32 num_split = num_outputs();
const TensorShape index_shape = ctx->InputShape(0);
+ const TensorShape input_shape = ctx->InputShape(1);
+
xla::Literal literal_index;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index));
- int32 split_dim;
+ int32 split_dim_orig;
if (index_shape.dims() == 0) {
- split_dim = literal_index.Get<int>({});
+ split_dim_orig = literal_index.Get<int>({});
} else {
OP_REQUIRES(
ctx, index_shape.dims() == 1,
@@ -49,27 +52,28 @@ class SplitOp : public XlaOpKernel {
ctx, index_shape.dim_size(0) == 1,
errors::InvalidArgument("split_index input to Split Op must be a "
"scalar or a vector with 1 element"));
- split_dim = literal_index.Get<int>({0});
+ split_dim_orig = literal_index.Get<int>({0});
}
- const int32 num_split = num_outputs();
- const TensorShape input_shape = ctx->InputShape(1);
-
- OP_REQUIRES(
- ctx, 0 <= split_dim && split_dim < input_shape.dims(),
- errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
- input_shape.dims(), "), but got ", split_dim));
+ int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims()
+ : split_dim_orig;
+ OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(),
+ errors::InvalidArgument("-input rank(-", input_shape.dims(),
+ ") <= split_dim < input rank (",
+ input_shape.dims(), "), but got ",
+ split_dim_orig));
OP_REQUIRES(
ctx, num_split > 0,
errors::InvalidArgument(
"Number of ways to split should be > 0, but got ", num_split));
- OP_REQUIRES(ctx, input_shape.dim_size(split_dim) % num_split == 0,
- errors::InvalidArgument(
- "Number of ways to split should evenly divide the split "
- "dimension, but got split_dim ",
- split_dim, " (size = ", input_shape.dim_size(split_dim),
- ") ", "and num_split ", num_split));
+ OP_REQUIRES(
+ ctx, input_shape.dim_size(split_dim) % num_split == 0,
+ errors::InvalidArgument(
+ "Number of ways to split should evenly divide the split "
+ "dimension, but got split_dim ",
+ split_dim_orig, " (size = ", input_shape.dim_size(split_dim), ") ",
+ "and num_split ", num_split));
// All the slices are the same size: this is the size along the
// split dimension.