aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/opt/python/training/weight_decay_optimizers.py')
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers.py44
1 files changed, 40 insertions, 4 deletions
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
index 8aa40aeb45..b9cf40eb7b 100644
--- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
@@ -19,13 +19,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
-from tensorflow.python.training import optimizer
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.training import adam
from tensorflow.python.training import momentum as momentum_opt
+from tensorflow.python.training import optimizer
from tensorflow.python.util.tf_export import tf_export
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import resource_variable_ops
class DecoupledWeightDecayExtension(object):
@@ -65,7 +65,7 @@ class DecoupledWeightDecayExtension(object):
Args:
weight_decay: A `Tensor` or a floating point value, the factor by which
a variable is decayed in the update step.
- decay_var_list: Optional list or tuple or set of `Variable` objects to
+ **kwargs: Optional list or tuple or set of `Variable` objects to
decay.
"""
self._decay_var_list = None # is set in minimize or apply_gradients
@@ -85,6 +85,28 @@ class DecoupledWeightDecayExtension(object):
If decay_var_list is None, all variables in var_list are decayed.
For more information see the documentation of Optimizer.minimize.
+
+ Args:
+ loss: A `Tensor` containing the value to minimize.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ var_list: Optional list or tuple of `Variable` objects to update to
+ minimize `loss`. Defaults to the list of variables collected in
+ the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ name: Optional name for the returned operation.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ decay_var_list: Optional list of decay variables.
+
+ Returns:
+ An Operation that updates the variables in `var_list`. If `global_step`
+ was not `None`, that operation also increments `global_step`.
+
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).minimize(
@@ -103,6 +125,19 @@ class DecoupledWeightDecayExtension(object):
are decayed.
For more information see the documentation of Optimizer.apply_gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the `Optimizer` constructor.
+ decay_var_list: Optional list of decay variables.
+
+ Returns:
+ An `Operation` that applies the specified gradients. If `global_step`
+ was not None, that operation also increments `global_step`.
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).apply_gradients(
@@ -197,6 +232,7 @@ def extend_with_decoupled_weight_decay(base_optimizer):
A new optimizer class that inherits from DecoupledWeightDecayExtension
and base_optimizer.
"""
+
class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
base_optimizer):
"""Base_optimizer with decoupled weight decay.