aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py
blob: c26037935d9756d56b6778cbabffebda4c274a47 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""An optimizer wrapper that ensures correct behaviour
of stateful optimizers with multitask loss."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import types
import six

from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer

__all__ = ["MultitaskOptimizerWrapper",
           "clip_gradients_by_global_norm"]

def _is_all_zeros(grad):
  all_zeros = math_ops.equal(math_ops.count_nonzero(grad), 0)
  return all_zeros

def _get_wrapper(fn, opt):
  def wrapper(self, grad, *args, **kwargs):  # pylint: disable=unused-argument
    all_zeros = _is_all_zeros(grad)
    return control_flow_ops.cond(
        all_zeros,
        control_flow_ops.no_op,
        lambda: fn(grad, *args, **kwargs))
  wrapper = types.MethodType(wrapper, opt)
  return wrapper

class MultitaskOptimizerWrapper(object):
  """Optimizer wrapper that ensures that
  all-zero gradients don't affect the optimizer state.

  This might be useful when a multi-task loss is used,
  and some components of the loss might be
  not present (e.g. masked out) in some training batches.
  Technically their gradient would be zero,
  which would normally affect the optimizer state
  (e.g. push running average to zero).
  However this is not the desired behaviour,
  since the missing loss component
  should be treated as unknown rather than zero.

  This wrapper filters out all-zero gradient tensors,
  therefore preserving the optimizer state.

  If gradient clipping by global norm is used,
  the provided function clip_gradients_by_global_norm
  should be used (and specified explicitly by the user).
  Otherwise the global norm would be underestimated
  because of all-zero tensors that should be ignored.

  The gradient calculation and application
  are delegated to an underlying optimizer.
  The gradient application is altered only for all-zero tensors.

  Example:
  ```python
  momentum_optimizer = tf.train.MomentumOptimizer(
    learning_rate, momentum=0.9)
  multitask_momentum_optimizer = tf.contrib.opt.MultitaskOptimizerWrapper(
    momentum_optimizer)
  gradvars = multitask_momentum_optimizer.compute_gradients(
    loss)
  gradvars_clipped, _ = tf.contrib.opt.clip_gradients_by_global_norm(
    gradvars, 15.0)
  train_op = multitask_momentum_optimizer.apply_gradients(
    gradvars_clipped, global_step=batch)
  ```
  """
  def __init__(self, opt):
    """
    Args:
    opt: an instance of a class that implements tf.train.Optimizer.
    """
    if not isinstance(opt, optimizer.Optimizer):
      raise TypeError(
          "Supplied optimizer must be an instance of tf.train.Optimizer")
    self._opt = opt
    overriden_methods = ('_apply_dense',
                         '_resource_apply_dense',
                         '_apply_sparse',
                         '_resource_apply_sparse')
    for name in overriden_methods:
      fn = getattr(self._opt, name)
      wrapper = _get_wrapper(fn, self._opt)
      setattr(self._opt, name, wrapper)

  def __getattr__(self, name):
    return getattr(self._opt, name)


def clip_gradients_by_global_norm(gradients_variables, clip_norm=20.):
  """Clips gradients of a multitask loss by their global norm.
  Ignores all-zero tensors when computing the global norm.

  Args:
  gradients_variables: a list of pairs (gradient, variable).
  clip_norm: a float Tensor, the global norm to clip on. Default is 20.0.

  Returns:
  list: A list of pairs of the same type as gradients_variables,.
  fixed_global_norm: A 0-D (scalar) Tensor representing the global norm.
  """
  gradients, variables = six.moves.zip(*gradients_variables)
  def _replace_nonexisting_grad(grad):
    if grad is None:
      return grad
    all_zeros = _is_all_zeros(grad)
    return control_flow_ops.cond(all_zeros,
                                 lambda: array_ops.zeros(
                                     [], dtype=dtypes.as_dtype(grad.dtype)),
                                 lambda: grad)
  nonzero_gradients = [_replace_nonexisting_grad(g) for g in gradients]
  fixed_global_norm = clip_ops.global_norm(nonzero_gradients)
  gradients, _ = clip_ops.clip_by_global_norm(gradients, clip_norm,
                                              use_norm=fixed_global_norm)
  return list(six.moves.zip(gradients, variables)), fixed_global_norm