aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-12 01:49:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 01:53:01 -0700
commit6bb429b7772bead4e386cb22b6ab2aefa520442e (patch)
tree3e8850d0ee4db999a281d15418431efdc6aba942 /tensorflow/compiler/tf2xla
parent4c936f1b220676d0d427f5f38b4111cfb9011b5a (diff)
Automated rollback of commit 4c936f1b220676d0d427f5f38b4111cfb9011b5a
PiperOrigin-RevId: 212600364
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc33
1 files changed, 1 insertions, 32 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index 0ae23aa6df..f410605104 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -37,16 +37,6 @@ limitations under the License.
namespace tensorflow {
namespace {
-// Used to determine the number of Tensors allowed in a Concat op to prevent
-// going over the max gpu parameter memory size. This is an issue because concat
-// is variadic and can have an unlimited number of arguments when called.
-// Concat ops with more Tensors than this will be split into multiple concat
-// ops.
-//
-// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass
-// along with boxing large numbers of parameters.
-constexpr int64 kMaxConcatArgsPerOp = 500;
-
// --------------------------------------------------------------------------
class ConcatBaseOp : public XlaOpKernel {
public:
@@ -84,7 +74,6 @@ class ConcatBaseOp : public XlaOpKernel {
// Make a vector holding the XlaOp for each of the inputs that has non-zero
// elements.
std::vector<xla::XlaOp> input_data;
- std::vector<xla::XlaOp> partial_concats;
int output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
@@ -105,30 +94,10 @@ class ConcatBaseOp : public XlaOpKernel {
input_data.push_back(handle);
}
output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1;
-
- // Concat is associative, so it can be split into many operations when too
- // many arguments are in a single op. This is a temporary workaround for
- // b/112613927 where too many parameters in an XlaLaunchOp later result in
- // too many parameters to a single GPU kernel.
- if (i && i % kMaxConcatArgsPerOp == 0) {
- partial_concats.push_back(
- xla::ConcatInDim(ctx->builder(), input_data, axis));
- input_data.clear();
- }
}
- // Add any inputs that have not been put into another concat yet.
- partial_concats.insert(partial_concats.end(), input_data.begin(),
- input_data.end());
VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
- // Don't add an additional "identity" concatenate for better readibility of
- // IR.
- if (partial_concats.size() == 1) {
- ctx->SetOutput(0, partial_concats.front());
- } else {
- ctx->SetOutput(0,
- xla::ConcatInDim(ctx->builder(), partial_concats, axis));
- }
+ ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis));
}
private: