aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-27 16:33:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-27 16:37:09 -0700
commit50b999a8336d19400ab75aea66fe46eca2f5fe0b (patch)
tree7cba4f4af6b131c253b65ff9f2923e851184668c /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parentd6d58a3a1785785679af56c0f8f131e7312b8226 (diff)
Merge changes from github.
PiperOrigin-RevId: 160344052
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 598b341002..9367c1ef22 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -318,7 +318,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
for (int i = 0; i < num_indices; ++i) {
// Slices the i-th index out of `indices`, and pads it with zeros in the
// minor dimensions to form an index into the TensorArray storage.
- auto index = b->Slice(indices, {i}, {i + 1});
+ auto index = b->Slice(indices, {i}, {i + 1}, {1});
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
@@ -381,16 +381,18 @@ class TensorArrayScatterOp : public XlaOpKernel {
std::vector<int64> value_starts(value_shape.dims(), 0);
auto value_ends = value_shape.dim_sizes();
+ std::vector<int64> value_strides(value_shape.dims(), 1);
+
// For every (index, value) pair, update the corresponding TensorArray
// storage.
for (int i = 0; i < num_indices; ++i) {
// Slice out part of the value.
value_starts[0] = i;
value_ends[0] = i + 1;
- auto slice = b->Slice(value, value_starts, value_ends);
+ auto slice = b->Slice(value, value_starts, value_ends, value_strides);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
- auto index = b->Slice(indices, {i}, {i + 1});
+ auto index = b->Slice(indices, {i}, {i + 1}, {1});
auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}