aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2018-02-02 13:08:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-02 13:15:17 -0800
commitf85f23bea87ea856ac839806b5d62f4257fb684b (patch)
tree7a451d9974675cb41fc181117803dc12756a8fc2
parentaf9afcfe44ed97dc13422445b1f1c91eaa98d583 (diff)
Support `float16` `dtype` in `tf.linalg.*`.
Note: not all `LinearOperator` functions will support `float16`. This change merely enables constructing the `LinearOperator` object(s) using this `dtype`. PiperOrigin-RevId: 184323477
-rw-r--r--tensorflow/python/ops/linalg/linalg_impl.py8
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_diag.py11
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_full_matrix.py10
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_low_rank_update.py10
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_lower_triangular.py9
5 files changed, 34 insertions, 14 deletions
diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py
index db33a08137..a5096ffdd9 100644
--- a/tensorflow/python/ops/linalg/linalg_impl.py
+++ b/tensorflow/python/ops/linalg/linalg_impl.py
@@ -65,8 +65,8 @@ def logdet(matrix, name=None):
```
Args:
- matrix: A `Tensor`. Must be `float32`, `float64`, `complex64`, or
- `complex128` with shape `[..., M, M]`.
+ matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
+ or `complex128` with shape `[..., M, M]`.
name: A name to give this `Op`. Defaults to `logdet`.
Returns:
@@ -99,8 +99,8 @@ def adjoint(matrix, name=None):
# [3 - 3j, 6 - 6j]]
Args:
- matrix: A `Tensor`. Must be `float32`, `float64`, `complex64`, or
- `complex128` with shape `[..., M, M]`.
+ matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
+ or `complex128` with shape `[..., M, M]`.
name: A name to give this `Op` (optional).
Returns:
diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py
index a4724d030f..2217bfd545 100644
--- a/tensorflow/python/ops/linalg/linear_operator_diag.py
+++ b/tensorflow/python/ops/linalg/linear_operator_diag.py
@@ -121,8 +121,8 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
Args:
diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
- The diagonal of the operator. Allowed dtypes: `float32`, `float64`,
- `complex64`, `complex128`.
+ The diagonal of the operator. Allowed dtypes: `float16`, `float32`,
+ `float64`, `complex64`, `complex128`.
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose. If `diag.dtype` is real, this is auto-set to `True`.
@@ -167,7 +167,12 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
def _check_diag(self, diag):
"""Static check of diag."""
allowed_dtypes = [
- dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
+ dtypes.float16,
+ dtypes.float32,
+ dtypes.float64,
+ dtypes.complex64,
+ dtypes.complex128,
+ ]
dtype = diag.dtype
if dtype not in allowed_dtypes:
diff --git a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py
index dd4c7cb041..8fb59ca1a7 100644
--- a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py
+++ b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py
@@ -114,7 +114,8 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
Args:
matrix: Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`.
- Allowed dtypes: `float32`, `float64`, `complex64`, `complex128`.
+ Allowed dtypes: `float16`, `float32`, `float64`, `complex64`,
+ `complex128`.
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose.
@@ -147,7 +148,12 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
def _check_matrix(self, matrix):
"""Static check of the `matrix` argument."""
allowed_dtypes = [
- dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
+ dtypes.float16,
+ dtypes.float32,
+ dtypes.float64,
+ dtypes.complex64,
+ dtypes.complex128,
+ ]
matrix = ops.convert_to_tensor(matrix, name="matrix")
diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
index ad3bb2efa9..36eed89db6 100644
--- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
+++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
@@ -150,8 +150,8 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
`is_X` matrix property hints, which will trigger the appropriate code path.
Args:
- base_operator: Shape `[B1,...,Bb, M, N]` real `float32` or `float64`
- `LinearOperator`. This is `L` above.
+ base_operator: Shape `[B1,...,Bb, M, N]` real `float16`, `float32` or
+ `float64` `LinearOperator`. This is `L` above.
u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`.
This is `U` above.
diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype`
@@ -188,7 +188,11 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
# because if diag has non-zero imaginary part, it will not be
# self-adjoint positive definite.
dtype = base_operator.dtype
- allowed_dtypes = [dtypes.float32, dtypes.float64]
+ allowed_dtypes = [
+ dtypes.float16,
+ dtypes.float32,
+ dtypes.float64,
+ ]
if dtype not in allowed_dtypes:
raise TypeError(
"Argument matrix must have dtype in %s. Found: %s"
diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
index 6ea55f0367..6419030755 100644
--- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
+++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
@@ -118,7 +118,8 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
Args:
tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`.
The lower triangular part of `tril` defines this operator. The strictly
- upper triangle is ignored. Allowed dtypes: `float32`, `float64`.
+ upper triangle is ignored. Allowed dtypes: `float16`, `float32`,
+ `float64`.
is_non_singular: Expect that this operator is non-singular.
This operator is non-singular if and only if its diagonal elements are
all non-zero.
@@ -164,7 +165,11 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
"""Static check of the `tril` argument."""
# TODO(langmore) Add complex types once matrix_triangular_solve works for
# them.
- allowed_dtypes = [dtypes.float32, dtypes.float64]
+ allowed_dtypes = [
+ dtypes.float16,
+ dtypes.float32,
+ dtypes.float64,
+ ]
dtype = tril.dtype
if dtype not in allowed_dtypes:
raise TypeError(