diff options
author | Dandelion Man? <dandelion@google.com> | 2017-12-15 17:12:41 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-15 17:16:29 -0800 |
commit | d55f532867a3670d66460c5ee3b774519542adc1 (patch) | |
tree | 7de4d85bcd61e93401459276b4d371ab0be23c1f /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | |
parent | 32d5048ae96116202f2aa0fa739ef37514ee8a54 (diff) |
Merge changes from github.
PiperOrigin-RevId: 179258973
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 87 |
1 files changed, 66 insertions, 21 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 351fda2517..03c22354a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -311,6 +311,32 @@ class TensorArrayGatherOp : public XlaOpKernel { xla::ComputationDataHandle ta = resource->value; + // Look for the case where the gather takes a simple slice from the + // tensor array (0, 1, 2, 3, 4, ..., N) + std::vector<int64> const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok()) { + bool gather_is_dense_slice = true; + for (auto i = 0; i < const_indices.size(); i++) { + if (const_indices[i] != i) { + gather_is_dense_slice = false; + break; + } + } + + if (gather_is_dense_slice) { + std::vector<int64> begin(ta_shape.dims(), 0); + std::vector<int64> strides(ta_shape.dims(), 1); + std::vector<int64> end(ta_shape.dims(), 1); + end[0] = const_indices.size(); + for (auto i = 1; i < ta_shape.dims(); i++) { + end[i] = ta_shape.dim_size(i); + } + ctx->SetOutput(0, b->Slice(ta, begin, end, strides)); + return; + } + } + xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b); ctx->SetOutput(0, gather); @@ -352,28 +378,47 @@ class TensorArrayScatterOp : public XlaOpKernel { const xla::ComputationDataHandle value = ctx->Input(2); const xla::ComputationDataHandle flow = ctx->Input(3); - auto slice_dims = value_shape.dim_sizes(); - slice_dims[0] = 1LL; - - std::vector<int64> value_starts(value_shape.dims(), 0); - auto value_ends = value_shape.dim_sizes(); - - std::vector<int64> value_strides(value_shape.dims(), 1); - - // For every (index, value) pair, update the corresponding TensorArray - // storage. - for (int i = 0; i < num_indices; ++i) { - // Slice out part of the value. - value_starts[0] = i; - value_ends[0] = i + 1; - auto slice = b->Slice(value, value_starts, value_ends, value_strides); + // Look for the case where the scatter is for each sub-tensor in order. The + // tensor array implementation allows for this to be a straight addition. + bool scatter_all_elements_in_order = false; + std::vector<int64> const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok() && num_indices == value_shape.dim_size(0)) { + scatter_all_elements_in_order = true; + for (auto i = 0; i < num_indices; i++) { + if (const_indices[i] != i) { + scatter_all_elements_in_order = false; + break; + } + } + } - // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = b->Slice(indices, {i}, {i + 1}, {1}); - auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); - ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + if (scatter_all_elements_in_order) { + ta = b->Add(ta, value); + } else { + auto slice_dims = value_shape.dim_sizes(); + slice_dims[0] = 1LL; + + std::vector<int64> value_starts(value_shape.dims(), 0); + auto value_ends = value_shape.dim_sizes(); + + std::vector<int64> value_strides(value_shape.dims(), 1); + + // For every (index, value) pair, update the corresponding TensorArray + // storage. + for (int i = 0; i < num_indices; ++i) { + // Slice out part of the value. + value_starts[0] = i; + value_ends[0] = i + 1; + auto slice = b->Slice(value, value_starts, value_ends, value_strides); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto index = b->Slice(indices, {i}, {i + 1}, {1}); + auto start_indices = + b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + } } resource->value = ta; |