diff options
author | 2018-09-12 01:49:11 -0700 | |
---|---|---|
committer | 2018-09-12 01:53:01 -0700 | |
commit | 6bb429b7772bead4e386cb22b6ab2aefa520442e (patch) | |
tree | 3e8850d0ee4db999a281d15418431efdc6aba942 /tensorflow/compiler/tf2xla | |
parent | 4c936f1b220676d0d427f5f38b4111cfb9011b5a (diff) |
Automated rollback of commit 4c936f1b220676d0d427f5f38b4111cfb9011b5a
PiperOrigin-RevId: 212600364
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/concat_op.cc | 33 |
1 files changed, 1 insertions, 32 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 0ae23aa6df..f410605104 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -37,16 +37,6 @@ limitations under the License. namespace tensorflow { namespace { -// Used to determine the number of Tensors allowed in a Concat op to prevent -// going over the max gpu parameter memory size. This is an issue because concat -// is variadic and can have an unlimited number of arguments when called. -// Concat ops with more Tensors than this will be split into multiple concat -// ops. -// -// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass -// along with boxing large numbers of parameters. -constexpr int64 kMaxConcatArgsPerOp = 500; - // -------------------------------------------------------------------------- class ConcatBaseOp : public XlaOpKernel { public: @@ -84,7 +74,6 @@ class ConcatBaseOp : public XlaOpKernel { // Make a vector holding the XlaOp for each of the inputs that has non-zero // elements. std::vector<xla::XlaOp> input_data; - std::vector<xla::XlaOp> partial_concats; int output_concat_dim = 0; const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { @@ -105,30 +94,10 @@ class ConcatBaseOp : public XlaOpKernel { input_data.push_back(handle); } output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; - - // Concat is associative, so it can be split into many operations when too - // many arguments are in a single op. This is a temporary workaround for - // b/112613927 where too many parameters in an XlaLaunchOp later result in - // too many parameters to a single GPU kernel. - if (i && i % kMaxConcatArgsPerOp == 0) { - partial_concats.push_back( - xla::ConcatInDim(ctx->builder(), input_data, axis)); - input_data.clear(); - } } - // Add any inputs that have not been put into another concat yet. - partial_concats.insert(partial_concats.end(), input_data.begin(), - input_data.end()); VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - // Don't add an additional "identity" concatenate for better readibility of - // IR. - if (partial_concats.size() == 1) { - ctx->SetOutput(0, partial_concats.front()); - } else { - ctx->SetOutput(0, - xla::ConcatInDim(ctx->builder(), partial_concats, axis)); - } + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); } private: |