aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/lib/numeric.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/lib/numeric.h')
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.h14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h
index 79707007b2..212f658313 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.h
+++ b/tensorflow/compiler/xla/client/lib/numeric.h
@@ -29,6 +29,20 @@ XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
// else.
XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n);
+// Get the diagonals of the last two dimensions. If 'x' has shape
+// [..., M, N], then the output has shape [..., min(M, N)], containing the
+// diagonal elements (i.e., with indices [..., i, i]).
+XlaOp GetMatrixDiagonal(XlaOp x);
+
+// Get the upper or lower triangle part of the last two dimensions
+XlaOp Triangle(XlaOp x, bool lower);
+
+// Get the upper triangle part of the last two dimensions
+XlaOp UpperTriangle(XlaOp x);
+
+// Get the lower triangle part of the last two dimensions
+XlaOp LowerTriangle(XlaOp x);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_