diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/unpack_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/unpack_op.cc | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index a5ce78e520..f87586ba57 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -66,6 +66,7 @@ class UnpackOp : public XlaOpKernel { std::vector<int64> start_indices(input_shape.dims(), 0); std::vector<int64> limit_indices(input_shape.dims()); + std::vector<int64> strides(input_shape.dims(), 1); for (int i = 0; i < input_shape.dims(); ++i) { limit_indices[i] = input_shape.dim_size(i); } @@ -73,7 +74,8 @@ class UnpackOp : public XlaOpKernel { for (int i = 0; i < num; ++i) { start_indices[axis] = i; limit_indices[axis] = i + 1; - auto slice = ctx->builder()->Slice(input, start_indices, limit_indices); + auto slice = ctx->builder()->Slice(input, start_indices, limit_indices, + strides); // Reshape to drop the 'axis' dimension. auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes()); ctx->SetOutput(i, result); |