aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-07 02:03:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 02:10:34 -0700
commit5a635e3472e16007830fca533c35b2f63fc4f898 (patch)
treec691c87638289e04ee1165d80fabc122e8f61dd2 /tensorflow
parent0358d57043067c6ea97ea7dd6c246806f56cd22a (diff)
[TF:XLA] Split XLA Concat Ops that fail on large sets of inputs.
GPU would fail due to having too many parameters to fit in memory because Concat's signature is variadic and can have an unlimited number of inputs. PiperOrigin-RevId: 211942734
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/tests/concat_ops_test.py37
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc33
2 files changed, 69 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 37e5318bb5..3f68b482d8 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -291,6 +291,43 @@ class ConcatTest(xla_test.XLATestCase):
ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"):
array_ops.concat([scalar, scalar, scalar], dim)
+ def testConcatLargeNumberOfTensors(self):
+ # CPU is too slow on the large data set.
+ if self.device in ["XLA_CPU"]:
+ return
+
+ with self.cached_session():
+ with self.test_scope():
+ for concat_dim in range(2):
+ params = {}
+ p = []
+ shape = np.array([7, 13])
+ num_tensors = 3999
+ for i in np.arange(num_tensors):
+ input_shape = shape
+ placeholder = array_ops.placeholder(
+ dtypes.float32, shape=input_shape)
+ p.append(placeholder)
+ params[placeholder] = np.random.rand(*input_shape).astype(
+ np.float32)
+
+ concat_inputs = p
+ c = array_ops.concat(concat_inputs, concat_dim)
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ cur_offset = 0
+
+ for i in np.arange(num_tensors):
+ # The index into the result is the ':' along all dimensions
+ # except the concat_dim. slice(0, size) is used for ':', and
+ # a list of slices is used to index into result.
+ index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)]
+ index[concat_dim] = slice(
+ cur_offset, cur_offset + params[p[i]].shape[concat_dim])
+ cur_offset += params[p[i]].shape[concat_dim]
+ self.assertAllEqual(result[index], params[p[i]])
+
class ConcatOffsetTest(xla_test.XLATestCase):
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index f410605104..0ae23aa6df 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -37,6 +37,16 @@ limitations under the License.
namespace tensorflow {
namespace {
+// Used to determine the number of Tensors allowed in a Concat op to prevent
+// going over the max gpu parameter memory size. This is an issue because concat
+// is variadic and can have an unlimited number of arguments when called.
+// Concat ops with more Tensors than this will be split into multiple concat
+// ops.
+//
+// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass
+// along with boxing large numbers of parameters.
+constexpr int64 kMaxConcatArgsPerOp = 500;
+
// --------------------------------------------------------------------------
class ConcatBaseOp : public XlaOpKernel {
public:
@@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel {
// Make a vector holding the XlaOp for each of the inputs that has non-zero
// elements.
std::vector<xla::XlaOp> input_data;
+ std::vector<xla::XlaOp> partial_concats;
int output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
@@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel {
input_data.push_back(handle);
}
output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1;
+
+ // Concat is associative, so it can be split into many operations when too
+ // many arguments are in a single op. This is a temporary workaround for
+ // b/112613927 where too many parameters in an XlaLaunchOp later result in
+ // too many parameters to a single GPU kernel.
+ if (i && i % kMaxConcatArgsPerOp == 0) {
+ partial_concats.push_back(
+ xla::ConcatInDim(ctx->builder(), input_data, axis));
+ input_data.clear();
+ }
}
+ // Add any inputs that have not been put into another concat yet.
+ partial_concats.insert(partial_concats.end(), input_data.begin(),
+ input_data.end());
VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
- ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis));
+ // Don't add an additional "identity" concatenate for better readibility of
+ // IR.
+ if (partial_concats.size() == 1) {
+ ctx->SetOutput(0, partial_concats.front());
+ } else {
+ ctx->SetOutput(0,
+ xla::ConcatInDim(ctx->builder(), partial_concats, axis));
+ }
}
private: