aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/linalg_ops.py
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-23 21:19:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 21:21:38 -0700
commit22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch)
treed16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/python/ops/linalg_ops.py
parent24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff)
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/python/ops/linalg_ops.py')
-rw-r--r--tensorflow/python/ops/linalg_ops.py77
1 files changed, 37 insertions, 40 deletions
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 170861b43f..a0dfa543f9 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -24,12 +24,13 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_linalg_ops
+from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_linalg_ops import *
# pylint: enable=wildcard-import
-from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -159,36 +160,11 @@ def eye(num_rows,
Returns:
A `Tensor` of shape `batch_shape + [num_rows, num_columns]`
"""
- with ops.name_scope(
- name, default_name='eye', values=[num_rows, num_columns, batch_shape]):
- is_square = num_columns is None
- batch_shape = [] if batch_shape is None else batch_shape
- num_columns = num_rows if num_columns is None else num_columns
- if isinstance(num_rows, ops.Tensor) or isinstance(
- num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
- batch_shape = ops.convert_to_tensor(
- batch_shape, name='shape', dtype=dtypes.int32)
- diag_size = math_ops.minimum(num_rows, num_columns)
- diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
- if not is_square:
- shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
- else:
- if not isinstance(num_rows, compat.integral_types) or not isinstance(
- num_columns, compat.integral_types):
- raise TypeError(
- 'num_rows and num_columns must be positive integer values.')
- batch_shape = [dim for dim in batch_shape]
- is_square = num_rows == num_columns
- diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
- if not is_square:
- shape = batch_shape + [num_rows, num_columns]
-
- diag_ones = array_ops.ones(diag_shape, dtype=dtype)
- if is_square:
- return array_ops.matrix_diag(diag_ones)
- else:
- zero_matrix = array_ops.zeros(shape, dtype=dtype)
- return array_ops.matrix_set_diag(zero_matrix, diag_ones)
+ return linalg_ops_impl.eye(num_rows,
+ num_columns=num_columns,
+ batch_shape=batch_shape,
+ dtype=dtype,
+ name=name)
@tf_export('matrix_solve_ls', 'linalg.lstsq')
@@ -454,7 +430,7 @@ def norm(tensor,
This function can compute several different vector norms (the 1-norm, the
Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
- matrix norms (Frobenius, 1-norm, and inf-norm).
+ matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).
Args:
tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
@@ -465,7 +441,7 @@ def norm(tensor,
Some restrictions apply:
a) The Frobenius norm `fro` is not defined for vectors,
b) If axis is a 2-tuple (matrix norm), only 'euclidean', 'fro', `1`,
- `np.inf` are supported.
+ `2`, `np.inf` are supported.
See the description of `axis` on how to compute norms for a batch of
vectors or matrices stored in a tensor.
axis: If `axis` is `None` (the default), the input is considered a vector
@@ -521,8 +497,7 @@ def norm(tensor,
axis[0] == axis[1]):
raise ValueError(
"'axis' must be None, an integer, or a tuple of 2 unique integers")
- # TODO(rmlarsen): Implement matrix 2-norm using tf.svd().
- supported_matrix_norms = ['euclidean', 'fro', 1, np.inf]
+ supported_matrix_norms = ['euclidean', 'fro', 1, 2, np.inf]
if ord not in supported_matrix_norms:
raise ValueError("'ord' must be a supported matrix norm in %s, got %s" %
(supported_matrix_norms, ord))
@@ -539,12 +514,34 @@ def norm(tensor,
with ops.name_scope(name, 'norm', [tensor]):
tensor = ops.convert_to_tensor(tensor)
+
if ord in ['fro', 'euclidean', 2, 2.0]:
- # TODO(rmlarsen): Move 2-norm to a separate clause once we support it for
- # matrices.
- result = math_ops.sqrt(
- math_ops.reduce_sum(
- tensor * math_ops.conj(tensor), axis, keepdims=True))
+ if is_matrix_norm and ord in [2, 2.0]:
+ rank = array_ops.rank(tensor)
+ positive_axis = functional_ops.map_fn(
+ lambda i: control_flow_ops.cond(i >= 0, lambda: i, lambda: i + rank),
+ ops.convert_to_tensor(axis))
+ axes = math_ops.range(rank)
+ perm_before = array_ops.concat(
+ [array_ops.setdiff1d(axes, positive_axis)[0], positive_axis],
+ axis=0)
+ perm_after = functional_ops.map_fn(
+ lambda i: math_ops.cast(
+ array_ops.squeeze(
+ array_ops.where(math_ops.equal(perm_before, i))),
+ dtype=dtypes.int32), axes)
+ permed = array_ops.transpose(tensor, perm=perm_before)
+ matrix_2_norm = array_ops.expand_dims(
+ math_ops.reduce_max(
+ math_ops.abs(gen_linalg_ops.svd(permed, compute_uv=False)[0]),
+ axis=-1,
+ keepdims=True),
+ axis=-1)
+ result = array_ops.transpose(matrix_2_norm, perm=perm_after)
+ else:
+ result = math_ops.sqrt(
+ math_ops.reduce_sum(
+ tensor * math_ops.conj(tensor), axis, keepdims=True))
else:
result = math_ops.abs(tensor)
if ord == 1: