aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt/python/training/powersign.py
blob: b4aa19264de4b1e1b8e9ecd3c2cb4637f5a06e25 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of PowerSign."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops


class PowerSignOptimizer(optimizer.Optimizer):
  """Optimizer that implements the PowerSign update.

  See [Bello et al., ICML2017],
  [Neural Optimizer Search with RL](https://arxiv.org/abs/1709.07417).
  """

  def __init__(self,
               learning_rate=0.1,
               base=math.e,
               beta=0.9,
               sign_decay_fn=None,
               use_locking=False,
               name='PowerSignOptimizer'):
    """Constructs a new PowerSignOptimizer object.

    Initialization:

    ```
    m_0 <- 0 (Initialize initial 1st moment vector)
    t <- 0 (Initialize timestep)
    ```

    Update:

    ```
    t <- t + 1
    m_t <- beta1 * m_{t-1} + (1 - beta1) * g
    sign_decay <- sign_decay_fn(t)
    update <- base ** (sign_decay * sign(g) * sign(m)) * g
    variable <- variable - lr_t * update
    ```

    Example usage for PowerSign-cd (PowerSign with cosine sign decay)
    ```
    decay_steps = 1000
    linear_decay_fn = sign_decays.get_cosine_decay_fn(decay_steps)
    opt = PowerSignOptimizer(learning_rate=0.1, sign_decay_fn=linear_decay_fn)
    ```

    Args:
      learning_rate: learning_rate used when taking a step.
      base: base used in optimizer.
      beta: decay used for computing the moving average m.
      sign_decay_fn: decay function applied to the sign(g) sign(m) quantity.
          Takes global_step as an argument. See sign_decay.py for some examples.
      use_locking: If True, use locks for update operations.
      name: Optional name for the operations created iwhen applying gradients.
        Defaults to "PowerSignOptimizer".
    """
    super(PowerSignOptimizer, self).__init__(use_locking, name)
    self._lr = learning_rate
    self._beta = beta
    self._logbase = math.log(base)

    self._sign_decay_fn = sign_decay_fn

    # Tensor versions of the constructor arguments, created in _prepare().
    self._lr_t = None
    self._beta_t = None
    self._logbase_t = None

  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    if self._sign_decay_fn is not None:
      self._sign_decay_t = ops.convert_to_tensor(
          self._sign_decay_fn(global_step), name='sign_decay')
    return super(PowerSignOptimizer, self).apply_gradients(
        grads_and_vars, global_step=global_step, name=name)

  def _create_slots(self, var_list):
    # Create slots for the first moment.
    for v in var_list:
      self._zeros_slot(v, 'm', self._name)

  def _prepare(self):
    self._lr_t = ops.convert_to_tensor(self._lr, name='learning_rate')
    self._beta_t = ops.convert_to_tensor(self._beta, name='beta')
    self._logbase_t = ops.convert_to_tensor(self._logbase, name='logbase')
    if self._sign_decay_fn is None:
      self._sign_decay_t = ops.convert_to_tensor(1.0, name='sign_decay')

  def _apply_dense(self, grad, var):
    m = self.get_slot(var, 'm')
    return training_ops.apply_power_sign(
        var,
        m,
        math_ops.cast(self._lr_t, var.dtype.base_dtype),
        math_ops.cast(self._logbase_t, var.dtype.base_dtype),
        math_ops.cast(self._sign_decay_t, var.dtype.base_dtype),
        math_ops.cast(self._beta_t, var.dtype.base_dtype),
        grad,
        use_locking=self._use_locking).op

  def _resource_apply_dense(self, grad, var):
    m = self.get_slot(var, 'm')
    return training_ops.resource_apply_power_sign(
        var.handle,
        m.handle,
        math_ops.cast(self._lr_t, var.dtype.base_dtype),
        math_ops.cast(self._logbase_t, var.dtype.base_dtype),
        math_ops.cast(self._sign_decay_t, var.dtype.base_dtype),
        math_ops.cast(self._beta_t, var.dtype.base_dtype),
        grad,
        use_locking=self._use_locking)

  def _apply_sparse(self, grad, var):
    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype)
    logbase_t = math_ops.cast(self._logbase_t, var.dtype.base_dtype)
    e_t = math_ops.cast(math.e, var.dtype.base_dtype)

    m = self.get_slot(var, 'm')
    m_t = state_ops.assign(
        m, (m * beta_t) + (grad * (1 - beta_t)), use_locking=self._use_locking)

    sign_g = ops.IndexedSlices(
        math_ops.sign(grad.values), grad.indices, dense_shape=grad.dense_shape)
    sign_gm = ops.IndexedSlices(
        array_ops.gather(math_ops.sign(m_t), sign_g.indices) * sign_g.values,
        sign_g.indices,
        dense_shape=sign_g.dense_shape)

    sign_decayed = math_ops.cast(
        self._sign_decay_t, var.dtype.base_dtype)
    multiplier_values = math_ops.pow(
        e_t, logbase_t * sign_decayed * sign_gm.values)
    multiplier = ops.IndexedSlices(
        multiplier_values, sign_gm.indices, dense_shape=sign_gm.dense_shape)

    final_update = ops.IndexedSlices(
        lr_t * multiplier.values * grad.values,
        multiplier.indices,
        dense_shape=multiplier.dense_shape)

    var_update = state_ops.scatter_sub(
        var,
        final_update.indices,
        final_update.values,
        use_locking=self._use_locking)

    return control_flow_ops.group(* [var_update, m_t])