diff options
Diffstat (limited to 'tensorflow/python/ops/linalg/linear_operator_addition.py')
-rw-r--r-- | tensorflow/python/ops/linalg/linear_operator_addition.py | 432 |
1 files changed, 432 insertions, 0 deletions
diff --git a/tensorflow/python/ops/linalg/linear_operator_addition.py b/tensorflow/python/ops/linalg/linear_operator_addition.py new file mode 100644 index 0000000000..86130a2c07 --- /dev/null +++ b/tensorflow/python/ops/linalg/linear_operator_addition.py @@ -0,0 +1,432 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Add one or more `LinearOperators` efficiently.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.ops.linalg import linear_operator_diag +from tensorflow.python.ops.linalg import linear_operator_full_matrix +from tensorflow.python.ops.linalg import linear_operator_identity +from tensorflow.python.ops.linalg import linear_operator_lower_triangular + +__all__ = [] + + +def add_operators(operators, + operator_name=None, + addition_tiers=None, + name=None): + """Efficiently add one or more linear operators. + + Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of + operators `[B1, B2,...]` such that + + ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).``` + + The operators `Bk` result by adding some of the `Ak`, as allowed by + `addition_tiers`. + + Example of efficient adding of diagonal operators. + + ```python + A1 = LinearOperatorDiag(diag=[1., 1.], name="A1") + A2 = LinearOperatorDiag(diag=[2., 2.], name="A2") + + # Use two tiers, the first contains an Adder that returns Diag. Since both + # A1 and A2 are Diag, they can use this Adder. The second tier will not be + # used. + addition_tiers = [ + [_AddAndReturnDiag()], + [_AddAndReturnMatrix()]] + B_list = add_operators([A1, A2], addition_tiers=addition_tiers) + + len(B_list) + ==> 1 + + B_list[0].__class__.__name__ + ==> 'LinearOperatorDiag' + + B_list[0].to_dense() + ==> [[3., 0.], + [0., 3.]] + + B_list[0].name + ==> 'Add/A1__A2/' + ``` + + Args: + operators: Iterable of `LinearOperator` objects with same `dtype`, domain + and range dimensions, and broadcastable batch shapes. + operator_name: String name for returned `LinearOperator`. Defaults to + concatenation of "Add/A__B/" that indicates the order of addition steps. + addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i` + is a list of `Adder` objects. This function attempts to do all additions + in tier `i` before trying tier `i + 1`. + name: A name for this `Op`. Defaults to `add_operators`. + + Returns: + Subclass of `LinearOperator`. Class and order of addition may change as new + (and better) addition strategies emerge. + + Raises: + ValueError: If `operators` argument is empty. + ValueError: If shapes are incompatible. + """ + # Default setting + if addition_tiers is None: + addition_tiers = _DEFAULT_ADDITION_TIERS + + # Argument checking. + check_ops.assert_proper_iterable(operators) + operators = list(reversed(operators)) + if len(operators) < 1: + raise ValueError( + "Argument 'operators' must contain at least one operator. " + "Found: %s" % operators) + if not all( + isinstance(op, linear_operator.LinearOperator) for op in operators): + raise TypeError( + "Argument 'operators' must contain only LinearOperator instances. " + "Found: %s" % operators) + _static_check_for_same_dimensions(operators) + _static_check_for_broadcastable_batch_shape(operators) + + graph_parents = [] + for operator in operators: + graph_parents.extend(operator.graph_parents) + + with ops.name_scope(name or "add_operators", values=graph_parents): + + # Additions done in one of the tiers. Try tier 0, 1,... + ops_to_try_at_next_tier = list(operators) + for tier in addition_tiers: + ops_to_try_at_this_tier = ops_to_try_at_next_tier + ops_to_try_at_next_tier = [] + while ops_to_try_at_this_tier: + op1 = ops_to_try_at_this_tier.pop() + op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier) + if op2 is not None: + # Will try to add the result of this again at this same tier. + new_operator = adder.add(op1, op2, operator_name) + ops_to_try_at_this_tier.append(new_operator) + else: + ops_to_try_at_next_tier.append(op1) + + return ops_to_try_at_next_tier + + +def _pop_a_match_at_tier(op1, operator_list, tier): + # Search from the back of list to the front in order to create nice default + # order of operations. + for i in range(1, len(operator_list) + 1): + op2 = operator_list[-i] + for adder in tier: + if adder.can_add(op1, op2): + return operator_list.pop(-i), adder + return None, None + + +def _infer_hints_allowing_override(op1, op2, hints): + """Infer hints from op1 and op2. hints argument is an override. + + Args: + op1: LinearOperator + op2: LinearOperator + hints: _Hints object holding "is_X" boolean hints to use for returned + operator. + If some hint is None, try to set using op1 and op2. If the + hint is provided, ignore op1 and op2 hints. This allows an override + of previous hints, but does not allow forbidden hints (e.g. you still + cannot say a real diagonal operator is not self-adjoint. + + Returns: + _Hints object. + """ + hints = hints or _Hints() + # If A, B are self-adjoint, then so is A + B. + if hints.is_self_adjoint is None: + is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint + else: + is_self_adjoint = hints.is_self_adjoint + + # If A, B are positive definite, then so is A + B. + if hints.is_positive_definite is None: + is_positive_definite = op1.is_positive_definite and op2.is_positive_definite + else: + is_positive_definite = hints.is_positive_definite + + # A positive definite operator is always non-singular. + if is_positive_definite and hints.is_positive_definite is None: + is_non_singular = True + else: + is_non_singular = hints.is_non_singular + + return _Hints( + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite) + + +def _static_check_for_same_dimensions(operators): + """ValueError if operators determined to have different dimensions.""" + if len(operators) < 2: + return + + domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators + if op.domain_dimension.value is not None] + if len(set(value for name, value in domain_dimensions)) > 1: + raise ValueError("Operators must have the same domain dimension. Found: %s" + % domain_dimensions) + + range_dimensions = [(op.name, op.range_dimension.value) for op in operators + if op.range_dimension.value is not None] + if len(set(value for name, value in range_dimensions)) > 1: + raise ValueError("Operators must have the same range dimension. Found: %s" % + range_dimensions) + + +def _static_check_for_broadcastable_batch_shape(operators): + """ValueError if operators determined to have non-broadcastable shapes.""" + if len(operators) < 2: + return + + # This will fail if they cannot be broadcast together. + batch_shape = operators[0].batch_shape + for op in operators[1:]: + batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape) + + +class _Hints(object): + """Holds 'is_X' flags that every LinearOperator is initialized with.""" + + def __init__(self, + is_non_singular=None, + is_positive_definite=None, + is_self_adjoint=None): + self.is_non_singular = is_non_singular + self.is_positive_definite = is_positive_definite + self.is_self_adjoint = is_self_adjoint + + +################################################################################ +# Classes to add two linear operators. +################################################################################ + + +@six.add_metaclass(abc.ABCMeta) +class _Adder(object): + """Abstract base class to add two operators. + + Each `Adder` acts independently, adding everything it can, paying no attention + as to whether another `Adder` could have done the addition more efficiently. + """ + + @property + def name(self): + return self.__class__.__name__ + + @abc.abstractmethod + def can_add(self, op1, op2): + """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`.""" + pass + + @abc.abstractmethod + def _add(self, op1, op2, operator_name, hints): + # Derived classes can assume op1 and op2 have been validated, e.g. they have + # the same dtype, and their domain/range dimensions match. + pass + + def add(self, op1, op2, operator_name, hints=None): + """Return new `LinearOperator` acting like `op1 + op2`. + + Args: + op1: `LinearOperator` + op2: `LinearOperator`, with `shape` and `dtype` such that adding to + `op1` is allowed. + operator_name: `String` name to give to returned `LinearOperator` + hints: `_Hints` object. Returned `LinearOperator` will be created with + these hints. + + Returns: + `LinearOperator` + """ + updated_hints = _infer_hints_allowing_override(op1, op2, hints) + + if operator_name is None: + operator_name = "Add/" + op1.name + "__" + op2.name + "/" + + values = op1.graph_parents + op2.graph_parents + scope_name = self.name + if scope_name.startswith("_"): + scope_name = scope_name[1:] + with ops.name_scope(scope_name, values=values): + return self._add(op1, op2, operator_name, updated_hints) + + +class _AddAndReturnScaledIdentity(_Adder): + """Handles additions resulting in an Identity family member. + + The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family + is closed under addition. This `Adder` respects that, and returns an Identity + """ + + def can_add(self, op1, op2): + types = {_type(op1), _type(op2)} + return not types.difference(_IDENTITY_FAMILY) + + def _add(self, op1, op2, operator_name, hints): + # Will build a LinearOperatorScaledIdentity. + + if _type(op1) == _SCALED_IDENTITY: + multiplier_1 = op1.multiplier + else: + multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype) + + if _type(op2) == _SCALED_IDENTITY: + multiplier_2 = op2.multiplier + else: + multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype) + + return linear_operator_identity.LinearOperatorScaledIdentity( + num_rows=op1.range_dimension_tensor(), + multiplier=multiplier_1 + multiplier_2, + is_non_singular=hints.is_non_singular, + is_self_adjoint=hints.is_self_adjoint, + is_positive_definite=hints.is_positive_definite, + name=operator_name) + + +class _AddAndReturnDiag(_Adder): + """Handles additions resulting in a Diag operator.""" + + def can_add(self, op1, op2): + types = {_type(op1), _type(op2)} + return not types.difference(_DIAG_LIKE) + + def _add(self, op1, op2, operator_name, hints): + return linear_operator_diag.LinearOperatorDiag( + diag=op1.diag_part() + op2.diag_part(), + is_non_singular=hints.is_non_singular, + is_self_adjoint=hints.is_self_adjoint, + is_positive_definite=hints.is_positive_definite, + name=operator_name) + + +class _AddAndReturnTriL(_Adder): + """Handles additions resulting in a TriL operator.""" + + def can_add(self, op1, op2): + types = {_type(op1), _type(op2)} + return not types.difference(_DIAG_LIKE.union({_TRIL})) + + def _add(self, op1, op2, operator_name, hints): + if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: + op_add_to_tensor, op_other = op1, op2 + else: + op_add_to_tensor, op_other = op2, op1 + + return linear_operator_lower_triangular.LinearOperatorLowerTriangular( + tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), + is_non_singular=hints.is_non_singular, + is_self_adjoint=hints.is_self_adjoint, + is_positive_definite=hints.is_positive_definite, + name=operator_name) + + +class _AddAndReturnMatrix(_Adder): + """"Handles additions resulting in a `LinearOperatorFullMatrix`.""" + + def can_add(self, op1, op2): # pylint: disable=unused-argument + return isinstance(op1, linear_operator.LinearOperator) and isinstance( + op2, linear_operator.LinearOperator) + + def _add(self, op1, op2, operator_name, hints): + if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: + op_add_to_tensor, op_other = op1, op2 + else: + op_add_to_tensor, op_other = op2, op1 + return linear_operator_full_matrix.LinearOperatorFullMatrix( + matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()), + is_non_singular=hints.is_non_singular, + is_self_adjoint=hints.is_self_adjoint, + is_positive_definite=hints.is_positive_definite, + name=operator_name) + + +################################################################################ +# Constants designating types of LinearOperators +################################################################################ + +# Type name constants for LinearOperator classes. +_IDENTITY = "identity" +_SCALED_IDENTITY = "scaled_identity" +_DIAG = "diag" +_TRIL = "tril" +_MATRIX = "matrix" + +# Groups of operators. +_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY} +_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY} +# operators with an efficient .add_to_tensor() method. +_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE + + +def _type(operator): + """Returns the type name constant (e.g. _TRIL) for operator.""" + if isinstance(operator, linear_operator_diag.LinearOperatorDiag): + return _DIAG + if isinstance(operator, + linear_operator_lower_triangular.LinearOperatorLowerTriangular): + return _TRIL + if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): + return _MATRIX + if isinstance(operator, linear_operator_identity.LinearOperatorIdentity): + return _IDENTITY + if isinstance(operator, + linear_operator_identity.LinearOperatorScaledIdentity): + return _SCALED_IDENTITY + raise TypeError("Operator type unknown: %s" % operator) + + +################################################################################ +# Addition tiers: +# We attempt to use Adders in tier K before K+1. +# +# Organize tiers to +# (i) reduce O(..) complexity of forming final operator, and +# (ii) produce the "most efficient" final operator. +# Dev notes: +# * Results of addition at tier K will be added at tier K or higher. +# * Tiers may change, and we warn the user that it may change. +################################################################################ + +# Note that the final tier, _AddAndReturnMatrix, will convert everything to a +# dense matrix. So it is sometimes very inefficient. +_DEFAULT_ADDITION_TIERS = [ + [_AddAndReturnScaledIdentity()], + [_AddAndReturnDiag()], + [_AddAndReturnTriL()], + [_AddAndReturnMatrix()], +] |