diff options
author | Mingxing Tan <tanmingxing@google.com> | 2018-06-28 19:13:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-28 19:16:41 -0700 |
commit | 1e7b0e4ad6d0f57f3241fe0b80a65f2c2a7f11b0 (patch) | |
tree | af92d172cedfc41e544c01a349c1d3b30bc3ff85 /tensorflow/contrib/opt | |
parent | 3cee10e61c1c90734317c62ea3388ec44acc8d08 (diff) |
Merge changes from github.
PiperOrigin-RevId: 202585094
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r-- | tensorflow/contrib/opt/BUILD | 20 | ||||
-rw-r--r-- | tensorflow/contrib/opt/__init__.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/opt/python/training/weight_decay_optimizers.py | 362 | ||||
-rw-r--r-- | tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py | 188 |
4 files changed, 578 insertions, 3 deletions
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 4f35de4e5d..bbdf962d04 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -29,6 +29,7 @@ py_library( "python/training/reg_adagrad_optimizer.py", "python/training/sign_decay.py", "python/training/variable_clipping_optimizer.py", + "python/training/weight_decay_optimizers.py", ], srcs_version = "PY2AND3", deps = [ @@ -198,6 +199,25 @@ py_test( ], ) +py_test( + name = "weight_decay_optimizers_test", + srcs = ["python/training/weight_decay_optimizers_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:session", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "drop_stale_gradient_optimizer_test", srcs = ["python/training/drop_stale_gradient_optimizer_test.py"], diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index b41148329d..65777b1323 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -22,16 +22,17 @@ from __future__ import print_function from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import * +from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * from tensorflow.contrib.opt.python.training.external_optimizer import * +from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * -from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * -from tensorflow.contrib.opt.python.training.model_average_optimizer import * -from tensorflow.contrib.opt.python.training.ggt import * +from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -47,6 +48,10 @@ _allowed_symbols = [ 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', + 'MomentumWOptimizer', + 'AdamWOptimizer', + 'DecoupledWeightDecayExtension', + 'extend_with_decoupled_weight_decay', 'ScipyOptimizerInterface', 'VariableClippingOptimizer', 'MultitaskOptimizerWrapper', diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py new file mode 100644 index 0000000000..b9cf40eb7b --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -0,0 +1,362 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Base class to make optimizers weight decay ready.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import adam +from tensorflow.python.training import momentum as momentum_opt +from tensorflow.python.training import optimizer +from tensorflow.python.util.tf_export import tf_export + + +class DecoupledWeightDecayExtension(object): + """This class allows to extend optimizers with decoupled weight decay. + + It implements the decoupled weight decay described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is + decoupled from the optimization steps w.r.t. to the loss function. + For SGD variants, this simplifies hyperparameter search since it decouples + the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + This class alone is not an optimizer but rather extends existing + optimizers with decoupled weight decay. We explicitly define the two examples + used in the above paper (SGDW and AdamW), but in general this can extend + any OptimizerX by using + `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`. + In order for it to work, it must be the first class the Optimizer with + weight decay inherits from, e.g. + + ```python + class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + def __init__(self, weight_decay, *args, **kwargs): + super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs). + ``` + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + """ + + def __init__(self, weight_decay, **kwargs): + """Construct the extension class that adds weight decay to an optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value, the factor by which + a variable is decayed in the update step. + **kwargs: Optional list or tuple or set of `Variable` objects to + decay. + """ + self._decay_var_list = None # is set in minimize or apply_gradients + self._weight_decay = weight_decay + # The tensors are initialized in call to _prepare + self._weight_decay_tensor = None + super(DecoupledWeightDecayExtension, self).__init__(**kwargs) + + def minimize(self, loss, global_step=None, var_list=None, + gate_gradients=optimizer.Optimizer.GATE_OP, + aggregation_method=None, colocate_gradients_with_ops=False, + name=None, grad_loss=None, decay_var_list=None): + """Add operations to minimize `loss` by updating `var_list` with decay. + + This function is the same as Optimizer.minimize except that it allows to + specify the variables that should be decayed using decay_var_list. + If decay_var_list is None, all variables in var_list are decayed. + + For more information see the documentation of Optimizer.minimize. + + Args: + loss: A `Tensor` containing the value to minimize. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + var_list: Optional list or tuple of `Variable` objects to update to + minimize `loss`. Defaults to the list of variables collected in + the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. + gate_gradients: How to gate the computation of gradients. Can be + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with + the corresponding op. + name: Optional name for the returned operation. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + decay_var_list: Optional list of decay variables. + + Returns: + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. + + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).minimize( + loss, global_step=global_step, var_list=var_list, + gate_gradients=gate_gradients, aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, name=name, + grad_loss=grad_loss) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None, + decay_var_list=None): + """Apply gradients to variables and decay the variables. + + This function is the same as Optimizer.apply_gradients except that it + allows to specify the variables that should be decayed using + decay_var_list. If decay_var_list is None, all variables in var_list + are decayed. + + For more information see the documentation of Optimizer.apply_gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + decay_var_list: Optional list of decay variables. + + Returns: + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).apply_gradients( + grads_and_vars, global_step=global_step, name=name) + + def _prepare(self): + weight_decay = self._weight_decay + if callable(weight_decay): + weight_decay = weight_decay() + self._weight_decay_tensor = ops.convert_to_tensor( + weight_decay, name="weight_decay") + # Call the optimizers _prepare function. + super(DecoupledWeightDecayExtension, self)._prepare() + + def _decay_weights_op(self, var): + if not self._decay_var_list or var in self._decay_var_list: + return var.assign_sub(self._weight_decay * var, self._use_locking) + return control_flow_ops.no_op() + + def _decay_weights_sparse_op(self, var, indices, scatter_add): + if not self._decay_var_list or var in self._decay_var_list: + return scatter_add(var, indices, -self._weight_decay * var, + self._use_locking) + return control_flow_ops.no_op() + + # Here, we overwrite the apply functions that the base optimizer calls. + # super().apply_x resolves to the apply_x function of the BaseOptimizer. + def _apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var) + + def _resource_apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_dense( + grad, var) + + def _apply_sparse(self, grad, var): + scatter_add = state_ops.scatter_add + decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._apply_sparse( + grad, var) + + def _resource_scatter_add(self, x, i, v, _=None): + # last argument allows for one overflow argument, to have the same function + # signature as state_ops.scatter_add + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + scatter_add = self._resource_scatter_add + decay_op = self._decay_weights_sparse_op(var, indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse( + grad, var, indices) + + +def extend_with_decoupled_weight_decay(base_optimizer): + """Factory function returning an optimizer class with decoupled weight decay. + + Returns an optimizer class. An instance of the returned class computes the + update step of `base_optimizer` and additionally decays the weights. + E.g., the class returned by + `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to + `tf.contrib.opt.AdamWOptimizer`. + + The API of the new optimizer class slightly differs from the API of the + base optimizer: + - The first argument to the constructor is the weight decay rate. + - `minimize` and `apply_gradients` accept the optional keyword argument + `decay_var_list`, which specifies the variables that should be decayed. + If `None`, all variables that are optimized are decayed. + + Usage example: + ```python + # MyAdamW is a new class + MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) + # Create a MyAdamW object + optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) + sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + ``` + + Args: + base_optimizer: An optimizer class that inherits from tf.train.Optimizer. + + Returns: + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. + """ + + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, + base_optimizer): + """Base_optimizer with decoupled weight decay. + + This class computes the update step of `base_optimizer` and + additionally decays the variable with the weight decay being decoupled from + the optimization steps w.r.t. to the loss function, as described by + Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf). + For SGD variants, this simplifies hyperparameter search since + it decouples the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + """ + + def __init__(self, weight_decay, *args, **kwargs): + # super delegation is necessary here + # pylint: disable=useless-super-delegation + super(OptimizerWithDecoupledWeightDecay, self).__init__( + weight_decay, *args, **kwargs) + # pylint: enable=useless-super-delegation + + return OptimizerWithDecoupledWeightDecay + + +@tf_export("contrib.opt.MomentumWOptimizer") +class MomentumWOptimizer(DecoupledWeightDecayExtension, + momentum_opt.MomentumOptimizer): + """Optimizer that implements the Momentum algorithm with weight_decay. + + This is an implementation of the SGDW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + It computes the update step of `train.MomentumOptimizer` and additionally + decays the variable. Note that this is different from adding + L2 regularization on the variables to the loss. Decoupling the weight decay + from other hyperparameters (in particular the learning rate) simplifies + hyperparameter search. + + For further information see the documentation of the Momentum Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.MomentumOptimizer, + weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate, momentum, + use_locking=False, name="MomentumW", use_nesterov=False): + """Construct a new MomentumW optimizer. + + For further information see the documentation of the Momentum Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A `Tensor` or a floating point value. The learning rate. + momentum: A `Tensor` or a floating point value. The momentum. + use_locking: If `True` use locks for update operations. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "Momentum". + use_nesterov: If `True` use Nesterov Momentum. + See [Sutskever et al., 2013]( + http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). + This implementation always computes gradients at the value of the + variable(s) passed to the optimizer. Using Nesterov Momentum makes the + variable(s) track the values called `theta_t + mu*v_t` in the paper. + + @compatibility(eager) + When eager execution is enabled, learning_rate, weight_decay and momentum + can each be a callable that takes no arguments and returns the actual value + to use. This can be useful for changing these values across different + invocations of optimizer functions. + @end_compatibility + """ + super(MomentumWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, momentum=momentum, + use_locking=use_locking, name=name, use_nesterov=use_nesterov) + + +@tf_export("contrib.opt.AdamWOptimizer") +class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + """Optimizer that implements the Adam algorithm with weight decay. + + This is an implementation of the AdamW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + + It computes the update step of `train.AdamOptimizer` and additionally decays + the variable. Note that this is different from adding L2 regularization on + the variables to the loss: it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + For further information see the documentation of the Adam Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999, + epsilon=1e-8, use_locking=False, name="AdamW"): + """Construct a new AdamW optimizer. + + For further information see the documentation of the Adam Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + """ + super(AdamWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2, + epsilon=epsilon, use_locking=use_locking, name=name) diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py new file mode 100644 index 0000000000..76d8a5697a --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py @@ -0,0 +1,188 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optimizers with weight decay.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import weight_decay_optimizers +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adam + +WEIGHT_DECAY = 0.01 + + +def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9, + beta2=0.999, epsilon=1e-8): + lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) - + (param * WEIGHT_DECAY)) + return param_t, m_t, v_t + + +def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_): + # v, t are not needed for momentum optimizer + m = momentum * m + g_t + param_t = param - lr * m - param * WEIGHT_DECAY + return param_t, m, None + + +class WeightDecayOptimizerTest(test.TestCase): + + def doTest(self, optimizer, update_fn, optimizer_name, slot_name, + use_resource=False, do_sparse=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.test_session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + if do_sparse: + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices(constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), + constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices(constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), + constant_op.constant([2])) + else: + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = optimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of the optimizer + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0) + var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/%s:0" % (i, optimizer_name), + opt.get_slot(var=var0, name=slot_name).name) + + +class AdamWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY) + + def testSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True) + + +class MomentumWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9) + + def testSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True) + + +class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay( + adam.AdamOptimizer) + return adamw(WEIGHT_DECAY) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=True) + + +if __name__ == "__main__": + test.main() |