diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-27 16:33:00 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-27 16:37:09 -0700 |
commit | 50b999a8336d19400ab75aea66fe46eca2f5fe0b (patch) | |
tree | 7cba4f4af6b131c253b65ff9f2923e851184668c /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | |
parent | d6d58a3a1785785679af56c0f8f131e7312b8226 (diff) |
Merge changes from github.
PiperOrigin-RevId: 160344052
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 598b341002..9367c1ef22 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -318,7 +318,7 @@ class TensorArrayGatherOp : public XlaOpKernel { for (int i = 0; i < num_indices; ++i) { // Slices the i-th index out of `indices`, and pads it with zeros in the // minor dimensions to form an index into the TensorArray storage. - auto index = b->Slice(indices, {i}, {i + 1}); + auto index = b->Slice(indices, {i}, {i + 1}, {1}); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1); @@ -381,16 +381,18 @@ class TensorArrayScatterOp : public XlaOpKernel { std::vector<int64> value_starts(value_shape.dims(), 0); auto value_ends = value_shape.dim_sizes(); + std::vector<int64> value_strides(value_shape.dims(), 1); + // For every (index, value) pair, update the corresponding TensorArray // storage. for (int i = 0; i < num_indices; ++i) { // Slice out part of the value. value_starts[0] = i; value_ends[0] = i + 1; - auto slice = b->Slice(value, value_starts, value_ends); + auto slice = b->Slice(value, value_starts, value_ends, value_strides); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = b->Slice(indices, {i}, {i + 1}); + auto index = b->Slice(indices, {i}, {i + 1}, {1}); auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims()); ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } |