aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py14
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc23
2 files changed, 24 insertions, 13 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 30a6d3a74d..0e4efaed86 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1045,6 +1045,20 @@ class BinaryOpsTest(XLATestCase):
],
equality_test=self.ListsAreClose)
+ def splitvOp(x, y): # pylint: disable=invalid-name
+ return array_ops.split(value=y, num_or_size_splits=[2, 3], axis=x)
+ for axis in [1, -1]:
+ self._testBinary(
+ splitvOp,
+ np.int32(axis),
+ np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]],
+ dtype=dtype),
+ expected=[
+ np.array([[0, 1], [5, 6]], dtype=dtype),
+ np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype),
+ ],
+ equality_test=self.ListsAreClose)
+
def testTile(self):
for dtype in self.numeric_types:
self._testBinary(
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index 79c435c90a..43c15e7538 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -111,28 +111,25 @@ class SplitVOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
const int32 num_split = num_outputs();
+ const TensorShape input_shape = ctx->InputShape(0);
const TensorShape index_shape = ctx->InputShape(2);
- xla::Literal literal_index;
- OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &literal_index));
- int32 split_dim;
- OP_REQUIRES(ctx, index_shape.dims() == 0,
- errors::InvalidArgument("split_dim input to Split Op must be a "
- "scalar"));
- split_dim = literal_index.Get<int>({});
+ int64 split_dim_orig;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &split_dim_orig));
+ int64 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims()
+ : split_dim_orig;
+ OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(),
+ errors::InvalidArgument("-input rank(-", input_shape.dims(),
+ ") <= split_dim < input rank (",
+ input_shape.dims(), "), but got ",
+ split_dim_orig));
xla::ComputationDataHandle input = ctx->Input(0);
- const TensorShape input_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, input_shape.dims() > 0,
errors::InvalidArgument("Can't split a 0 dimensional input"));
OP_REQUIRES(
- ctx, 0 <= split_dim && split_dim < input_shape.dims(),
- errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
- input_shape.dims(), "), but got ", split_dim));
-
- OP_REQUIRES(
ctx, num_split > 0,
errors::InvalidArgument(
"Number of ways to split should be > 0, but got ", num_split));