aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/diag_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/diag_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 74994d8961..ec5017f6ab 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -125,7 +125,7 @@ class DiagPartOp : public XlaOpKernel {
diag = builder->Reshape(diag, {new_size, new_size + 1});
// Slices out the first column and reshapes to the final shape.
- diag = builder->Slice(diag, {0, 0}, {new_size, 1});
+ diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1});
diag = builder->Reshape(diag, new_dims);
ctx->SetOutput(0, diag);
@@ -224,8 +224,9 @@ class MatrixDiagPartOp : public XlaOpKernel {
} else if (actual_size > target_size) {
std::vector<int64> start(flattened_dims.size(), 0);
std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end());
+ std::vector<int64> strides(flattened_dims.size(), 1);
limits[flattened_dims.size() - 1] = target_size;
- diag = builder->Slice(diag, start, limits);
+ diag = builder->Slice(diag, start, limits, strides);
}
// Reshape so the target values are in the first position of the last
@@ -238,8 +239,9 @@ class MatrixDiagPartOp : public XlaOpKernel {
// Slices out the first column and reshapes to the final shape.
std::vector<int64> start(dims.size(), 0);
std::vector<int64> limits(dims.begin(), dims.end());
+ std::vector<int64> strides(dims.size(), 1);
limits[last_dim] = 1;
- diag = builder->Slice(diag, start, limits);
+ diag = builder->Slice(diag, start, limits, strides);
// Collapses away the last dimension.
dims.pop_back();