diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/diag_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/diag_op.cc | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 74994d8961..ec5017f6ab 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -125,7 +125,7 @@ class DiagPartOp : public XlaOpKernel { diag = builder->Reshape(diag, {new_size, new_size + 1}); // Slices out the first column and reshapes to the final shape. - diag = builder->Slice(diag, {0, 0}, {new_size, 1}); + diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); diag = builder->Reshape(diag, new_dims); ctx->SetOutput(0, diag); @@ -224,8 +224,9 @@ class MatrixDiagPartOp : public XlaOpKernel { } else if (actual_size > target_size) { std::vector<int64> start(flattened_dims.size(), 0); std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end()); + std::vector<int64> strides(flattened_dims.size(), 1); limits[flattened_dims.size() - 1] = target_size; - diag = builder->Slice(diag, start, limits); + diag = builder->Slice(diag, start, limits, strides); } // Reshape so the target values are in the first position of the last @@ -238,8 +239,9 @@ class MatrixDiagPartOp : public XlaOpKernel { // Slices out the first column and reshapes to the final shape. std::vector<int64> start(dims.size(), 0); std::vector<int64> limits(dims.begin(), dims.end()); + std::vector<int64> strides(dims.size(), 1); limits[last_dim] = 1; - diag = builder->Slice(diag, start, limits); + diag = builder->Slice(diag, start, limits, strides); // Collapses away the last dimension. dims.pop_back(); |