diff options
Diffstat (limited to 'tensorflow/python/ops/linalg_ops.py')
-rw-r--r-- | tensorflow/python/ops/linalg_ops.py | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py new file mode 100644 index 0000000000..76fd83fb3d --- /dev/null +++ b/tensorflow/python/ops/linalg_ops.py @@ -0,0 +1,62 @@ +"""Operations for linear algebra.""" + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_linalg_ops +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_linalg_ops import * +# pylint: enable=wildcard-import + + +@ops.RegisterShape("Cholesky") +def _CholeskyShape(op): + input_shape = op.inputs[0].get_shape().with_rank(2) + # The matrix must be square. + input_shape[0].assert_is_compatible_with(input_shape[1]) + return [input_shape] + + +@ops.RegisterShape("BatchCholesky") +def _BatchCholeskyShape(op): + input_shape = op.inputs[0].get_shape().with_rank_at_least(3) + # The matrices in the batch must be square. + input_shape[-1].assert_is_compatible_with(input_shape[-2]) + return [input_shape] + + +@ops.RegisterShape("MatrixDeterminant") +def _MatrixDeterminantShape(op): + input_shape = op.inputs[0].get_shape().with_rank(2) + # The matrix must be square. + input_shape[0].assert_is_compatible_with(input_shape[1]) + if input_shape.ndims is not None: + return [tensor_shape.scalar()] + else: + return [tensor_shape.unknown_shape()] + + +@ops.RegisterShape("BatchMatrixDeterminant") +def _BatchMatrixDeterminantShape(op): + input_shape = op.inputs[0].get_shape().with_rank_at_least(3) + # The matrices in the batch must be square. + input_shape[-1].assert_is_compatible_with(input_shape[-2]) + if input_shape.ndims is not None: + return [input_shape[:-2]] + else: + return [tensor_shape.unknown_shape()] + + +@ops.RegisterShape("MatrixInverse") +def _MatrixInverseShape(op): + input_shape = op.inputs[0].get_shape().with_rank(2) + # The matrix must be square. + input_shape[0].assert_is_compatible_with(input_shape[1]) + return [input_shape] + + +@ops.RegisterShape("BatchMatrixInverse") +def _BatchMatrixInverseShape(op): + input_shape = op.inputs[0].get_shape().with_rank_at_least(3) + # The matrices in the batch must be square. + input_shape[-1].assert_is_compatible_with(input_shape[-2]) + return [input_shape] |