aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/unpack_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unpack_op.cc4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
index a5ce78e520..f87586ba57 100644
--- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
@@ -66,6 +66,7 @@ class UnpackOp : public XlaOpKernel {
std::vector<int64> start_indices(input_shape.dims(), 0);
std::vector<int64> limit_indices(input_shape.dims());
+ std::vector<int64> strides(input_shape.dims(), 1);
for (int i = 0; i < input_shape.dims(); ++i) {
limit_indices[i] = input_shape.dim_size(i);
}
@@ -73,7 +74,8 @@ class UnpackOp : public XlaOpKernel {
for (int i = 0; i < num; ++i) {
start_indices[axis] = i;
limit_indices[axis] = i + 1;
- auto slice = ctx->builder()->Slice(input, start_indices, limit_indices);
+ auto slice = ctx->builder()->Slice(input, start_indices, limit_indices,
+ strides);
// Reshape to drop the 'axis' dimension.
auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes());
ctx->SetOutput(i, result);