diff options
author | 2018-01-26 16:53:59 -0800 | |
---|---|---|
committer | 2018-01-26 16:59:01 -0800 | |
commit | aee7f95a027accc94f1f9130f0cfaecd9399bc1d (patch) | |
tree | 6b8484915bf631f18b2fa0561a73549d9bf19fad /tensorflow/python/training/learning_rate_decay.py | |
parent | e95537708f070a98607393a8f60bc61f1611a77b (diff) |
Add C0301 line-too-long error to pylint sanity check.
PiperOrigin-RevId: 183467186
Diffstat (limited to 'tensorflow/python/training/learning_rate_decay.py')
-rw-r--r-- | tensorflow/python/training/learning_rate_decay.py | 145 |
1 files changed, 89 insertions, 56 deletions
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index 3ee49650e0..343a49cded 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Various learning rate decay functions.""" from __future__ import absolute_import from __future__ import division @@ -28,8 +27,12 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, - staircase=False, name=None): +def exponential_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): """Applies exponential decay to the learning rate. When training a model, it is often recommended to lower the learning rate as @@ -85,9 +88,9 @@ def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, """ if global_step is None: raise ValueError("global_step is required for exponential_decay.") - with ops.name_scope(name, "ExponentialDecay", - [learning_rate, global_step, - decay_steps, decay_rate]) as name: + with ops.name_scope( + name, "ExponentialDecay", + [learning_rate, global_step, decay_steps, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype global_step = math_ops.cast(global_step, dtype) @@ -96,8 +99,8 @@ def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, p = global_step / decay_steps if staircase: p = math_ops.floor(p) - return math_ops.multiply(learning_rate, math_ops.pow(decay_rate, p), - name=name) + return math_ops.multiply( + learning_rate, math_ops.pow(decay_rate, p), name=name) def piecewise_constant(x, boundaries, values, name=None): @@ -156,15 +159,15 @@ def piecewise_constant(x, boundaries, values, name=None): boundaries[i] = b else: raise ValueError( - "Boundaries (%s) must have the same dtype as x (%s)." % ( - b.dtype.base_dtype, x.dtype.base_dtype)) + "Boundaries (%s) must have the same dtype as x (%s)." % + (b.dtype.base_dtype, x.dtype.base_dtype)) # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing. values = ops.convert_n_to_tensor(values) for v in values[1:]: if v.dtype.base_dtype != values[0].dtype.base_dtype: raise ValueError( - "Values must have elements all with the same dtype (%s vs %s)." % ( - values[0].dtype.base_dtype, v.dtype.base_dtype)) + "Values must have elements all with the same dtype (%s vs %s)." % + (values[0].dtype.base_dtype, v.dtype.base_dtype)) pred_fn_pairs = [] pred_fn_pairs.append((x <= boundaries[0], lambda: values[0])) pred_fn_pairs.append((x > boundaries[-1], lambda: values[-1])) @@ -179,9 +182,13 @@ def piecewise_constant(x, boundaries, values, name=None): return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) -def polynomial_decay(learning_rate, global_step, decay_steps, - end_learning_rate=0.0001, power=1.0, - cycle=False, name=None): +def polynomial_decay(learning_rate, + global_step, + decay_steps, + end_learning_rate=0.0001, + power=1.0, + cycle=False, + name=None): """Applies a polynomial decay to the learning rate. It is commonly observed that a monotonically decreasing learning rate, whose @@ -255,9 +262,10 @@ def polynomial_decay(learning_rate, global_step, decay_steps, """ if global_step is None: raise ValueError("global_step is required for polynomial_decay.") - with ops.name_scope(name, "PolynomialDecay", - [learning_rate, global_step, - decay_steps, end_learning_rate, power]) as name: + with ops.name_scope( + name, "PolynomialDecay", + [learning_rate, global_step, decay_steps, end_learning_rate, power + ]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype global_step = math_ops.cast(global_step, dtype) @@ -267,23 +275,28 @@ def polynomial_decay(learning_rate, global_step, decay_steps, if cycle: # Find the first multiple of decay_steps that is bigger than global_step. # If global_step is zero set the multiplier to 1 - multiplier = control_flow_ops.cond(math_ops.equal(global_step, 0), - lambda: 1.0, - lambda: math_ops.ceil( - global_step / decay_steps)) + multiplier = control_flow_ops.cond( + math_ops.equal(global_step, 0), lambda: 1.0, + lambda: math_ops.ceil(global_step / decay_steps)) decay_steps = math_ops.multiply(decay_steps, multiplier) else: # Make sure that the global_step used is not bigger than decay_steps. global_step = math_ops.minimum(global_step, decay_steps) p = math_ops.div(global_step, decay_steps) - return math_ops.add(math_ops.multiply(learning_rate - end_learning_rate, - math_ops.pow(1 - p, power)), - end_learning_rate, name=name) - - -def natural_exp_decay(learning_rate, global_step, decay_steps, decay_rate, - staircase=False, name=None): + return math_ops.add( + math_ops.multiply(learning_rate - end_learning_rate, + math_ops.pow(1 - p, power)), + end_learning_rate, + name=name) + + +def natural_exp_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): """Applies natural exponential decay to the initial learning rate. When training a model, it is often recommended to lower the learning rate as @@ -349,8 +362,12 @@ def natural_exp_decay(learning_rate, global_step, decay_steps, decay_rate, return math_ops.multiply(learning_rate, exponent, name=name) -def inverse_time_decay(learning_rate, global_step, decay_steps, decay_rate, - staircase=False, name=None): +def inverse_time_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): """Applies inverse time decay to the initial learning rate. When training a model, it is often recommended to lower the learning rate as @@ -362,13 +379,15 @@ def inverse_time_decay(learning_rate, global_step, decay_steps, decay_rate, The function returns the decayed learning rate. It is computed as: ```python - decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step) + decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / + decay_step) ``` or, if `staircase` is `True`, as: ```python - decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step)) + decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / + decay_step)) ``` Example: decay 1/t with a rate of 0.5: @@ -379,7 +398,8 @@ def inverse_time_decay(learning_rate, global_step, decay_steps, decay_rate, learning_rate = 0.1 decay_steps = 1.0 decay_rate = 0.5 - learning_rate = tf.train.inverse_time_decay(learning_rate, global_step, decay_steps, decay_rate) + learning_rate = tf.train.inverse_time_decay(learning_rate, global_step, + decay_steps, decay_rate) # Passing global_step to minimize() will increment it at each step. learning_step = ( @@ -424,8 +444,7 @@ def inverse_time_decay(learning_rate, global_step, decay_steps, decay_rate, return math_ops.div(learning_rate, denom, name=name) -def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, - name=None): +def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): """Applies cosine decay to the learning rate. See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent @@ -484,8 +503,13 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, return math_ops.multiply(learning_rate, decayed) -def cosine_decay_restarts(learning_rate, global_step, first_decay_steps, - t_mul=2.0, m_mul=1.0, alpha=0.0, name=None): +def cosine_decay_restarts(learning_rate, + global_step, + first_decay_steps, + t_mul=2.0, + m_mul=1.0, + alpha=0.0, + name=None): """Applies cosine decay with restarts to the learning rate. See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent @@ -532,10 +556,9 @@ def cosine_decay_restarts(learning_rate, global_step, first_decay_steps, """ if global_step is None: raise ValueError("cosine decay restarts requires global_step") - with ops.name_scope(name, "SGDRDecay", - [learning_rate, global_step]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, - name="initial_learning_rate") + with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step]) as name: + learning_rate = ops.convert_to_tensor( + learning_rate, name="initial_learning_rate") dtype = learning_rate.dtype global_step = math_ops.cast(global_step, dtype) first_decay_steps = math_ops.cast(first_decay_steps, dtype) @@ -547,11 +570,12 @@ def cosine_decay_restarts(learning_rate, global_step, first_decay_steps, def compute_step(completed_fraction, geometric=False): if geometric: - i_restart = math_ops.floor(math_ops.log(1.0 - completed_fraction * ( - 1.0 - t_mul)) / math_ops.log(t_mul)) + i_restart = math_ops.floor( + math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / + math_ops.log(t_mul)) - sum_r = (1.0 - t_mul ** i_restart) / (1.0 - t_mul) - completed_fraction = (completed_fraction - sum_r) / t_mul ** i_restart + sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) + completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart else: i_restart = math_ops.floor(completed_fraction) @@ -564,16 +588,20 @@ def cosine_decay_restarts(learning_rate, global_step, first_decay_steps, lambda: compute_step(completed_fraction, geometric=False), lambda: compute_step(completed_fraction, geometric=True)) - m_fac = m_mul ** i_restart - cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos( - constant_op.constant(math.pi) * completed_fraction)) + m_fac = m_mul**i_restart + cosine_decayed = 0.5 * m_fac * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction)) decayed = (1 - alpha) * cosine_decayed + alpha return math_ops.multiply(learning_rate, decayed, name=name) -def linear_cosine_decay(learning_rate, global_step, decay_steps, - num_periods=0.5, alpha=0.0, beta=0.001, +def linear_cosine_decay(learning_rate, + global_step, + decay_steps, + num_periods=0.5, + alpha=0.0, + beta=0.001, name=None): """Applies linear cosine decay to the learning rate. @@ -651,9 +679,14 @@ def linear_cosine_decay(learning_rate, global_step, decay_steps, return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) -def noisy_linear_cosine_decay(learning_rate, global_step, decay_steps, - initial_variance=1.0, variance_decay=0.55, - num_periods=0.5, alpha=0.0, beta=0.001, +def noisy_linear_cosine_decay(learning_rate, + global_step, + decay_steps, + initial_variance=1.0, + variance_decay=0.55, + num_periods=0.5, + alpha=0.0, + beta=0.001, name=None): """Applies noisy linear cosine decay to the learning rate. @@ -734,8 +767,8 @@ def noisy_linear_cosine_decay(learning_rate, global_step, decay_steps, math_ops.pow(1.0 + global_step, variance_decay)) std = math_ops.sqrt(variance) noisy_linear_decayed = ( - linear_decayed + random_ops.random_normal( - linear_decayed.shape, stddev=std)) + linear_decayed + + random_ops.random_normal(linear_decayed.shape, stddev=std)) completed_fraction = global_step / decay_steps fraction = 2.0 * num_periods * completed_fraction |