diff options
author | 2018-09-07 02:03:58 -0700 | |
---|---|---|
committer | 2018-09-07 02:10:34 -0700 | |
commit | 5a635e3472e16007830fca533c35b2f63fc4f898 (patch) | |
tree | c691c87638289e04ee1165d80fabc122e8f61dd2 /tensorflow/compiler/tests | |
parent | 0358d57043067c6ea97ea7dd6c246806f56cd22a (diff) |
[TF:XLA] Split XLA Concat Ops that fail on large sets of inputs.
GPU would fail due to having too many parameters to fit in memory because Concat's signature is variadic and can have an unlimited number of inputs.
PiperOrigin-RevId: 211942734
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r-- | tensorflow/compiler/tests/concat_ops_test.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 37e5318bb5..3f68b482d8 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -291,6 +291,43 @@ class ConcatTest(xla_test.XLATestCase): ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"): array_ops.concat([scalar, scalar, scalar], dim) + def testConcatLargeNumberOfTensors(self): + # CPU is too slow on the large data set. + if self.device in ["XLA_CPU"]: + return + + with self.cached_session(): + with self.test_scope(): + for concat_dim in range(2): + params = {} + p = [] + shape = np.array([7, 13]) + num_tensors = 3999 + for i in np.arange(num_tensors): + input_shape = shape + placeholder = array_ops.placeholder( + dtypes.float32, shape=input_shape) + p.append(placeholder) + params[placeholder] = np.random.rand(*input_shape).astype( + np.float32) + + concat_inputs = p + c = array_ops.concat(concat_inputs, concat_dim) + result = c.eval(feed_dict=params) + + self.assertEqual(result.shape, c.get_shape()) + cur_offset = 0 + + for i in np.arange(num_tensors): + # The index into the result is the ':' along all dimensions + # except the concat_dim. slice(0, size) is used for ':', and + # a list of slices is used to index into result. + index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)] + index[concat_dim] = slice( + cur_offset, cur_offset + params[p[i]].shape[concat_dim]) + cur_offset += params[p[i]].shape[concat_dim] + self.assertAllEqual(result[index], params[p[i]]) + class ConcatOffsetTest(xla_test.XLATestCase): |