diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/lib/numeric.h')
-rw-r--r-- | tensorflow/compiler/xla/client/lib/numeric.h | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h index 79707007b2..212f658313 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.h +++ b/tensorflow/compiler/xla/client/lib/numeric.h @@ -29,6 +29,20 @@ XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); // else. XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); +// Get the diagonals of the last two dimensions. If 'x' has shape +// [..., M, N], then the output has shape [..., min(M, N)], containing the +// diagonal elements (i.e., with indices [..., i, i]). +XlaOp GetMatrixDiagonal(XlaOp x); + +// Get the upper or lower triangle part of the last two dimensions +XlaOp Triangle(XlaOp x, bool lower); + +// Get the upper triangle part of the last two dimensions +XlaOp UpperTriangle(XlaOp x); + +// Get the lower triangle part of the last two dimensions +XlaOp LowerTriangle(XlaOp x); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ |