diff options
author | 2018-09-14 02:30:05 -0700 | |
---|---|---|
committer | 2018-09-14 02:36:29 -0700 | |
commit | 54cbee5d034af8693aa39cc5877c3dfcd62d3740 (patch) | |
tree | 716ed89052251374c030978107a71da81f7693c8 | |
parent | c335f3ae6872715c4873eb8af3ff2e42833bc6c0 (diff) |
[TF:XLA] Split XLA Concat Ops that fail on large sets of inputs.
Make the test large to prevent occasional timeouts on CPU. This should normally complete in well under a minute.
PiperOrigin-RevId: 212953337
-rw-r--r-- | tensorflow/compiler/tests/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/compiler/tests/concat_ops_test.py | 35 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/concat_op.cc | 33 |
3 files changed, 69 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 2176eaebe4..97ed554171 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -277,9 +277,10 @@ tf_xla_py_test( ], ) +# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors tf_xla_py_test( name = "concat_ops_test", - size = "medium", + size = "large", srcs = ["concat_ops_test.py"], deps = [ ":xla_test", diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 37e5318bb5..2d225ad226 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -291,6 +291,41 @@ class ConcatTest(xla_test.XLATestCase): ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"): array_ops.concat([scalar, scalar, scalar], dim) + # The purpose of this is to ensure that XLA on GPU will not run out of memory + # with too many arguments. + def testConcatLargeNumberOfTensors(self): + with self.cached_session(): + with self.test_scope(): + for concat_dim in range(2): + params = {} + p = [] + shape = np.array([7, 13]) + num_tensors = 1001 + for i in np.arange(num_tensors): + input_shape = shape + placeholder = array_ops.placeholder( + dtypes.float32, shape=input_shape) + p.append(placeholder) + params[placeholder] = np.random.rand(*input_shape).astype( + np.float32) + + concat_inputs = p + c = array_ops.concat(concat_inputs, concat_dim) + result = c.eval(feed_dict=params) + + self.assertEqual(result.shape, c.get_shape()) + cur_offset = 0 + + for i in np.arange(num_tensors): + # The index into the result is the ':' along all dimensions + # except the concat_dim. slice(0, size) is used for ':', and + # a list of slices is used to index into result. + index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)] + index[concat_dim] = slice( + cur_offset, cur_offset + params[p[i]].shape[concat_dim]) + cur_offset += params[p[i]].shape[concat_dim] + self.assertAllEqual(result[index], params[p[i]]) + class ConcatOffsetTest(xla_test.XLATestCase): diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index f410605104..0ae23aa6df 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -37,6 +37,16 @@ limitations under the License. namespace tensorflow { namespace { +// Used to determine the number of Tensors allowed in a Concat op to prevent +// going over the max gpu parameter memory size. This is an issue because concat +// is variadic and can have an unlimited number of arguments when called. +// Concat ops with more Tensors than this will be split into multiple concat +// ops. +// +// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass +// along with boxing large numbers of parameters. +constexpr int64 kMaxConcatArgsPerOp = 500; + // -------------------------------------------------------------------------- class ConcatBaseOp : public XlaOpKernel { public: @@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel { // Make a vector holding the XlaOp for each of the inputs that has non-zero // elements. std::vector<xla::XlaOp> input_data; + std::vector<xla::XlaOp> partial_concats; int output_concat_dim = 0; const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { @@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel { input_data.push_back(handle); } output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; + + // Concat is associative, so it can be split into many operations when too + // many arguments are in a single op. This is a temporary workaround for + // b/112613927 where too many parameters in an XlaLaunchOp later result in + // too many parameters to a single GPU kernel. + if (i && i % kMaxConcatArgsPerOp == 0) { + partial_concats.push_back( + xla::ConcatInDim(ctx->builder(), input_data, axis)); + input_data.clear(); + } } + // Add any inputs that have not been put into another concat yet. + partial_concats.insert(partial_concats.end(), input_data.begin(), + input_data.end()); VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); + // Don't add an additional "identity" concatenate for better readibility of + // IR. + if (partial_concats.size() == 1) { + ctx->SetOutput(0, partial_concats.front()); + } else { + ctx->SetOutput(0, + xla::ConcatInDim(ctx->builder(), partial_concats, axis)); + } } private: |