aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-14 02:30:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-14 02:36:29 -0700
commit54cbee5d034af8693aa39cc5877c3dfcd62d3740 (patch)
tree716ed89052251374c030978107a71da81f7693c8 /tensorflow/compiler/tests
parentc335f3ae6872715c4873eb8af3ff2e42833bc6c0 (diff)
[TF:XLA] Split XLA Concat Ops that fail on large sets of inputs.
Make the test large to prevent occasional timeouts on CPU. This should normally complete in well under a minute. PiperOrigin-RevId: 212953337
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/BUILD3
-rw-r--r--tensorflow/compiler/tests/concat_ops_test.py35
2 files changed, 37 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 2176eaebe4..97ed554171 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -277,9 +277,10 @@ tf_xla_py_test(
],
)
+# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors
tf_xla_py_test(
name = "concat_ops_test",
- size = "medium",
+ size = "large",
srcs = ["concat_ops_test.py"],
deps = [
":xla_test",
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 37e5318bb5..2d225ad226 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -291,6 +291,41 @@ class ConcatTest(xla_test.XLATestCase):
ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"):
array_ops.concat([scalar, scalar, scalar], dim)
+ # The purpose of this is to ensure that XLA on GPU will not run out of memory
+ # with too many arguments.
+ def testConcatLargeNumberOfTensors(self):
+ with self.cached_session():
+ with self.test_scope():
+ for concat_dim in range(2):
+ params = {}
+ p = []
+ shape = np.array([7, 13])
+ num_tensors = 1001
+ for i in np.arange(num_tensors):
+ input_shape = shape
+ placeholder = array_ops.placeholder(
+ dtypes.float32, shape=input_shape)
+ p.append(placeholder)
+ params[placeholder] = np.random.rand(*input_shape).astype(
+ np.float32)
+
+ concat_inputs = p
+ c = array_ops.concat(concat_inputs, concat_dim)
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ cur_offset = 0
+
+ for i in np.arange(num_tensors):
+ # The index into the result is the ':' along all dimensions
+ # except the concat_dim. slice(0, size) is used for ':', and
+ # a list of slices is used to index into result.
+ index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)]
+ index[concat_dim] = slice(
+ cur_offset, cur_offset + params[p[i]].shape[concat_dim])
+ cur_offset += params[p[i]].shape[concat_dim]
+ self.assertAllEqual(result[index], params[p[i]])
+
class ConcatOffsetTest(xla_test.XLATestCase):