diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/split_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/split_op.cc | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 017f3a110e..44ee81461e 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -77,14 +77,14 @@ class SplitOp : public XlaOpKernel { // The vectors we will use to define the slice. The entry for the // split dimensions varies for each output. - std::vector<int64> begin; - std::vector<int64> limits; + std::vector<int64> begin(input_shape.dims(), 0); + std::vector<int64> limits(input_shape.dims()); + std::vector<int64> strides(input_shape.dims(), 1); for (int i = 0; i < input_shape.dims(); ++i) { // Initially set up the limits to be the full size of the input: // the split dimension is filled in below. int64 dim = input_shape.dim_size(i); - begin.push_back(0); - limits.push_back(dim); + limits[i] = dim; } auto input = ctx->Input(1); @@ -94,7 +94,7 @@ class SplitOp : public XlaOpKernel { // Slice out the ith split from the split dimension. begin[split_dim] = i * slice_size; limits[split_dim] = (i + 1) * slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits)); + ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); } } }; @@ -188,7 +188,7 @@ class SplitVOp : public XlaOpKernel { std::vector<int64> begin(input_shape.dims(), 0); auto dim_sizes = input_shape.dim_sizes(); std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end()); - + std::vector<int64> strides(input_shape.dims(), 1); for (int i = 0; i < num_split; ++i) { TensorShape output_shape(input_shape); int slice_size = split_sizes_vec[i]; @@ -196,7 +196,7 @@ class SplitVOp : public XlaOpKernel { // Slice out the ith split from the split dimension. limits[split_dim] = begin[split_dim] + slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits)); + ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); begin[split_dim] = limits[split_dim]; } } |