aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/split_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/split_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc14
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index 017f3a110e..44ee81461e 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -77,14 +77,14 @@ class SplitOp : public XlaOpKernel {
// The vectors we will use to define the slice. The entry for the
// split dimensions varies for each output.
- std::vector<int64> begin;
- std::vector<int64> limits;
+ std::vector<int64> begin(input_shape.dims(), 0);
+ std::vector<int64> limits(input_shape.dims());
+ std::vector<int64> strides(input_shape.dims(), 1);
for (int i = 0; i < input_shape.dims(); ++i) {
// Initially set up the limits to be the full size of the input:
// the split dimension is filled in below.
int64 dim = input_shape.dim_size(i);
- begin.push_back(0);
- limits.push_back(dim);
+ limits[i] = dim;
}
auto input = ctx->Input(1);
@@ -94,7 +94,7 @@ class SplitOp : public XlaOpKernel {
// Slice out the ith split from the split dimension.
begin[split_dim] = i * slice_size;
limits[split_dim] = (i + 1) * slice_size;
- ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits));
+ ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides));
}
}
};
@@ -188,7 +188,7 @@ class SplitVOp : public XlaOpKernel {
std::vector<int64> begin(input_shape.dims(), 0);
auto dim_sizes = input_shape.dim_sizes();
std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end());
-
+ std::vector<int64> strides(input_shape.dims(), 1);
for (int i = 0; i < num_split; ++i) {
TensorShape output_shape(input_shape);
int slice_size = split_sizes_vec[i];
@@ -196,7 +196,7 @@ class SplitVOp : public XlaOpKernel {
// Slice out the ith split from the split dimension.
limits[split_dim] = begin[split_dim] + slice_size;
- ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits));
+ ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides));
begin[split_dim] = limits[split_dim];
}
}