aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-12 01:02:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 01:06:45 -0700
commit4c936f1b220676d0d427f5f38b4111cfb9011b5a (patch)
tree3476cbb85c63e6058d28851fb8eca378fa98159a /tensorflow/compiler/tf2xla
parenta5d649045cf60f8b3dde5ba1ff86285c4bcc1695 (diff)
Automated rollback of commit c5267a54a63a08234a0314888f6cfe842647a73b
PiperOrigin-RevId: 212595533
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc33
1 files changed, 32 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index f410605104..0ae23aa6df 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -37,6 +37,16 @@ limitations under the License.
namespace tensorflow {
namespace {
+// Used to determine the number of Tensors allowed in a Concat op to prevent
+// going over the max gpu parameter memory size. This is an issue because concat
+// is variadic and can have an unlimited number of arguments when called.
+// Concat ops with more Tensors than this will be split into multiple concat
+// ops.
+//
+// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass
+// along with boxing large numbers of parameters.
+constexpr int64 kMaxConcatArgsPerOp = 500;
+
// --------------------------------------------------------------------------
class ConcatBaseOp : public XlaOpKernel {
public:
@@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel {
// Make a vector holding the XlaOp for each of the inputs that has non-zero
// elements.
std::vector<xla::XlaOp> input_data;
+ std::vector<xla::XlaOp> partial_concats;
int output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
@@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel {
input_data.push_back(handle);
}
output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1;
+
+ // Concat is associative, so it can be split into many operations when too
+ // many arguments are in a single op. This is a temporary workaround for
+ // b/112613927 where too many parameters in an XlaLaunchOp later result in
+ // too many parameters to a single GPU kernel.
+ if (i && i % kMaxConcatArgsPerOp == 0) {
+ partial_concats.push_back(
+ xla::ConcatInDim(ctx->builder(), input_data, axis));
+ input_data.clear();
+ }
}
+ // Add any inputs that have not been put into another concat yet.
+ partial_concats.insert(partial_concats.end(), input_data.begin(),
+ input_data.end());
VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
- ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis));
+ // Don't add an additional "identity" concatenate for better readibility of
+ // IR.
+ if (partial_concats.size() == 1) {
+ ctx->SetOutput(0, partial_concats.front());
+ } else {
+ ctx->SetOutput(0,
+ xla::ConcatInDim(ctx->builder(), partial_concats, axis));
+ }
}
private: