aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/training/learning_rate_decay.py432
-rw-r--r--tensorflow/python/training/learning_rate_decay_v2.py898
-rw-r--r--tensorflow/python/training/learning_rate_decay_v2_test.py497
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_v2.py24
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_v2_test.py13
6 files changed, 1547 insertions, 318 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index e6169e9e80..ba9c6a2320 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -4393,6 +4393,7 @@ cuda_py_tests(
"training/ftrl_test.py",
"training/gradient_descent_test.py",
"training/learning_rate_decay_test.py",
+ "training/learning_rate_decay_v2_test.py",
"training/momentum_test.py",
"training/optimizer_test.py",
"training/proximal_adagrad_test.py",
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index fd195a7965..29b5465321 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -17,19 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import math
-
from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
+from tensorflow.python.training import learning_rate_decay_v2
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.exponential_decay")
+@tf_export(v1=["train.exponential_decay"])
def exponential_decay(learning_rate,
global_step,
decay_steps,
@@ -95,32 +88,19 @@ def exponential_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("global_step is required for exponential_decay.")
- with ops.name_scope(
- name, "ExponentialDecay",
- [learning_rate, global_step, decay_steps, decay_rate]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
- decay_rate = math_ops.cast(decay_rate, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- p = global_step_recomp / decay_steps
- if staircase:
- p = math_ops.floor(p)
- return math_ops.multiply(
- learning_rate, math_ops.pow(decay_rate, p), name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.piecewise_constant")
+ decayed_lr = learning_rate_decay_v2.exponential_decay(learning_rate,
+ global_step,
+ decay_steps,
+ decay_rate,
+ staircase=staircase,
+ name=name)
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.piecewise_constant"])
def piecewise_constant(x, boundaries, values, name=None):
"""Piecewise constant from boundaries and interval values.
@@ -163,58 +143,15 @@ def piecewise_constant(x, boundaries, values, name=None):
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if len(boundaries) != len(values) - 1:
- raise ValueError(
- "The length of boundaries should be 1 less than the length of values")
- with ops.name_scope(name, "PiecewiseConstant",
- [x, boundaries, values, name]) as name:
- boundaries = ops.convert_n_to_tensor(boundaries)
- values = ops.convert_n_to_tensor(values)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- x_recomp = ops.convert_to_tensor(x)
- # Avoid explicit conversion to x's dtype. This could result in faulty
- # comparisons, for example if floats are converted to integers.
- for i, b in enumerate(boundaries):
- if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
- # We can promote int32 boundaries to int64 without loss of precision.
- # This covers the most common case where the user passes in boundaries
- # as an array of Python integers.
- if (b.dtype.base_dtype == dtypes.int32 and
- x_recomp.dtype.base_dtype == dtypes.int64):
- b = math_ops.cast(b, x_recomp.dtype.base_dtype)
- boundaries[i] = b
- else:
- raise ValueError(
- "Boundaries (%s) must have the same dtype as x (%s)." %
- (b.dtype.base_dtype, x_recomp.dtype.base_dtype))
- # TODO(rdipietro): Ensure that boundaries' elements strictly increases.
- for v in values[1:]:
- if v.dtype.base_dtype != values[0].dtype.base_dtype:
- raise ValueError(
- "Values must have elements all with the same dtype (%s vs %s)." %
- (values[0].dtype.base_dtype, v.dtype.base_dtype))
- pred_fn_pairs = []
- pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
- pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
- for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
- # Need to bind v here; can do this with lambda v=v: ...
- pred = (x_recomp > low) & (x_recomp <= high)
- pred_fn_pairs.append((pred, lambda v=v: v))
-
- # The default isn't needed here because our conditions are mutually
- # exclusive and exhaustive, but tf.case requires it.
- default = lambda: values[0]
- return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.polynomial_decay")
+ decayed_lr = learning_rate_decay_v2.piecewise_constant(x, boundaries, values,
+ name=name)
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.polynomial_decay"])
def polynomial_decay(learning_rate,
global_step,
decay_steps,
@@ -299,46 +236,22 @@ def polynomial_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("global_step is required for polynomial_decay.")
- with ops.name_scope(
- name, "PolynomialDecay",
- [learning_rate, global_step, decay_steps, end_learning_rate, power
- ]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- end_learning_rate = math_ops.cast(end_learning_rate, dtype)
- power = math_ops.cast(power, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- decay_steps_recomp = math_ops.cast(decay_steps, dtype)
- if cycle:
- # Find the first multiple of decay_steps that is bigger than
- # global_step. If global_step is zero set the multiplier to 1
- multiplier = control_flow_ops.cond(
- math_ops.equal(global_step_recomp, 0), lambda: 1.0,
- lambda: math_ops.ceil(global_step_recomp / decay_steps))
- decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
- else:
- # Make sure that the global_step used is not bigger than decay_steps.
- global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
-
- p = math_ops.div(global_step_recomp, decay_steps_recomp)
- return math_ops.add(
- math_ops.multiply(learning_rate - end_learning_rate,
- math_ops.pow(1 - p, power)),
- end_learning_rate,
- name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.natural_exp_decay")
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ learning_rate,
+ global_step,
+ decay_steps,
+ end_learning_rate=end_learning_rate,
+ power=power,
+ cycle=cycle,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.natural_exp_decay"])
def natural_exp_decay(learning_rate,
global_step,
decay_steps,
@@ -410,32 +323,17 @@ def natural_exp_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("global_step is required for natural_exp_decay.")
- with ops.name_scope(name, "NaturalExpDecay",
- [learning_rate, global_step, decay_rate]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
- decay_rate = math_ops.cast(decay_rate, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- p = global_step_recomp / decay_steps
- if staircase:
- p = math_ops.floor(p)
- exponent = math_ops.exp(
- math_ops.multiply(math_ops.negative(decay_rate), p))
- return math_ops.multiply(learning_rate, exponent, name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.inverse_time_decay")
+ decayed_lr = learning_rate_decay_v2.natural_exp_decay(
+ learning_rate, global_step, decay_steps, decay_rate, staircase=staircase,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.inverse_time_decay"])
def inverse_time_decay(learning_rate,
global_step,
decay_steps,
@@ -507,32 +405,21 @@ def inverse_time_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("global_step is required for inverse_time_decay.")
- with ops.name_scope(name, "InverseTimeDecay",
- [learning_rate, global_step, decay_rate]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
- decay_rate = math_ops.cast(decay_rate, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- p = global_step_recomp / decay_steps
- if staircase:
- p = math_ops.floor(p)
- const = math_ops.cast(constant_op.constant(1), dtype)
- denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
- return math_ops.div(learning_rate, denom, name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.cosine_decay")
+ decayed_lr = learning_rate_decay_v2.inverse_time_decay(
+ learning_rate,
+ global_step,
+ decay_steps,
+ decay_rate,
+ staircase=staircase,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.cosine_decay"])
def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
"""Applies cosine decay to the learning rate.
@@ -581,32 +468,16 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("cosine decay requires global_step")
- with ops.name_scope(name, "CosineDecay",
- [learning_rate, global_step]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
- completed_fraction = global_step_recomp / decay_steps
- cosine_decayed = 0.5 * (1.0 + math_ops.cos(
- constant_op.constant(math.pi) * completed_fraction))
-
- decayed = (1 - alpha) * cosine_decayed + alpha
- return math_ops.multiply(learning_rate, decayed)
+ decayed_lr = learning_rate_decay_v2.cosine_decay(
+ learning_rate, global_step, decay_steps, alpha=alpha, name=name)
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
- return decayed_lr
+ return decayed_lr
-@tf_export("train.cosine_decay_restarts")
+@tf_export(v1=["train.cosine_decay_restarts"])
def cosine_decay_restarts(learning_rate,
global_step,
first_decay_steps,
@@ -664,57 +535,22 @@ def cosine_decay_restarts(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("cosine decay restarts requires global_step")
- with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step]) as name:
- learning_rate = ops.convert_to_tensor(
- learning_rate, name="initial_learning_rate")
- dtype = learning_rate.dtype
- first_decay_steps = math_ops.cast(first_decay_steps, dtype)
- alpha = math_ops.cast(alpha, dtype)
- t_mul = math_ops.cast(t_mul, dtype)
- m_mul = math_ops.cast(m_mul, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- completed_fraction = global_step_recomp / first_decay_steps
-
- def compute_step(completed_fraction, geometric=False):
- """Helper for `cond` operation."""
- if geometric:
- i_restart = math_ops.floor(
- math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
- math_ops.log(t_mul))
-
- sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
- completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
-
- else:
- i_restart = math_ops.floor(completed_fraction)
- completed_fraction -= i_restart
+ decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+ learning_rate,
+ global_step,
+ first_decay_steps,
+ t_mul=t_mul,
+ m_mul=m_mul,
+ alpha=alpha,
+ name=name)
- return i_restart, completed_fraction
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
- i_restart, completed_fraction = control_flow_ops.cond(
- math_ops.equal(t_mul, 1.0),
- lambda: compute_step(completed_fraction, geometric=False),
- lambda: compute_step(completed_fraction, geometric=True))
+ return decayed_lr
- m_fac = m_mul**i_restart
- cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
- constant_op.constant(math.pi) * completed_fraction))
- decayed = (1 - alpha) * cosine_decayed + alpha
- return math_ops.multiply(learning_rate, decayed, name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.linear_cosine_decay")
+@tf_export(v1=["train.linear_cosine_decay"])
def linear_cosine_decay(learning_rate,
global_step,
decay_steps,
@@ -781,37 +617,22 @@ def linear_cosine_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("linear cosine decay requires global_step")
- with ops.name_scope(name, "LinearCosineDecay",
- [learning_rate, global_step]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
- num_periods = math_ops.cast(num_periods, dtype)
- alpha = math_ops.cast(alpha, dtype)
- beta = math_ops.cast(beta, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
- linear_decayed = (decay_steps - global_step_recomp) / decay_steps
- completed_fraction = global_step_recomp / decay_steps
- fraction = 2.0 * num_periods * completed_fraction
- cosine_decayed = 0.5 * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
-
- linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
- return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.noisy_linear_cosine_decay")
+ decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
+ learning_rate,
+ global_step,
+ decay_steps,
+ num_periods=num_periods,
+ alpha=alpha,
+ beta=beta,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.noisy_linear_cosine_decay"])
def noisy_linear_cosine_decay(learning_rate,
global_step,
decay_steps,
@@ -886,42 +707,17 @@ def noisy_linear_cosine_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("noisy linear cosine decay requires global_step")
- with ops.name_scope(name, "NoisyLinearCosineDecay",
- [learning_rate, global_step]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
- initial_variance = math_ops.cast(initial_variance, dtype)
- variance_decay = math_ops.cast(variance_decay, dtype)
- num_periods = math_ops.cast(num_periods, dtype)
- alpha = math_ops.cast(alpha, dtype)
- beta = math_ops.cast(beta, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
- linear_decayed = (decay_steps - global_step_recomp) / decay_steps
- variance = initial_variance / (
- math_ops.pow(1.0 + global_step_recomp, variance_decay))
- std = math_ops.sqrt(variance)
- noisy_linear_decayed = (
- linear_decayed + random_ops.random_normal(
- linear_decayed.shape, stddev=std))
-
- completed_fraction = global_step_recomp / decay_steps
- fraction = 2.0 * num_periods * completed_fraction
- cosine_decayed = 0.5 * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
- noisy_linear_cosine_decayed = (
- (alpha + noisy_linear_decayed) * cosine_decayed + beta)
-
- return math_ops.multiply(
- learning_rate, noisy_linear_cosine_decayed, name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
+ decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
+ learning_rate, global_step,
+ decay_steps,
+ initial_variance=initial_variance,
+ variance_decay=variance_decay,
+ num_periods=num_periods,
+ alpha=alpha,
+ beta=beta,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
diff --git a/tensorflow/python/training/learning_rate_decay_v2.py b/tensorflow/python/training/learning_rate_decay_v2.py
new file mode 100644
index 0000000000..9c5e144be6
--- /dev/null
+++ b/tensorflow/python/training/learning_rate_decay_v2.py
@@ -0,0 +1,898 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Various learning rate decay functions."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import math
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("train.exponential_decay", v1=[])
+def exponential_decay(learning_rate,
+ global_step,
+ decay_steps,
+ decay_rate,
+ staircase=False,
+ name=None):
+ """Applies exponential decay to the learning rate.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies an exponential decay function
+ to a provided initial learning rate. It requires a `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns a no-arg function that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions.
+ It is computed as:
+
+ ```python
+ decayed_learning_rate = learning_rate *
+ decay_rate ^ (global_step / decay_steps)
+ ```
+
+ If the argument `staircase` is `True`, then `global_step / decay_steps` is an
+ integer division and the decayed learning rate follows a staircase function.
+
+ Example: decay every 100000 steps with a base of 0.96:
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ starter_learning_rate = 0.1
+ learning_rate_fn = tf.train.exponential_decay(starter_learning_rate,
+ global_step, 100000, 0.96,
+ staircase=True)
+ # Passing global_step to minimize() will increment it at each step.
+ learning_step = (
+ tf.train.GradientDescentOptimizer(learning_rate_fn)
+ .minimize(...my loss..., global_step=global_step)
+ )
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation. Must not be negative.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Must be positive. See the decay computation above.
+ decay_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The decay rate.
+ staircase: Boolean. If `True` decay the learning rate at discrete intervals
+ name: String. Optional name of the operation. Defaults to
+ 'ExponentialDecay'.
+
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("global_step is required for exponential_decay.")
+ def decayed_lr(learning_rate, global_step, decay_steps, decay_rate,
+ staircase, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(
+ name, "ExponentialDecay",
+ [learning_rate, global_step, decay_steps, decay_rate]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ decay_rate = math_ops.cast(decay_rate, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ return math_ops.multiply(
+ learning_rate, math_ops.pow(decay_rate, p), name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ decay_rate, staircase, name)
+
+
+@tf_export("train.piecewise_constant", v1=[])
+def piecewise_constant(x, boundaries, values, name=None):
+ """Piecewise constant from boundaries and interval values.
+
+ This function returns a no-arg callable to compute the piecewise constant.
+ This can be useful for changing the learning rate value across
+ different invocations of optimizer functions.
+
+ Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
+ for the next 10000 steps, and 0.1 for any additional steps.
+
+ ```python
+ global_step = tf.Variable(0, trainable=False)
+ boundaries = [100000, 110000]
+ values = [1.0, 0.5, 0.1]
+ learning_rate_fn = tf.train.piecewise_constant(global_step, boundaries,
+ values)
+ learning_rate = learning_rate_fn()
+
+ # Later, whenever we perform an optimization step, we increment global_step.
+ ```
+
+ Args:
+ x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`,
+ `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
+ boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
+ increasing entries, and with all elements having the same type as `x`.
+ values: A list of `Tensor`s or `float`s or `int`s that specifies the values
+ for the intervals defined by `boundaries`. It should have one more element
+ than `boundaries`, and all elements should have the same type.
+ name: A string. Optional name of the operation. Defaults to
+ 'PiecewiseConstant'.
+
+ Returns:
+ A no-arg function that outputs a 0-D Tensor. The output of the no-arg
+ function is `values[0]` when `x <= boundaries[0]`,
+ `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ...,
+ and values[-1] when `x > boundaries[-1]`.
+
+ Raises:
+ ValueError: if types of `x` and `boundaries` do not match, or types of all
+ `values` do not match or
+ the number of elements in the lists does not match.
+ """
+ if len(boundaries) != len(values) - 1:
+ raise ValueError(
+ "The length of boundaries should be 1 less than the length of values")
+ def decayed_lr(x, boundaries, values, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "PiecewiseConstant",
+ [x, boundaries, values, name]) as name:
+ boundaries = ops.convert_n_to_tensor(boundaries)
+ values = ops.convert_n_to_tensor(values)
+ x_recomp = ops.convert_to_tensor(x)
+ # Avoid explicit conversion to x's dtype. This could result in faulty
+ # comparisons, for example if floats are converted to integers.
+ for i, b in enumerate(boundaries):
+ if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
+ # We can promote int32 boundaries to int64 without loss of precision.
+ # This covers the most common case where the user passes in boundaries
+ # as an array of Python integers.
+ if (b.dtype.base_dtype == dtypes.int32 and
+ x_recomp.dtype.base_dtype == dtypes.int64):
+ b = math_ops.cast(b, x_recomp.dtype.base_dtype)
+ boundaries[i] = b
+ else:
+ raise ValueError(
+ "Boundaries (%s) must have the same dtype as x (%s)." %
+ (b.dtype.base_dtype, x_recomp.dtype.base_dtype))
+ # TODO(rdipietro): Ensure that boundaries' elements strictly increases.
+ for v in values[1:]:
+ if v.dtype.base_dtype != values[0].dtype.base_dtype:
+ raise ValueError(
+ "Values must have elements all with the same dtype (%s vs %s)." %
+ (values[0].dtype.base_dtype, v.dtype.base_dtype))
+ pred_fn_pairs = []
+ pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
+ pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
+ for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
+ # Need to bind v here; can do this with lambda v=v: ...
+ pred = (x_recomp > low) & (x_recomp <= high)
+ pred_fn_pairs.append((pred, lambda v=v: v))
+
+ # The default isn't needed here because our conditions are mutually
+ # exclusive and exhaustive, but tf.case requires it.
+ default = lambda: values[0]
+ return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
+
+ return functools.partial(decayed_lr, x, boundaries, values, name)
+
+
+@tf_export("train.polynomial_decay", v1=[])
+def polynomial_decay(learning_rate,
+ global_step,
+ decay_steps,
+ end_learning_rate=0.0001,
+ power=1.0,
+ cycle=False,
+ name=None):
+ """Applies a polynomial decay to the learning rate.
+
+ It is commonly observed that a monotonically decreasing learning rate, whose
+ degree of change is carefully chosen, results in a better performing model.
+ This function applies a polynomial decay function to a provided initial
+ `learning_rate` to reach an `end_learning_rate` in the given `decay_steps`.
+
+ It requires a `global_step` value to compute the decayed learning rate. You
+ can just pass a TensorFlow variable that you increment at each training step.
+
+ The function returns a no-arg callable that outputs the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ global_step = min(global_step, decay_steps)
+ decayed_learning_rate = (learning_rate - end_learning_rate) *
+ (1 - global_step / decay_steps) ^ (power) +
+ end_learning_rate
+
+ ```
+
+ If `cycle` is True then a multiple of `decay_steps` is used, the first one
+ that is bigger than `global_steps`.
+
+ ```python
+ decay_steps = decay_steps * ceil(global_step / decay_steps)
+ decayed_learning_rate_fn = (learning_rate - end_learning_rate) *
+ (1 - global_step / decay_steps) ^ (power) +
+ end_learning_rate
+ decayed_learning_rate = decayed_learning_rate_fn()
+
+ ```
+
+ Example: decay from 0.1 to 0.01 in 10000 steps using sqrt (i.e. power=0.5):
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ starter_learning_rate = 0.1
+ end_learning_rate = 0.01
+ decay_steps = 10000
+ learning_rate_fn = tf.train.polynomial_decay(starter_learning_rate,
+ global_step, decay_steps,
+ end_learning_rate,
+ power=0.5)
+ # Passing global_step to minimize() will increment it at each step.
+ learning_step = (
+ tf.train.GradientDescentOptimizer(learning_rate_fn)
+ .minimize(...my loss..., global_step=global_step)
+ )
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation. Must not be negative.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Must be positive. See the decay computation above.
+ end_learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The minimal end learning rate.
+ power: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The power of the polynomial. Defaults to linear, 1.0.
+ cycle: A boolean, whether or not it should cycle beyond decay_steps.
+ name: String. Optional name of the operation. Defaults to
+ 'PolynomialDecay'.
+
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("global_step is required for polynomial_decay.")
+ def decayed_lr(learning_rate, global_step, decay_steps, end_learning_rate,
+ power, cycle, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(
+ name, "PolynomialDecay",
+ [learning_rate, global_step, decay_steps, end_learning_rate, power]
+ ) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ end_learning_rate = math_ops.cast(end_learning_rate, dtype)
+ power = math_ops.cast(power, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ decay_steps_recomp = math_ops.cast(decay_steps, dtype)
+ if cycle:
+ # Find the first multiple of decay_steps that is bigger than
+ # global_step. If global_step is zero set the multiplier to 1
+ multiplier = control_flow_ops.cond(
+ math_ops.equal(global_step_recomp, 0), lambda: 1.0,
+ lambda: math_ops.ceil(global_step_recomp / decay_steps))
+ decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
+ else:
+ # Make sure that the global_step used is not bigger than decay_steps.
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+
+ p = math_ops.div(global_step_recomp, decay_steps_recomp)
+ return math_ops.add(
+ math_ops.multiply(learning_rate - end_learning_rate,
+ math_ops.pow(1 - p, power)),
+ end_learning_rate,
+ name=name)
+
+ return functools.partial(
+ decayed_lr, learning_rate, global_step, decay_steps, end_learning_rate,
+ power, cycle, name)
+
+
+@tf_export("train.natural_exp_decay", v1=[])
+def natural_exp_decay(learning_rate,
+ global_step,
+ decay_steps,
+ decay_rate,
+ staircase=False,
+ name=None):
+ """Applies natural exponential decay to the initial learning rate.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies an exponential decay function
+ to a provided initial learning rate. It requires an `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ decayed_learning_rate = learning_rate * exp(-decay_rate * global_step /
+ decay_step)
+ ```
+
+ or, if `staircase` is `True`, as:
+
+ ```python
+ decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step /
+ decay_step))
+ ```
+
+ Example: decay exponentially with a base of 0.96:
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ learning_rate = 0.1
+ decay_steps = 5
+ k = 0.5
+ learning_rate_fn = tf.train.natural_exp_decay(learning_rate, global_step,
+ decay_steps, k)
+
+ # Passing global_step to minimize() will increment it at each step.
+ learning_step = (
+ tf.train.GradientDescentOptimizer(learning_rate_fn)
+ .minimize(...my loss..., global_step=global_step)
+ )
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A Python number.
+ Global step to use for the decay computation. Must not be negative.
+ decay_steps: How often to apply decay.
+ decay_rate: A Python number. The decay rate.
+ staircase: Whether to apply decay in a discrete staircase, as opposed to
+ continuous, fashion.
+ name: String. Optional name of the operation. Defaults to
+ 'ExponentialTimeDecay'.
+
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("global_step is required for natural_exp_decay.")
+ def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase,
+ name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "NaturalExpDecay",
+ [learning_rate, global_step, decay_rate]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ decay_rate = math_ops.cast(decay_rate, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ exponent = math_ops.exp(
+ math_ops.multiply(math_ops.negative(decay_rate), p))
+ return math_ops.multiply(learning_rate, exponent, name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ decay_rate, staircase, name)
+
+
+@tf_export("train.inverse_time_decay", v1=[])
+def inverse_time_decay(learning_rate,
+ global_step,
+ decay_steps,
+ decay_rate,
+ staircase=False,
+ name=None):
+ """Applies inverse time decay to the initial learning rate.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies an inverse decay function
+ to a provided initial learning rate. It requires an `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ decayed_learning_rate = learning_rate / (1 + decay_rate * global_step /
+ decay_step)
+ ```
+
+ or, if `staircase` is `True`, as:
+
+ ```python
+ decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step /
+ decay_step))
+ ```
+
+ Example: decay 1/t with a rate of 0.5:
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ learning_rate = 0.1
+ decay_steps = 1.0
+ decay_rate = 0.5
+ learning_rate_fn = tf.train.inverse_time_decay(learning_rate, global_step,
+ decay_steps, decay_rate)
+
+ # Passing global_step to minimize() will increment it at each step.
+ learning_step = (
+ tf.train.GradientDescentOptimizer(learning_rate_fn)
+ .minimize(...my loss..., global_step=global_step)
+ )
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A Python number.
+ Global step to use for the decay computation. Must not be negative.
+ decay_steps: How often to apply decay.
+ decay_rate: A Python number. The decay rate.
+ staircase: Whether to apply decay in a discrete staircase, as opposed to
+ continuous, fashion.
+ name: String. Optional name of the operation. Defaults to
+ 'InverseTimeDecay'.
+
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("global_step is required for inverse_time_decay.")
+ def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase,
+ name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "InverseTimeDecay",
+ [learning_rate, global_step, decay_rate]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ decay_rate = math_ops.cast(decay_rate, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ const = math_ops.cast(constant_op.constant(1), dtype)
+ denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
+ return math_ops.div(learning_rate, denom, name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ decay_rate, staircase, name)
+
+
+@tf_export("train.cosine_decay", v1=[])
+def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0,
+ name=None):
+ """Applies cosine decay to the learning rate.
+
+ See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
+ with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies a cosine decay function
+ to a provided initial learning rate. It requires a `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ global_step = min(global_step, decay_steps)
+ cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps))
+ decayed = (1 - alpha) * cosine_decay + alpha
+ decayed_learning_rate = learning_rate * decayed
+ ```
+
+ Example usage:
+ ```python
+ decay_steps = 1000
+ lr_decayed_fn = tf.train.cosine_decay(learning_rate, global_step, decay_steps)
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+ The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Number of steps to decay over.
+ alpha: A scalar `float32` or `float64` Tensor or a Python number.
+ Minimum learning rate value as a fraction of learning_rate.
+ name: String. Optional name of the operation. Defaults to 'CosineDecay'.
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("cosine decay requires global_step")
+ def decayed_lr(learning_rate, global_step, decay_steps, alpha, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "CosineDecay",
+ [learning_rate, global_step]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ completed_fraction = global_step_recomp / decay_steps
+ cosine_decayed = 0.5 * (1.0 + math_ops.cos(
+ constant_op.constant(math.pi) * completed_fraction))
+
+ decayed = (1 - alpha) * cosine_decayed + alpha
+ return math_ops.multiply(learning_rate, decayed)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ alpha, name)
+
+
+@tf_export("train.cosine_decay_restarts", v1=[])
+def cosine_decay_restarts(learning_rate,
+ global_step,
+ first_decay_steps,
+ t_mul=2.0,
+ m_mul=1.0,
+ alpha=0.0,
+ name=None):
+ """Applies cosine decay with restarts to the learning rate.
+
+ See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
+ with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies a cosine decay function with
+ restarts to a provided initial learning rate. It requires a `global_step`
+ value to compute the decayed learning rate. You can just pass a TensorFlow
+ variable that you increment at each training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate while taking into account possible warm restarts. This can be useful for
+ changing the learning rate value across different invocations of optimizer
+ functions.
+
+ The learning rate multiplier first decays
+ from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
+ restart is performed. Each new warm restart runs for `t_mul` times more steps
+ and with `m_mul` times smaller initial learning rate.
+
+ Example usage:
+ ```python
+ first_decay_steps = 1000
+ lr_decayed_fn = tf.train.cosine_decay_restarts(learning_rate, global_step,
+ first_decay_steps)
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+ The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation.
+ first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Number of steps to decay over.
+ t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
+ Used to derive the number of iterations in the i-th period
+ m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
+ Used to derive the initial learning rate of the i-th period:
+ alpha: A scalar `float32` or `float64` Tensor or a Python number.
+ Minimum learning rate value as a fraction of the learning_rate.
+ name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("cosine decay restarts requires global_step")
+ def decayed_lr(learning_rate, global_step, first_decay_steps, t_mul, m_mul,
+ alpha, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step]
+ ) as name:
+ learning_rate = ops.convert_to_tensor(
+ learning_rate, name="initial_learning_rate")
+ dtype = learning_rate.dtype
+ first_decay_steps = math_ops.cast(first_decay_steps, dtype)
+ alpha = math_ops.cast(alpha, dtype)
+ t_mul = math_ops.cast(t_mul, dtype)
+ m_mul = math_ops.cast(m_mul, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ completed_fraction = global_step_recomp / first_decay_steps
+
+ def compute_step(completed_fraction, geometric=False):
+ """Helper for `cond` operation."""
+ if geometric:
+ i_restart = math_ops.floor(
+ math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
+ math_ops.log(t_mul))
+
+ sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
+ completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
+
+ else:
+ i_restart = math_ops.floor(completed_fraction)
+ completed_fraction -= i_restart
+
+ return i_restart, completed_fraction
+
+ i_restart, completed_fraction = control_flow_ops.cond(
+ math_ops.equal(t_mul, 1.0),
+ lambda: compute_step(completed_fraction, geometric=False),
+ lambda: compute_step(completed_fraction, geometric=True))
+
+ m_fac = m_mul**i_restart
+ cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
+ constant_op.constant(math.pi) * completed_fraction))
+ decayed = (1 - alpha) * cosine_decayed + alpha
+
+ return math_ops.multiply(learning_rate, decayed, name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step,
+ first_decay_steps, t_mul, m_mul, alpha, name)
+
+
+@tf_export("train.linear_cosine_decay", v1=[])
+def linear_cosine_decay(learning_rate,
+ global_step,
+ decay_steps,
+ num_periods=0.5,
+ alpha=0.0,
+ beta=0.001,
+ name=None):
+ """Applies linear cosine decay to the learning rate.
+
+ See [Bello et al., ICML2017] Neural Optimizer Search with RL.
+ https://arxiv.org/abs/1709.07417
+
+ For the idea of warm starts here controlled by `num_periods`,
+ see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
+ with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+ Note that linear cosine decay is more aggressive than cosine decay and
+ larger initial learning rates can typically be used.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies a linear cosine decay function
+ to a provided initial learning rate. It requires a `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ global_step = min(global_step, decay_steps)
+ linear_decay = (decay_steps - global_step) / decay_steps)
+ cosine_decay = 0.5 * (
+ 1 + cos(pi * 2 * num_periods * global_step / decay_steps))
+ decayed = (alpha + linear_decay) * cosine_decay + beta
+ decayed_learning_rate = learning_rate * decayed
+ ```
+
+ Example usage:
+ ```python
+ decay_steps = 1000
+ lr_decayed_fn = tf.train.linear_cosine_decay(learning_rate, global_step,
+ decay_steps)
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+ The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Number of steps to decay over.
+ num_periods: Number of periods in the cosine part of the decay.
+ See computation above.
+ alpha: See computation above.
+ beta: See computation above.
+ name: String. Optional name of the operation. Defaults to
+ 'LinearCosineDecay'.
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("linear cosine decay requires global_step")
+ def decayed_lr(learning_rate, global_step, decay_steps, num_periods, alpha,
+ beta, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "LinearCosineDecay",
+ [learning_rate, global_step]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ num_periods = math_ops.cast(num_periods, dtype)
+ alpha = math_ops.cast(alpha, dtype)
+ beta = math_ops.cast(beta, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ linear_decayed = (decay_steps - global_step_recomp) / decay_steps
+ completed_fraction = global_step_recomp / decay_steps
+ fraction = 2.0 * num_periods * completed_fraction
+ cosine_decayed = 0.5 * (
+ 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+
+ linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
+ return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ num_periods, alpha, beta, name)
+
+
+@tf_export("train.noisy_linear_cosine_decay", v1=[])
+def noisy_linear_cosine_decay(learning_rate,
+ global_step,
+ decay_steps,
+ initial_variance=1.0,
+ variance_decay=0.55,
+ num_periods=0.5,
+ alpha=0.0,
+ beta=0.001,
+ name=None):
+ """Applies noisy linear cosine decay to the learning rate.
+
+ See [Bello et al., ICML2017] Neural Optimizer Search with RL.
+ https://arxiv.org/abs/1709.07417
+
+ For the idea of warm starts here controlled by `num_periods`,
+ see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
+ with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+ Note that linear cosine decay is more aggressive than cosine decay and
+ larger initial learning rates can typically be used.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies a noisy linear
+ cosine decay function to a provided initial learning rate.
+ It requires a `global_step` value to compute the decayed learning rate.
+ You can just pass a TensorFlow variable that you increment at each
+ training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ global_step = min(global_step, decay_steps)
+ linear_decay = (decay_steps - global_step) / decay_steps)
+ cosine_decay = 0.5 * (
+ 1 + cos(pi * 2 * num_periods * global_step / decay_steps))
+ decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta
+ decayed_learning_rate = learning_rate * decayed
+ ```
+ where eps_t is 0-centered gaussian noise with variance
+ initial_variance / (1 + global_step) ** variance_decay
+
+ Example usage:
+ ```python
+ decay_steps = 1000
+ lr_decayed_fn = tf.train.noisy_linear_cosine_decay(learning_rate, global_step,
+ decay_steps)
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+ The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Number of steps to decay over.
+ initial_variance: initial variance for the noise. See computation above.
+ variance_decay: decay for the noise's variance. See computation above.
+ num_periods: Number of periods in the cosine part of the decay.
+ See computation above.
+ alpha: See computation above.
+ beta: See computation above.
+ name: String. Optional name of the operation. Defaults to
+ 'NoisyLinearCosineDecay'.
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("noisy linear cosine decay requires global_step")
+ def decayed_lr(learning_rate, global_step, decay_steps, initial_variance,
+ variance_decay, num_periods, alpha, beta, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "NoisyLinearCosineDecay",
+ [learning_rate, global_step]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ initial_variance = math_ops.cast(initial_variance, dtype)
+ variance_decay = math_ops.cast(variance_decay, dtype)
+ num_periods = math_ops.cast(num_periods, dtype)
+ alpha = math_ops.cast(alpha, dtype)
+ beta = math_ops.cast(beta, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ linear_decayed = (decay_steps - global_step_recomp) / decay_steps
+ variance = initial_variance / (
+ math_ops.pow(1.0 + global_step_recomp, variance_decay))
+ std = math_ops.sqrt(variance)
+ noisy_linear_decayed = (
+ linear_decayed + random_ops.random_normal(
+ linear_decayed.shape, stddev=std))
+
+ completed_fraction = global_step_recomp / decay_steps
+ fraction = 2.0 * num_periods * completed_fraction
+ cosine_decayed = 0.5 * (
+ 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+ noisy_linear_cosine_decayed = (
+ (alpha + noisy_linear_decayed) * cosine_decayed + beta)
+
+ return math_ops.multiply(
+ learning_rate, noisy_linear_cosine_decayed, name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ initial_variance, variance_decay, num_periods, alpha,
+ beta, name)
diff --git a/tensorflow/python/training/learning_rate_decay_v2_test.py b/tensorflow/python/training/learning_rate_decay_v2_test.py
new file mode 100644
index 0000000000..0f2d60dafc
--- /dev/null
+++ b/tensorflow/python/training/learning_rate_decay_v2_test.py
@@ -0,0 +1,497 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functional test for learning rate decay."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import test_util
+# Import resource_variable_ops for the variables-to-tensor implicit conversion.
+from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import learning_rate_decay_v2
+
+
+class LRDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testContinuous(self):
+ self.evaluate(variables.global_variables_initializer())
+ step = 5
+ decayed_lr = learning_rate_decay_v2.exponential_decay(0.05, step, 10, 0.96)
+ expected = .05 * 0.96**(5.0 / 10.0)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testStaircase(self):
+ if context.executing_eagerly():
+ step = resource_variable_ops.ResourceVariable(0)
+ self.evaluate(variables.global_variables_initializer())
+ decayed_lr = learning_rate_decay_v2.exponential_decay(
+ .1, step, 3, 0.96, staircase=True)
+
+ # No change to learning rate due to staircase
+ expected = .1
+ self.evaluate(step.assign(1))
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ expected = .1
+ self.evaluate(step.assign(2))
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ # Decayed learning rate
+ expected = .1 * 0.96 ** (100 // 3)
+ self.evaluate(step.assign(100))
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ def testVariables(self):
+ with self.test_session():
+ step = variables.Variable(1)
+ assign_1 = step.assign(1)
+ assign_2 = step.assign(2)
+ assign_100 = step.assign(100)
+ decayed_lr = learning_rate_decay_v2.exponential_decay(.1, step, 3, 0.96,
+ staircase=True)
+ variables.global_variables_initializer().run()
+ # No change to learning rate
+ assign_1.op.run()
+ self.assertAllClose(decayed_lr().eval(), .1, 1e-6)
+ assign_2.op.run()
+ self.assertAllClose(decayed_lr().eval(), .1, 1e-6)
+ # Decayed learning rate
+ assign_100.op.run()
+ expected = .1 * 0.96 ** (100 // 3)
+ self.assertAllClose(decayed_lr().eval(), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testPiecewiseConstant(self):
+ x = resource_variable_ops.ResourceVariable(-999)
+ decayed_lr = learning_rate_decay_v2.piecewise_constant(
+ x, [100, 110, 120], [1.0, 0.1, 0.01, 0.001])
+
+ self.evaluate(variables.global_variables_initializer())
+
+ self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6)
+ self.evaluate(x.assign(100))
+ self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6)
+ self.evaluate(x.assign(105))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6)
+ self.evaluate(x.assign(110))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6)
+ self.evaluate(x.assign(120))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.01, 1e-6)
+ self.evaluate(x.assign(999))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.001, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testPiecewiseConstantEdgeCases(self):
+ x_int = resource_variable_ops.ResourceVariable(
+ 0, dtype=variables.dtypes.int32)
+ boundaries, values = [-1.0, 1.0], [1, 2, 3]
+ with self.assertRaises(ValueError):
+ decayed_lr = learning_rate_decay_v2.piecewise_constant(
+ x_int, boundaries, values)
+ decayed_lr()
+
+ x = resource_variable_ops.ResourceVariable(0.0)
+ boundaries, values = [-1.0, 1.0], [1.0, 2, 3]
+ with self.assertRaises(ValueError):
+ decayed_lr = learning_rate_decay_v2.piecewise_constant(
+ x, boundaries, values)()
+ decayed_lr()
+
+ # Test that ref types are valid.
+ if not context.executing_eagerly():
+ x = variables.Variable(0.0)
+ x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
+ boundaries, values = [1.0, 2.0], [1, 2, 3]
+ learning_rate_decay_v2.piecewise_constant(x_ref, boundaries, values)
+
+ # Test casting boundaries from int32 to int64.
+ x_int64 = resource_variable_ops.ResourceVariable(
+ 0, dtype=variables.dtypes.int64)
+ boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
+ decayed_lr = learning_rate_decay_v2.piecewise_constant(
+ x_int64, boundaries, values)
+
+ self.evaluate(variables.global_variables_initializer())
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6)
+ self.evaluate(x_int64.assign(1))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6)
+ self.evaluate(x_int64.assign(2))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.5, 1e-6)
+ self.evaluate(x_int64.assign(3))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.6, 1e-6)
+ self.evaluate(x_int64.assign(4))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.7, 1e-6)
+
+
+class LinearDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testHalfWay(self):
+ step = 5
+ lr = 0.05
+ end_lr = 0.0
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+ expected = lr * 0.5
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testEnd(self):
+ step = 10
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testHalfWayWithEnd(self):
+ step = 5
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+ expected = (lr + end_lr) * 0.5
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBeyondEnd(self):
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBeyondEndWithCycle(self):
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, cycle=True)
+ expected = (lr - end_lr) * 0.25 + end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class SqrtDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testHalfWay(self):
+ step = 5
+ lr = 0.05
+ end_lr = 0.0
+ power = 0.5
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = lr * 0.5**power
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testEnd(self):
+ step = 10
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testHalfWayWithEnd(self):
+ step = 5
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = (lr - end_lr) * 0.5**power + end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBeyondEnd(self):
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBeyondEndWithCycle(self):
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, power=power, cycle=True)
+ expected = (lr - end_lr) * 0.25**power + end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class PolynomialDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBeginWithCycle(self):
+ lr = 0.001
+ decay_steps = 10
+ step = 0
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, decay_steps, cycle=True)
+ expected = lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class ExponentialDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDecay(self):
+ initial_lr = 0.1
+ k = 10
+ decay_rate = 0.96
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay_v2.natural_exp_decay(initial_lr, step, k,
+ decay_rate)
+
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr * math.exp(-i / k * decay_rate)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testStaircase(self):
+ initial_lr = 0.1
+ k = 10
+ decay_rate = 0.96
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay_v2.natural_exp_decay(
+ initial_lr, step, k, decay_rate, staircase=True)
+
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr * math.exp(-decay_rate * (i // k))
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
+
+
+class InverseDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDecay(self):
+ initial_lr = 0.1
+ k = 10
+ decay_rate = 0.96
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay_v2.inverse_time_decay(initial_lr, step, k,
+ decay_rate)
+
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr / (1 + i / k * decay_rate)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testStaircase(self):
+ initial_lr = 0.1
+ k = 10
+ decay_rate = 0.96
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay_v2.inverse_time_decay(
+ initial_lr, step, k, decay_rate, staircase=True)
+
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr / (1 + decay_rate * (i // k))
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
+
+
+class CosineDecayTestV2(test_util.TensorFlowTestCase):
+
+ def np_cosine_decay(self, step, decay_steps, alpha=0.0):
+ step = min(step, decay_steps)
+ completed_fraction = step / decay_steps
+ decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
+ return (1.0 - alpha) * decay + alpha
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDecay(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step,
+ num_training_steps)
+ expected = self.np_cosine_decay(step, num_training_steps)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testAlpha(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ alpha = 0.1
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step,
+ num_training_steps,
+ alpha)
+ expected = self.np_cosine_decay(step, num_training_steps, alpha)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class CosineDecayRestartsTestV2(test_util.TensorFlowTestCase):
+
+ def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0,
+ alpha=0.0):
+ fac = 1.0
+ while step >= decay_steps:
+ step -= decay_steps
+ decay_steps *= t_mul
+ fac *= m_mul
+
+ completed_fraction = step / decay_steps
+ decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
+ return (1.0 - alpha) * decay + alpha
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDecay(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+ initial_lr, step, num_training_steps)
+ expected = self.np_cosine_decay_restarts(step, num_training_steps)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testAlpha(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ alpha = 0.1
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+ initial_lr, step, num_training_steps, alpha=alpha)
+ expected = self.np_cosine_decay_restarts(
+ step, num_training_steps, alpha=alpha)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testMMul(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ m_mul = 0.9
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+ initial_lr, step, num_training_steps, m_mul=m_mul)
+ expected = self.np_cosine_decay_restarts(
+ step, num_training_steps, m_mul=m_mul)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testTMul(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ t_mul = 1.0
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+ initial_lr, step, num_training_steps, t_mul=t_mul)
+ expected = self.np_cosine_decay_restarts(
+ step, num_training_steps, t_mul=t_mul)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class LinearCosineDecayTestV2(test_util.TensorFlowTestCase):
+
+ def np_linear_cosine_decay(self,
+ step,
+ decay_steps,
+ alpha=0.0,
+ beta=0.001,
+ num_periods=0.5):
+ step = min(step, decay_steps)
+ linear_decayed = float(decay_steps - step) / decay_steps
+ fraction = 2.0 * num_periods * step / float(decay_steps)
+ cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction))
+ return (alpha + linear_decayed) * cosine_decayed + beta
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDefaultDecay(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
+ initial_lr, step, num_training_steps)
+ expected = self.np_linear_cosine_decay(step, num_training_steps)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNonDefaultDecay(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
+ initial_lr,
+ step,
+ num_training_steps,
+ alpha=0.1,
+ beta=1e-4,
+ num_periods=5)
+ expected = self.np_linear_cosine_decay(
+ step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class NoisyLinearCosineDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDefaultNoisyLinearCosine(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ # No numerical check because of noise
+ decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
+ initial_lr, step, num_training_steps)
+ # Cannot be deterministically tested
+ self.evaluate(decayed_lr())
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNonDefaultNoisyLinearCosine(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ # No numerical check because of noise
+ decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
+ initial_lr,
+ step,
+ num_training_steps,
+ initial_variance=0.5,
+ variance_decay=0.1,
+ alpha=0.1,
+ beta=1e-4,
+ num_periods=5)
+ # Cannot be deterministically tested
+ self.evaluate(decayed_lr())
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py
index 9702430a12..38216ce9b1 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import argparse
+import functools
from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import renames_v2
@@ -45,6 +46,29 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# Specially handled functions.
self.function_handle = {}
+ for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant",
+ "tf.train.polynomial_decay", "tf.train.natural_exp_decay",
+ "tf.train.inverse_time_decay", "tf.train.cosine_decay",
+ "tf.train.cosine_decay_restarts",
+ "tf.train.linear_cosine_decay",
+ "tf.train.noisy_linear_cosine_decay"]:
+ self.function_handle[decay] = functools.partial(
+ self._learning_rate_decay_handler, decay_name=decay)
+
+ @staticmethod
+ def _learning_rate_decay_handler(file_edit_recorder, node, decay_name):
+ comment = ("ERROR: %s has been changed to return a callable instead of a "
+ "tensor when graph building, but its functionality remains "
+ "unchanged during eager execution (returns a callable like "
+ "before). The converter cannot detect and fix this reliably, so "
+ "you need to inspect this usage manually.\n") % decay_name
+ file_edit_recorder.add(
+ comment,
+ node.lineno,
+ node.col_offset,
+ decay_name,
+ decay_name,
+ error="%s requires manual check." % decay_name)
if __name__ == "__main__":
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
index 57ac04de06..3886c1e8b9 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
@@ -63,6 +63,19 @@ class TestUpgrade(test_util.TensorFlowTestCase):
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.math.rsqrt(tf.math.log(3.8))\n")
+ def testLearningRateDecay(self):
+ for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant",
+ "tf.train.polynomial_decay", "tf.train.natural_exp_decay",
+ "tf.train.inverse_time_decay", "tf.train.cosine_decay",
+ "tf.train.cosine_decay_restarts",
+ "tf.train.linear_cosine_decay",
+ "tf.train.noisy_linear_cosine_decay"]:
+
+ text = "%s(a, b)\n" % decay
+ _, unused_report, errors, new_text = self._upgrade(text)
+ self.assertEqual(text, new_text)
+ self.assertEqual(errors, ["test.py:1: %s requires manual check." % decay])
+
class TestUpgradeFiles(test_util.TensorFlowTestCase):