diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-12 01:02:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 01:06:45 -0700 |
commit | 4c936f1b220676d0d427f5f38b4111cfb9011b5a (patch) | |
tree | 3476cbb85c63e6058d28851fb8eca378fa98159a /tensorflow/compiler/tf2xla | |
parent | a5d649045cf60f8b3dde5ba1ff86285c4bcc1695 (diff) |
Automated rollback of commit c5267a54a63a08234a0314888f6cfe842647a73b
PiperOrigin-RevId: 212595533
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/concat_op.cc | 33 |
1 files changed, 32 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index f410605104..0ae23aa6df 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -37,6 +37,16 @@ limitations under the License. namespace tensorflow { namespace { +// Used to determine the number of Tensors allowed in a Concat op to prevent +// going over the max gpu parameter memory size. This is an issue because concat +// is variadic and can have an unlimited number of arguments when called. +// Concat ops with more Tensors than this will be split into multiple concat +// ops. +// +// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass +// along with boxing large numbers of parameters. +constexpr int64 kMaxConcatArgsPerOp = 500; + // -------------------------------------------------------------------------- class ConcatBaseOp : public XlaOpKernel { public: @@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel { // Make a vector holding the XlaOp for each of the inputs that has non-zero // elements. std::vector<xla::XlaOp> input_data; + std::vector<xla::XlaOp> partial_concats; int output_concat_dim = 0; const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { @@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel { input_data.push_back(handle); } output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; + + // Concat is associative, so it can be split into many operations when too + // many arguments are in a single op. This is a temporary workaround for + // b/112613927 where too many parameters in an XlaLaunchOp later result in + // too many parameters to a single GPU kernel. + if (i && i % kMaxConcatArgsPerOp == 0) { + partial_concats.push_back( + xla::ConcatInDim(ctx->builder(), input_data, axis)); + input_data.clear(); + } } + // Add any inputs that have not been put into another concat yet. + partial_concats.insert(partial_concats.end(), input_data.begin(), + input_data.end()); VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); + // Don't add an additional "identity" concatenate for better readibility of + // IR. + if (partial_concats.size() == 1) { + ctx->SetOutput(0, partial_concats.front()); + } else { + ctx->SetOutput(0, + xla::ConcatInDim(ctx->builder(), partial_concats, axis)); + } } private: |