diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/diag_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/diag_op.cc | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index ec5017f6ab..74994d8961 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -125,7 +125,7 @@ class DiagPartOp : public XlaOpKernel { diag = builder->Reshape(diag, {new_size, new_size + 1}); // Slices out the first column and reshapes to the final shape. - diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); + diag = builder->Slice(diag, {0, 0}, {new_size, 1}); diag = builder->Reshape(diag, new_dims); ctx->SetOutput(0, diag); @@ -224,9 +224,8 @@ class MatrixDiagPartOp : public XlaOpKernel { } else if (actual_size > target_size) { std::vector<int64> start(flattened_dims.size(), 0); std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end()); - std::vector<int64> strides(flattened_dims.size(), 1); limits[flattened_dims.size() - 1] = target_size; - diag = builder->Slice(diag, start, limits, strides); + diag = builder->Slice(diag, start, limits); } // Reshape so the target values are in the first position of the last @@ -239,9 +238,8 @@ class MatrixDiagPartOp : public XlaOpKernel { // Slices out the first column and reshapes to the final shape. std::vector<int64> start(dims.size(), 0); std::vector<int64> limits(dims.begin(), dims.end()); - std::vector<int64> strides(dims.size(), 1); limits[last_dim] = 1; - diag = builder->Slice(diag, start, limits, strides); + diag = builder->Slice(diag, start, limits); // Collapses away the last dimension. dims.pop_back(); |