aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/concat_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/concat_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/concat_ops_test.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 37e5318bb5..2d225ad226 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -291,6 +291,41 @@ class ConcatTest(xla_test.XLATestCase):
ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"):
array_ops.concat([scalar, scalar, scalar], dim)
+ # The purpose of this is to ensure that XLA on GPU will not run out of memory
+ # with too many arguments.
+ def testConcatLargeNumberOfTensors(self):
+ with self.cached_session():
+ with self.test_scope():
+ for concat_dim in range(2):
+ params = {}
+ p = []
+ shape = np.array([7, 13])
+ num_tensors = 1001
+ for i in np.arange(num_tensors):
+ input_shape = shape
+ placeholder = array_ops.placeholder(
+ dtypes.float32, shape=input_shape)
+ p.append(placeholder)
+ params[placeholder] = np.random.rand(*input_shape).astype(
+ np.float32)
+
+ concat_inputs = p
+ c = array_ops.concat(concat_inputs, concat_dim)
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ cur_offset = 0
+
+ for i in np.arange(num_tensors):
+ # The index into the result is the ':' along all dimensions
+ # except the concat_dim. slice(0, size) is used for ':', and
+ # a list of slices is used to index into result.
+ index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)]
+ index[concat_dim] = slice(
+ cur_offset, cur_offset + params[p[i]].shape[concat_dim])
+ cur_offset += params[p[i]].shape[concat_dim]
+ self.assertAllEqual(result[index], params[p[i]])
+
class ConcatOffsetTest(xla_test.XLATestCase):