aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-12 01:02:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 01:06:45 -0700
commit4c936f1b220676d0d427f5f38b4111cfb9011b5a (patch)
tree3476cbb85c63e6058d28851fb8eca378fa98159a /tensorflow/compiler/tests
parenta5d649045cf60f8b3dde5ba1ff86285c4bcc1695 (diff)
Automated rollback of commit c5267a54a63a08234a0314888f6cfe842647a73b
PiperOrigin-RevId: 212595533
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/concat_ops_test.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 37e5318bb5..2d225ad226 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -291,6 +291,41 @@ class ConcatTest(xla_test.XLATestCase):
ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"):
array_ops.concat([scalar, scalar, scalar], dim)
+ # The purpose of this is to ensure that XLA on GPU will not run out of memory
+ # with too many arguments.
+ def testConcatLargeNumberOfTensors(self):
+ with self.cached_session():
+ with self.test_scope():
+ for concat_dim in range(2):
+ params = {}
+ p = []
+ shape = np.array([7, 13])
+ num_tensors = 1001
+ for i in np.arange(num_tensors):
+ input_shape = shape
+ placeholder = array_ops.placeholder(
+ dtypes.float32, shape=input_shape)
+ p.append(placeholder)
+ params[placeholder] = np.random.rand(*input_shape).astype(
+ np.float32)
+
+ concat_inputs = p
+ c = array_ops.concat(concat_inputs, concat_dim)
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ cur_offset = 0
+
+ for i in np.arange(num_tensors):
+ # The index into the result is the ':' along all dimensions
+ # except the concat_dim. slice(0, size) is used for ':', and
+ # a list of slices is used to index into result.
+ index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)]
+ index[concat_dim] = slice(
+ cur_offset, cur_offset + params[p[i]].shape[concat_dim])
+ cur_offset += params[p[i]].shape[concat_dim]
+ self.assertAllEqual(result[index], params[p[i]])
+
class ConcatOffsetTest(xla_test.XLATestCase):