aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
index faa7ef0ef9..0330e34c98 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
@@ -156,6 +156,8 @@ class DynamicStitchOp : public XlaOpKernel {
indices0_shape.dims());
std::vector<int64> slice_limit(1 + data0_shape.dims() -
indices0_shape.dims());
+ std::vector<int64> stride(1 + data0_shape.dims() -
+ indices0_shape.dims(), 1);
for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) {
slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d);
}
@@ -168,7 +170,7 @@ class DynamicStitchOp : public XlaOpKernel {
// And place it in the concat list in the place indicated by
// the index.
to_concat[index_num] =
- ctx->builder()->Slice(expression, slice_start, slice_limit);
+ ctx->builder()->Slice(expression, slice_start, slice_limit, stride);
}
ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0));