aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/learning_rate_decay.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/learning_rate_decay.py')
-rw-r--r--tensorflow/python/training/learning_rate_decay.py384
1 files changed, 255 insertions, 129 deletions
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index 10ab4c1137..51190264e8 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import math
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -87,6 +88,12 @@ def exponential_decay(learning_rate,
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("global_step is required for exponential_decay.")
@@ -95,14 +102,22 @@ def exponential_decay(learning_rate,
[learning_rate, global_step, decay_steps, decay_rate]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
decay_rate = math_ops.cast(decay_rate, dtype)
- p = global_step / decay_steps
- if staircase:
- p = math_ops.floor(p)
- return math_ops.multiply(
- learning_rate, math_ops.pow(decay_rate, p), name=name)
+
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ return math_ops.multiply(
+ learning_rate, math_ops.pow(decay_rate, p), name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.piecewise_constant")
@@ -141,48 +156,62 @@ def piecewise_constant(x, boundaries, values, name=None):
ValueError: if types of `x` and `boundaries` do not match, or types of all
`values` do not match or
the number of elements in the lists does not match.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if len(boundaries) != len(values) - 1:
raise ValueError(
"The length of boundaries should be 1 less than the length of values")
with ops.name_scope(name, "PiecewiseConstant",
[x, boundaries, values, name]) as name:
- x = ops.convert_to_tensor(x)
- # Avoid explicit conversion to x's dtype. This could result in faulty
- # comparisons, for example if floats are converted to integers.
boundaries = ops.convert_n_to_tensor(boundaries)
- for i, b in enumerate(boundaries):
- if b.dtype.base_dtype != x.dtype.base_dtype:
- # We can promote int32 boundaries to int64 without loss of precision.
- # This covers the most common case where the user passes in boundaries
- # as an array of Python integers.
- if (b.dtype.base_dtype == dtypes.int32 and
- x.dtype.base_dtype == dtypes.int64):
- b = math_ops.cast(b, x.dtype.base_dtype)
- boundaries[i] = b
- else:
- raise ValueError(
- "Boundaries (%s) must have the same dtype as x (%s)." %
- (b.dtype.base_dtype, x.dtype.base_dtype))
- # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing.
values = ops.convert_n_to_tensor(values)
- for v in values[1:]:
- if v.dtype.base_dtype != values[0].dtype.base_dtype:
- raise ValueError(
- "Values must have elements all with the same dtype (%s vs %s)." %
- (values[0].dtype.base_dtype, v.dtype.base_dtype))
- pred_fn_pairs = []
- pred_fn_pairs.append((x <= boundaries[0], lambda: values[0]))
- pred_fn_pairs.append((x > boundaries[-1], lambda: values[-1]))
- for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
- # Need to bind v here; can do this with lambda v=v: ...
- pred = (x > low) & (x <= high)
- pred_fn_pairs.append((pred, lambda v=v: v))
-
- # The default isn't needed here because our conditions are mutually
- # exclusive and exhaustive, but tf.case requires it.
- default = lambda: values[0]
- return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
+
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ x_recomp = ops.convert_to_tensor(x)
+ # Avoid explicit conversion to x's dtype. This could result in faulty
+ # comparisons, for example if floats are converted to integers.
+ for i, b in enumerate(boundaries):
+ if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
+ # We can promote int32 boundaries to int64 without loss of precision.
+ # This covers the most common case where the user passes in boundaries
+ # as an array of Python integers.
+ if (b.dtype.base_dtype == dtypes.int32 and
+ x_recomp.dtype.base_dtype == dtypes.int64):
+ b = math_ops.cast(b, x_recomp.dtype.base_dtype)
+ boundaries[i] = b
+ else:
+ raise ValueError(
+ "Boundaries (%s) must have the same dtype as x (%s)." %
+ (b.dtype.base_dtype, x_recomp.dtype.base_dtype))
+ # TODO(rdipietro): Ensure that boundaries' elements strictly increases.
+ for v in values[1:]:
+ if v.dtype.base_dtype != values[0].dtype.base_dtype:
+ raise ValueError(
+ "Values must have elements all with the same dtype (%s vs %s)." %
+ (values[0].dtype.base_dtype, v.dtype.base_dtype))
+ pred_fn_pairs = []
+ pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
+ pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
+ for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
+ # Need to bind v here; can do this with lambda v=v: ...
+ pred = (x_recomp > low) & (x_recomp <= high)
+ pred_fn_pairs.append((pred, lambda v=v: v))
+
+ # The default isn't needed here because our conditions are mutually
+ # exclusive and exhaustive, but tf.case requires it.
+ default = lambda: values[0]
+ return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.polynomial_decay")
@@ -263,6 +292,12 @@ def polynomial_decay(learning_rate,
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("global_step is required for polynomial_decay.")
@@ -272,27 +307,35 @@ def polynomial_decay(learning_rate,
]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
- decay_steps = math_ops.cast(decay_steps, dtype)
end_learning_rate = math_ops.cast(end_learning_rate, dtype)
power = math_ops.cast(power, dtype)
- if cycle:
- # Find the first multiple of decay_steps that is bigger than global_step.
- # If global_step is zero set the multiplier to 1
- multiplier = control_flow_ops.cond(
- math_ops.equal(global_step, 0), lambda: 1.0,
- lambda: math_ops.ceil(global_step / decay_steps))
- decay_steps = math_ops.multiply(decay_steps, multiplier)
- else:
- # Make sure that the global_step used is not bigger than decay_steps.
- global_step = math_ops.minimum(global_step, decay_steps)
-
- p = math_ops.div(global_step, decay_steps)
- return math_ops.add(
- math_ops.multiply(learning_rate - end_learning_rate,
- math_ops.pow(1 - p, power)),
- end_learning_rate,
- name=name)
+
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ decay_steps_recomp = math_ops.cast(decay_steps, dtype)
+ if cycle:
+ # Find the first multiple of decay_steps that is bigger than
+ # global_step. If global_step is zero set the multiplier to 1
+ multiplier = control_flow_ops.cond(
+ math_ops.equal(global_step_recomp, 0), lambda: 1.0,
+ lambda: math_ops.ceil(global_step_recomp / decay_steps))
+ decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
+ else:
+ # Make sure that the global_step used is not bigger than decay_steps.
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+
+ p = math_ops.div(global_step_recomp, decay_steps_recomp)
+ return math_ops.add(
+ math_ops.multiply(learning_rate - end_learning_rate,
+ math_ops.pow(1 - p, power)),
+ end_learning_rate,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.natural_exp_decay")
@@ -350,6 +393,12 @@ def natural_exp_decay(learning_rate,
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("global_step is required for natural_exp_decay.")
@@ -357,14 +406,23 @@ def natural_exp_decay(learning_rate,
[learning_rate, global_step, decay_rate]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
decay_rate = math_ops.cast(decay_rate, dtype)
- p = global_step / decay_steps
- if staircase:
- p = math_ops.floor(p)
- exponent = math_ops.exp(math_ops.multiply(math_ops.negative(decay_rate), p))
- return math_ops.multiply(learning_rate, exponent, name=name)
+
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ exponent = math_ops.exp(
+ math_ops.multiply(math_ops.negative(decay_rate), p))
+ return math_ops.multiply(learning_rate, exponent, name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.inverse_time_decay")
@@ -432,6 +490,12 @@ def inverse_time_decay(learning_rate,
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("global_step is required for inverse_time_decay.")
@@ -439,15 +503,23 @@ def inverse_time_decay(learning_rate,
[learning_rate, global_step, decay_rate]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
decay_rate = math_ops.cast(decay_rate, dtype)
- p = global_step / decay_steps
- if staircase:
- p = math_ops.floor(p)
- const = math_ops.cast(constant_op.constant(1), learning_rate.dtype)
- denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
- return math_ops.div(learning_rate, denom, name=name)
+
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ const = math_ops.cast(constant_op.constant(1), dtype)
+ denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
+ return math_ops.div(learning_rate, denom, name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.cosine_decay")
@@ -492,6 +564,12 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
learning rate.
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("cosine decay requires global_step")
@@ -499,15 +577,23 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
[learning_rate, global_step]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
- global_step = math_ops.minimum(global_step, decay_steps)
- completed_fraction = global_step / decay_steps
- cosine_decayed = 0.5 * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction))
- decayed = (1 - alpha) * cosine_decayed + alpha
- return math_ops.multiply(learning_rate, decayed)
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ completed_fraction = global_step_recomp / decay_steps
+ cosine_decayed = 0.5 * (1.0 + math_ops.cos(
+ constant_op.constant(math.pi) * completed_fraction))
+
+ decayed = (1 - alpha) * cosine_decayed + alpha
+ return math_ops.multiply(learning_rate, decayed)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.cosine_decay_restarts")
@@ -561,6 +647,12 @@ def cosine_decay_restarts(learning_rate,
learning rate.
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("cosine decay restarts requires global_step")
@@ -568,40 +660,48 @@ def cosine_decay_restarts(learning_rate,
learning_rate = ops.convert_to_tensor(
learning_rate, name="initial_learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
first_decay_steps = math_ops.cast(first_decay_steps, dtype)
alpha = math_ops.cast(alpha, dtype)
t_mul = math_ops.cast(t_mul, dtype)
m_mul = math_ops.cast(m_mul, dtype)
- completed_fraction = global_step / first_decay_steps
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ completed_fraction = global_step_recomp / first_decay_steps
- def compute_step(completed_fraction, geometric=False):
- if geometric:
- i_restart = math_ops.floor(
- math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
- math_ops.log(t_mul))
+ def compute_step(completed_fraction, geometric=False):
+ """Helper for `cond` operation."""
+ if geometric:
+ i_restart = math_ops.floor(
+ math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
+ math_ops.log(t_mul))
- sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
- completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
+ sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
+ completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
- else:
- i_restart = math_ops.floor(completed_fraction)
- completed_fraction = completed_fraction - i_restart
+ else:
+ i_restart = math_ops.floor(completed_fraction)
+ completed_fraction -= i_restart
+
+ return i_restart, completed_fraction
- return i_restart, completed_fraction
+ i_restart, completed_fraction = control_flow_ops.cond(
+ math_ops.equal(t_mul, 1.0),
+ lambda: compute_step(completed_fraction, geometric=False),
+ lambda: compute_step(completed_fraction, geometric=True))
- i_restart, completed_fraction = control_flow_ops.cond(
- math_ops.equal(t_mul, 1.0),
- lambda: compute_step(completed_fraction, geometric=False),
- lambda: compute_step(completed_fraction, geometric=True))
+ m_fac = m_mul**i_restart
+ cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
+ constant_op.constant(math.pi) * completed_fraction))
+ decayed = (1 - alpha) * cosine_decayed + alpha
- m_fac = m_mul**i_restart
- cosine_decayed = 0.5 * m_fac * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction))
- decayed = (1 - alpha) * cosine_decayed + alpha
+ return math_ops.multiply(learning_rate, decayed, name=name)
- return math_ops.multiply(learning_rate, decayed, name=name)
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.linear_cosine_decay")
@@ -664,6 +764,12 @@ def linear_cosine_decay(learning_rate,
learning rate.
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("linear cosine decay requires global_step")
@@ -671,21 +777,28 @@ def linear_cosine_decay(learning_rate,
[learning_rate, global_step]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
num_periods = math_ops.cast(num_periods, dtype)
- global_step = math_ops.minimum(global_step, decay_steps)
alpha = math_ops.cast(alpha, dtype)
beta = math_ops.cast(beta, dtype)
- linear_decayed = (decay_steps - global_step) / decay_steps
- completed_fraction = global_step / decay_steps
- fraction = 2.0 * num_periods * completed_fraction
- cosine_decayed = 0.5 * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ linear_decayed = (decay_steps - global_step_recomp) / decay_steps
+ completed_fraction = global_step_recomp / decay_steps
+ fraction = 2.0 * num_periods * completed_fraction
+ cosine_decayed = 0.5 * (
+ 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+
+ linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
+ return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
- linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
- return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.noisy_linear_cosine_decay")
@@ -756,6 +869,12 @@ def noisy_linear_cosine_decay(learning_rate,
learning rate.
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("noisy linear cosine decay requires global_step")
@@ -763,29 +882,36 @@ def noisy_linear_cosine_decay(learning_rate,
[learning_rate, global_step]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
- global_step = math_ops.minimum(global_step, decay_steps)
initial_variance = math_ops.cast(initial_variance, dtype)
variance_decay = math_ops.cast(variance_decay, dtype)
num_periods = math_ops.cast(num_periods, dtype)
alpha = math_ops.cast(alpha, dtype)
beta = math_ops.cast(beta, dtype)
- linear_decayed = (decay_steps - global_step) / decay_steps
- variance = initial_variance / (
- math_ops.pow(1.0 + global_step, variance_decay))
- std = math_ops.sqrt(variance)
- noisy_linear_decayed = (
- linear_decayed +
- random_ops.random_normal(linear_decayed.shape, stddev=std))
-
- completed_fraction = global_step / decay_steps
- fraction = 2.0 * num_periods * completed_fraction
- cosine_decayed = 0.5 * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
- noisy_linear_cosine_decayed = (
- (alpha + noisy_linear_decayed) * cosine_decayed + beta)
-
- return math_ops.multiply(
- learning_rate, noisy_linear_cosine_decayed, name=name)
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ linear_decayed = (decay_steps - global_step_recomp) / decay_steps
+ variance = initial_variance / (
+ math_ops.pow(1.0 + global_step_recomp, variance_decay))
+ std = math_ops.sqrt(variance)
+ noisy_linear_decayed = (
+ linear_decayed + random_ops.random_normal(
+ linear_decayed.shape, stddev=std))
+
+ completed_fraction = global_step_recomp / decay_steps
+ fraction = 2.0 * num_periods * completed_fraction
+ cosine_decayed = 0.5 * (
+ 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+ noisy_linear_cosine_decayed = (
+ (alpha + noisy_linear_decayed) * cosine_decayed + beta)
+
+ return math_ops.multiply(
+ learning_rate, noisy_linear_cosine_decayed, name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr