aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-17 04:30:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-17 04:35:02 -0700
commit2381b6b0efabbccab1a3099500b28d48956c6573 (patch)
treefa71d779915a5ba7de25a9720bb487d3cef55ce9
parent0cffe2dba0c2000a8c719c2ed499a3ee72d6a2b6 (diff)
[XLA] Correct assertions in tf2xla tile_ops.
Zero inputs for the multiples are supported but were being disallowed by an assertion. PiperOrigin-RevId: 209132264
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc2
2 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 0aafda7fb4..5b7001b5a4 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1167,6 +1167,16 @@ class BinaryOpsTest(xla_test.XLATestCase):
for dtype in self.numeric_types:
self._testBinary(
array_ops.tile,
+ np.array([[6], [3], [4]], dtype=dtype),
+ np.array([2, 0], dtype=np.int32),
+ expected=np.empty([6, 0], dtype=dtype))
+ self._testBinary(
+ array_ops.tile,
+ np.array([[6, 3, 4]], dtype=dtype),
+ np.array([2, 0], dtype=np.int32),
+ expected=np.empty([2, 0], dtype=dtype))
+ self._testBinary(
+ array_ops.tile,
np.array([[6]], dtype=dtype),
np.array([1, 2], dtype=np.int32),
expected=np.array([[6, 6]], dtype=dtype))
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index 1233a37565..2c7213f322 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -70,7 +70,7 @@ class TileOp : public XlaOpKernel {
bool one_dimension_is_broadcasted_without_multiple = true;
for (int i = 0; i < input_dims; ++i) {
int multiple = literal.Get<int>({i});
- OP_REQUIRES(ctx, multiple,
+ OP_REQUIRES(ctx, multiple >= 0,
errors::InvalidArgument("Expected multiples[", i,
"] >= 0, but got ", multiple));
int64 new_dim = input_shape.dim_size(i) * multiple;