aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar Dandelion Man? <dandelion@google.com>2017-12-15 17:12:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 17:16:29 -0800
commitd55f532867a3670d66460c5ee3b774519542adc1 (patch)
tree7de4d85bcd61e93401459276b4d371ab0be23c1f /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parent32d5048ae96116202f2aa0fa739ef37514ee8a54 (diff)
Merge changes from github.
PiperOrigin-RevId: 179258973
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc87
1 files changed, 66 insertions, 21 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 351fda2517..03c22354a9 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -311,6 +311,32 @@ class TensorArrayGatherOp : public XlaOpKernel {
xla::ComputationDataHandle ta = resource->value;
+ // Look for the case where the gather takes a simple slice from the
+ // tensor array (0, 1, 2, 3, 4, ..., N)
+ std::vector<int64> const_indices;
+ Status status = ctx->ConstantInputAsIntVector(1, &const_indices);
+ if (status.ok()) {
+ bool gather_is_dense_slice = true;
+ for (auto i = 0; i < const_indices.size(); i++) {
+ if (const_indices[i] != i) {
+ gather_is_dense_slice = false;
+ break;
+ }
+ }
+
+ if (gather_is_dense_slice) {
+ std::vector<int64> begin(ta_shape.dims(), 0);
+ std::vector<int64> strides(ta_shape.dims(), 1);
+ std::vector<int64> end(ta_shape.dims(), 1);
+ end[0] = const_indices.size();
+ for (auto i = 1; i < ta_shape.dims(); i++) {
+ end[i] = ta_shape.dim_size(i);
+ }
+ ctx->SetOutput(0, b->Slice(ta, begin, end, strides));
+ return;
+ }
+ }
+
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b);
ctx->SetOutput(0, gather);
@@ -352,28 +378,47 @@ class TensorArrayScatterOp : public XlaOpKernel {
const xla::ComputationDataHandle value = ctx->Input(2);
const xla::ComputationDataHandle flow = ctx->Input(3);
- auto slice_dims = value_shape.dim_sizes();
- slice_dims[0] = 1LL;
-
- std::vector<int64> value_starts(value_shape.dims(), 0);
- auto value_ends = value_shape.dim_sizes();
-
- std::vector<int64> value_strides(value_shape.dims(), 1);
-
- // For every (index, value) pair, update the corresponding TensorArray
- // storage.
- for (int i = 0; i < num_indices; ++i) {
- // Slice out part of the value.
- value_starts[0] = i;
- value_ends[0] = i + 1;
- auto slice = b->Slice(value, value_starts, value_ends, value_strides);
+ // Look for the case where the scatter is for each sub-tensor in order. The
+ // tensor array implementation allows for this to be a straight addition.
+ bool scatter_all_elements_in_order = false;
+ std::vector<int64> const_indices;
+ Status status = ctx->ConstantInputAsIntVector(1, &const_indices);
+ if (status.ok() && num_indices == value_shape.dim_size(0)) {
+ scatter_all_elements_in_order = true;
+ for (auto i = 0; i < num_indices; i++) {
+ if (const_indices[i] != i) {
+ scatter_all_elements_in_order = false;
+ break;
+ }
+ }
+ }
- // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
- auto index = b->Slice(indices, {i}, {i + 1}, {1});
- auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
- ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
+ if (scatter_all_elements_in_order) {
+ ta = b->Add(ta, value);
+ } else {
+ auto slice_dims = value_shape.dim_sizes();
+ slice_dims[0] = 1LL;
+
+ std::vector<int64> value_starts(value_shape.dims(), 0);
+ auto value_ends = value_shape.dim_sizes();
+
+ std::vector<int64> value_strides(value_shape.dims(), 1);
+
+ // For every (index, value) pair, update the corresponding TensorArray
+ // storage.
+ for (int i = 0; i < num_indices; ++i) {
+ // Slice out part of the value.
+ value_starts[0] = i;
+ value_ends[0] = i + 1;
+ auto slice = b->Slice(value, value_starts, value_ends, value_strides);
+
+ // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
+ auto index = b->Slice(indices, {i}, {i + 1}, {1});
+ auto start_indices =
+ b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
+ }
}
resource->value = ta;