diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index faa7ef0ef9..0330e34c98 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -156,6 +156,8 @@ class DynamicStitchOp : public XlaOpKernel { indices0_shape.dims()); std::vector<int64> slice_limit(1 + data0_shape.dims() - indices0_shape.dims()); + std::vector<int64> stride(1 + data0_shape.dims() - + indices0_shape.dims(), 1); for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } @@ -168,7 +170,7 @@ class DynamicStitchOp : public XlaOpKernel { // And place it in the concat list in the place indicated by // the index. to_concat[index_num] = - ctx->builder()->Slice(expression, slice_start, slice_limit); + ctx->builder()->Slice(expression, slice_start, slice_limit, stride); } ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0)); |