diff options
Diffstat (limited to 'tensorflow/contrib/linalg/python/ops/linear_operator_addition.py')
-rw-r--r-- | tensorflow/contrib/linalg/python/ops/linear_operator_addition.py | 432 |
1 files changed, 0 insertions, 432 deletions
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py deleted file mode 100644 index 86130a2c07..0000000000 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py +++ /dev/null @@ -1,432 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Add one or more `LinearOperators` efficiently.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc - -import six - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops.linalg import linear_operator -from tensorflow.python.ops.linalg import linear_operator_diag -from tensorflow.python.ops.linalg import linear_operator_full_matrix -from tensorflow.python.ops.linalg import linear_operator_identity -from tensorflow.python.ops.linalg import linear_operator_lower_triangular - -__all__ = [] - - -def add_operators(operators, - operator_name=None, - addition_tiers=None, - name=None): - """Efficiently add one or more linear operators. - - Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of - operators `[B1, B2,...]` such that - - ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).``` - - The operators `Bk` result by adding some of the `Ak`, as allowed by - `addition_tiers`. - - Example of efficient adding of diagonal operators. - - ```python - A1 = LinearOperatorDiag(diag=[1., 1.], name="A1") - A2 = LinearOperatorDiag(diag=[2., 2.], name="A2") - - # Use two tiers, the first contains an Adder that returns Diag. Since both - # A1 and A2 are Diag, they can use this Adder. The second tier will not be - # used. - addition_tiers = [ - [_AddAndReturnDiag()], - [_AddAndReturnMatrix()]] - B_list = add_operators([A1, A2], addition_tiers=addition_tiers) - - len(B_list) - ==> 1 - - B_list[0].__class__.__name__ - ==> 'LinearOperatorDiag' - - B_list[0].to_dense() - ==> [[3., 0.], - [0., 3.]] - - B_list[0].name - ==> 'Add/A1__A2/' - ``` - - Args: - operators: Iterable of `LinearOperator` objects with same `dtype`, domain - and range dimensions, and broadcastable batch shapes. - operator_name: String name for returned `LinearOperator`. Defaults to - concatenation of "Add/A__B/" that indicates the order of addition steps. - addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i` - is a list of `Adder` objects. This function attempts to do all additions - in tier `i` before trying tier `i + 1`. - name: A name for this `Op`. Defaults to `add_operators`. - - Returns: - Subclass of `LinearOperator`. Class and order of addition may change as new - (and better) addition strategies emerge. - - Raises: - ValueError: If `operators` argument is empty. - ValueError: If shapes are incompatible. - """ - # Default setting - if addition_tiers is None: - addition_tiers = _DEFAULT_ADDITION_TIERS - - # Argument checking. - check_ops.assert_proper_iterable(operators) - operators = list(reversed(operators)) - if len(operators) < 1: - raise ValueError( - "Argument 'operators' must contain at least one operator. " - "Found: %s" % operators) - if not all( - isinstance(op, linear_operator.LinearOperator) for op in operators): - raise TypeError( - "Argument 'operators' must contain only LinearOperator instances. " - "Found: %s" % operators) - _static_check_for_same_dimensions(operators) - _static_check_for_broadcastable_batch_shape(operators) - - graph_parents = [] - for operator in operators: - graph_parents.extend(operator.graph_parents) - - with ops.name_scope(name or "add_operators", values=graph_parents): - - # Additions done in one of the tiers. Try tier 0, 1,... - ops_to_try_at_next_tier = list(operators) - for tier in addition_tiers: - ops_to_try_at_this_tier = ops_to_try_at_next_tier - ops_to_try_at_next_tier = [] - while ops_to_try_at_this_tier: - op1 = ops_to_try_at_this_tier.pop() - op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier) - if op2 is not None: - # Will try to add the result of this again at this same tier. - new_operator = adder.add(op1, op2, operator_name) - ops_to_try_at_this_tier.append(new_operator) - else: - ops_to_try_at_next_tier.append(op1) - - return ops_to_try_at_next_tier - - -def _pop_a_match_at_tier(op1, operator_list, tier): - # Search from the back of list to the front in order to create nice default - # order of operations. - for i in range(1, len(operator_list) + 1): - op2 = operator_list[-i] - for adder in tier: - if adder.can_add(op1, op2): - return operator_list.pop(-i), adder - return None, None - - -def _infer_hints_allowing_override(op1, op2, hints): - """Infer hints from op1 and op2. hints argument is an override. - - Args: - op1: LinearOperator - op2: LinearOperator - hints: _Hints object holding "is_X" boolean hints to use for returned - operator. - If some hint is None, try to set using op1 and op2. If the - hint is provided, ignore op1 and op2 hints. This allows an override - of previous hints, but does not allow forbidden hints (e.g. you still - cannot say a real diagonal operator is not self-adjoint. - - Returns: - _Hints object. - """ - hints = hints or _Hints() - # If A, B are self-adjoint, then so is A + B. - if hints.is_self_adjoint is None: - is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint - else: - is_self_adjoint = hints.is_self_adjoint - - # If A, B are positive definite, then so is A + B. - if hints.is_positive_definite is None: - is_positive_definite = op1.is_positive_definite and op2.is_positive_definite - else: - is_positive_definite = hints.is_positive_definite - - # A positive definite operator is always non-singular. - if is_positive_definite and hints.is_positive_definite is None: - is_non_singular = True - else: - is_non_singular = hints.is_non_singular - - return _Hints( - is_non_singular=is_non_singular, - is_self_adjoint=is_self_adjoint, - is_positive_definite=is_positive_definite) - - -def _static_check_for_same_dimensions(operators): - """ValueError if operators determined to have different dimensions.""" - if len(operators) < 2: - return - - domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators - if op.domain_dimension.value is not None] - if len(set(value for name, value in domain_dimensions)) > 1: - raise ValueError("Operators must have the same domain dimension. Found: %s" - % domain_dimensions) - - range_dimensions = [(op.name, op.range_dimension.value) for op in operators - if op.range_dimension.value is not None] - if len(set(value for name, value in range_dimensions)) > 1: - raise ValueError("Operators must have the same range dimension. Found: %s" % - range_dimensions) - - -def _static_check_for_broadcastable_batch_shape(operators): - """ValueError if operators determined to have non-broadcastable shapes.""" - if len(operators) < 2: - return - - # This will fail if they cannot be broadcast together. - batch_shape = operators[0].batch_shape - for op in operators[1:]: - batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape) - - -class _Hints(object): - """Holds 'is_X' flags that every LinearOperator is initialized with.""" - - def __init__(self, - is_non_singular=None, - is_positive_definite=None, - is_self_adjoint=None): - self.is_non_singular = is_non_singular - self.is_positive_definite = is_positive_definite - self.is_self_adjoint = is_self_adjoint - - -################################################################################ -# Classes to add two linear operators. -################################################################################ - - -@six.add_metaclass(abc.ABCMeta) -class _Adder(object): - """Abstract base class to add two operators. - - Each `Adder` acts independently, adding everything it can, paying no attention - as to whether another `Adder` could have done the addition more efficiently. - """ - - @property - def name(self): - return self.__class__.__name__ - - @abc.abstractmethod - def can_add(self, op1, op2): - """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`.""" - pass - - @abc.abstractmethod - def _add(self, op1, op2, operator_name, hints): - # Derived classes can assume op1 and op2 have been validated, e.g. they have - # the same dtype, and their domain/range dimensions match. - pass - - def add(self, op1, op2, operator_name, hints=None): - """Return new `LinearOperator` acting like `op1 + op2`. - - Args: - op1: `LinearOperator` - op2: `LinearOperator`, with `shape` and `dtype` such that adding to - `op1` is allowed. - operator_name: `String` name to give to returned `LinearOperator` - hints: `_Hints` object. Returned `LinearOperator` will be created with - these hints. - - Returns: - `LinearOperator` - """ - updated_hints = _infer_hints_allowing_override(op1, op2, hints) - - if operator_name is None: - operator_name = "Add/" + op1.name + "__" + op2.name + "/" - - values = op1.graph_parents + op2.graph_parents - scope_name = self.name - if scope_name.startswith("_"): - scope_name = scope_name[1:] - with ops.name_scope(scope_name, values=values): - return self._add(op1, op2, operator_name, updated_hints) - - -class _AddAndReturnScaledIdentity(_Adder): - """Handles additions resulting in an Identity family member. - - The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family - is closed under addition. This `Adder` respects that, and returns an Identity - """ - - def can_add(self, op1, op2): - types = {_type(op1), _type(op2)} - return not types.difference(_IDENTITY_FAMILY) - - def _add(self, op1, op2, operator_name, hints): - # Will build a LinearOperatorScaledIdentity. - - if _type(op1) == _SCALED_IDENTITY: - multiplier_1 = op1.multiplier - else: - multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype) - - if _type(op2) == _SCALED_IDENTITY: - multiplier_2 = op2.multiplier - else: - multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype) - - return linear_operator_identity.LinearOperatorScaledIdentity( - num_rows=op1.range_dimension_tensor(), - multiplier=multiplier_1 + multiplier_2, - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -class _AddAndReturnDiag(_Adder): - """Handles additions resulting in a Diag operator.""" - - def can_add(self, op1, op2): - types = {_type(op1), _type(op2)} - return not types.difference(_DIAG_LIKE) - - def _add(self, op1, op2, operator_name, hints): - return linear_operator_diag.LinearOperatorDiag( - diag=op1.diag_part() + op2.diag_part(), - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -class _AddAndReturnTriL(_Adder): - """Handles additions resulting in a TriL operator.""" - - def can_add(self, op1, op2): - types = {_type(op1), _type(op2)} - return not types.difference(_DIAG_LIKE.union({_TRIL})) - - def _add(self, op1, op2, operator_name, hints): - if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: - op_add_to_tensor, op_other = op1, op2 - else: - op_add_to_tensor, op_other = op2, op1 - - return linear_operator_lower_triangular.LinearOperatorLowerTriangular( - tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -class _AddAndReturnMatrix(_Adder): - """"Handles additions resulting in a `LinearOperatorFullMatrix`.""" - - def can_add(self, op1, op2): # pylint: disable=unused-argument - return isinstance(op1, linear_operator.LinearOperator) and isinstance( - op2, linear_operator.LinearOperator) - - def _add(self, op1, op2, operator_name, hints): - if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: - op_add_to_tensor, op_other = op1, op2 - else: - op_add_to_tensor, op_other = op2, op1 - return linear_operator_full_matrix.LinearOperatorFullMatrix( - matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()), - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -################################################################################ -# Constants designating types of LinearOperators -################################################################################ - -# Type name constants for LinearOperator classes. -_IDENTITY = "identity" -_SCALED_IDENTITY = "scaled_identity" -_DIAG = "diag" -_TRIL = "tril" -_MATRIX = "matrix" - -# Groups of operators. -_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY} -_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY} -# operators with an efficient .add_to_tensor() method. -_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE - - -def _type(operator): - """Returns the type name constant (e.g. _TRIL) for operator.""" - if isinstance(operator, linear_operator_diag.LinearOperatorDiag): - return _DIAG - if isinstance(operator, - linear_operator_lower_triangular.LinearOperatorLowerTriangular): - return _TRIL - if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): - return _MATRIX - if isinstance(operator, linear_operator_identity.LinearOperatorIdentity): - return _IDENTITY - if isinstance(operator, - linear_operator_identity.LinearOperatorScaledIdentity): - return _SCALED_IDENTITY - raise TypeError("Operator type unknown: %s" % operator) - - -################################################################################ -# Addition tiers: -# We attempt to use Adders in tier K before K+1. -# -# Organize tiers to -# (i) reduce O(..) complexity of forming final operator, and -# (ii) produce the "most efficient" final operator. -# Dev notes: -# * Results of addition at tier K will be added at tier K or higher. -# * Tiers may change, and we warn the user that it may change. -################################################################################ - -# Note that the final tier, _AddAndReturnMatrix, will convert everything to a -# dense matrix. So it is sometimes very inefficient. -_DEFAULT_ADDITION_TIERS = [ - [_AddAndReturnScaledIdentity()], - [_AddAndReturnDiag()], - [_AddAndReturnTriL()], - [_AddAndReturnMatrix()], -] |