From 96d92a1ffb6de1687849dd3895d6d13783404db0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Jun 2018 11:39:35 -0700 Subject: Replace Keras clip by value and clip by norm in Keras Optimizers with native TF clip_ops, also added user input check for clipnorm and clipvalue >= 0 if set PiperOrigin-RevId: 202516320 --- tensorflow/python/keras/optimizers.py | 53 +++++++----------------------- tensorflow/python/keras/optimizers_test.py | 6 ++++ 2 files changed, 17 insertions(+), 42 deletions(-) diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py index 34951791b5..b02cafcf61 100644 --- a/tensorflow/python/keras/optimizers.py +++ b/tensorflow/python/keras/optimizers.py @@ -19,17 +19,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - import six from six.moves import zip # pylint: disable=redefined-builtin -from tensorflow.python.framework import dtypes as dtypes_module -from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import clip_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import distribute as distribute_lib @@ -39,37 +35,6 @@ from tensorflow.python.training.checkpointable import tracking as checkpointable from tensorflow.python.util.tf_export import tf_export -def clip_norm(g, c, n): - """Clip a tensor by norm. - - Arguments: - g: gradient tensor to clip. - c: clipping threshold. - n: norm of gradient tensor. - - Returns: - Clipped gradient tensor. - """ - if c > 0: - condition = n >= c - then_expression = lambda: math_ops.scalar_mul(c / n, g) - else_expression = lambda: g - - # saving the shape to avoid converting sparse tensor to dense - if isinstance(g, ops.Tensor): - g_shape = copy.copy(g.get_shape()) - elif isinstance(g, ops.IndexedSlices): - g_shape = copy.copy(g.dense_shape) - if condition.dtype != dtypes_module.bool: - condition = math_ops.cast(condition, 'bool') - g = control_flow_ops.cond(condition, then_expression, else_expression) - if isinstance(g, ops.Tensor): - g.set_shape(g_shape) - elif isinstance(g, ops.IndexedSlices): - g._dense_shape = g_shape # pylint: disable=protected-access - return g - - @tf_export('keras.optimizers.Optimizer') class Optimizer(object): """Abstract optimizer base class. @@ -91,6 +56,9 @@ class Optimizer(object): if k not in allowed_kwargs: raise TypeError('Unexpected keyword argument ' 'passed to optimizer: ' + str(k)) + # checks that clipnorm >= 0 and clipvalue >= 0 + if kwargs[k] < 0: + raise ValueError('Expected {} >= 0, received: {}'.format(k, kwargs[k])) self.__dict__.update(kwargs) self.updates = [] self.weights = [] @@ -119,12 +87,13 @@ class Optimizer(object): 'gradient defined (i.e. are differentiable). ' 'Common ops without gradient: ' 'K.argmax, K.round, K.eval.') - if hasattr(self, 'clipnorm') and self.clipnorm > 0: - norm = K.sqrt( - sum([math_ops.reduce_sum(math_ops.square(g)) for g in grads])) - grads = [clip_norm(g, self.clipnorm, norm) for g in grads] - if hasattr(self, 'clipvalue') and self.clipvalue > 0: - grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads] + if hasattr(self, 'clipnorm'): + grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads] + if hasattr(self, 'clipvalue'): + grads = [ + clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue) + for g in grads + ] return grads def set_weights(self, weights): diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py index 92b0cf3261..55fc3fdcf4 100644 --- a/tensorflow/python/keras/optimizers_test.py +++ b/tensorflow/python/keras/optimizers_test.py @@ -145,6 +145,12 @@ class KerasOptimizersTest(test.TestCase): with self.assertRaises(NotImplementedError): optimizer.from_config(None) + def test_negative_clipvalue_or_clipnorm(self): + with self.assertRaises(ValueError): + _ = keras.optimizers.SGD(lr=0.01, clipvalue=-0.5) + with self.assertRaises(ValueError): + _ = keras.optimizers.Adam(clipnorm=-2.0) + if __name__ == '__main__': test.main() -- cgit v1.2.3