aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/linalg_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/linalg_ops.py')
-rw-r--r--tensorflow/python/ops/linalg_ops.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
new file mode 100644
index 0000000000..76fd83fb3d
--- /dev/null
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -0,0 +1,62 @@
+"""Operations for linear algebra."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_linalg_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_linalg_ops import *
+# pylint: enable=wildcard-import
+
+
+@ops.RegisterShape("Cholesky")
+def _CholeskyShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank(2)
+ # The matrix must be square.
+ input_shape[0].assert_is_compatible_with(input_shape[1])
+ return [input_shape]
+
+
+@ops.RegisterShape("BatchCholesky")
+def _BatchCholeskyShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+ # The matrices in the batch must be square.
+ input_shape[-1].assert_is_compatible_with(input_shape[-2])
+ return [input_shape]
+
+
+@ops.RegisterShape("MatrixDeterminant")
+def _MatrixDeterminantShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank(2)
+ # The matrix must be square.
+ input_shape[0].assert_is_compatible_with(input_shape[1])
+ if input_shape.ndims is not None:
+ return [tensor_shape.scalar()]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("BatchMatrixDeterminant")
+def _BatchMatrixDeterminantShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+ # The matrices in the batch must be square.
+ input_shape[-1].assert_is_compatible_with(input_shape[-2])
+ if input_shape.ndims is not None:
+ return [input_shape[:-2]]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("MatrixInverse")
+def _MatrixInverseShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank(2)
+ # The matrix must be square.
+ input_shape[0].assert_is_compatible_with(input_shape[1])
+ return [input_shape]
+
+
+@ops.RegisterShape("BatchMatrixInverse")
+def _BatchMatrixInverseShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+ # The matrices in the batch must be square.
+ input_shape[-1].assert_is_compatible_with(input_shape[-2])
+ return [input_shape]