aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-28 11:39:35 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commit96d92a1ffb6de1687849dd3895d6d13783404db0 (patch)
tree3fa4bb0f211221f671daa888436b51fedae2d69a
parent93dc1dfe1303e4a33e53c66ef84ad15f4953568c (diff)
Replace Keras clip by value and clip by norm in Keras Optimizers with native TF clip_ops, also added user input check for clipnorm and clipvalue >= 0 if set
PiperOrigin-RevId: 202516320
-rw-r--r--tensorflow/python/keras/optimizers.py53
-rw-r--r--tensorflow/python/keras/optimizers_test.py6
2 files changed, 17 insertions, 42 deletions
diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py
index 34951791b5..b02cafcf61 100644
--- a/tensorflow/python/keras/optimizers.py
+++ b/tensorflow/python/keras/optimizers.py
@@ -19,17 +19,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import copy
-
import six
from six.moves import zip # pylint: disable=redefined-builtin
-from tensorflow.python.framework import dtypes as dtypes_module
-from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
-from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import distribute as distribute_lib
@@ -39,37 +35,6 @@ from tensorflow.python.training.checkpointable import tracking as checkpointable
from tensorflow.python.util.tf_export import tf_export
-def clip_norm(g, c, n):
- """Clip a tensor by norm.
-
- Arguments:
- g: gradient tensor to clip.
- c: clipping threshold.
- n: norm of gradient tensor.
-
- Returns:
- Clipped gradient tensor.
- """
- if c > 0:
- condition = n >= c
- then_expression = lambda: math_ops.scalar_mul(c / n, g)
- else_expression = lambda: g
-
- # saving the shape to avoid converting sparse tensor to dense
- if isinstance(g, ops.Tensor):
- g_shape = copy.copy(g.get_shape())
- elif isinstance(g, ops.IndexedSlices):
- g_shape = copy.copy(g.dense_shape)
- if condition.dtype != dtypes_module.bool:
- condition = math_ops.cast(condition, 'bool')
- g = control_flow_ops.cond(condition, then_expression, else_expression)
- if isinstance(g, ops.Tensor):
- g.set_shape(g_shape)
- elif isinstance(g, ops.IndexedSlices):
- g._dense_shape = g_shape # pylint: disable=protected-access
- return g
-
-
@tf_export('keras.optimizers.Optimizer')
class Optimizer(object):
"""Abstract optimizer base class.
@@ -91,6 +56,9 @@ class Optimizer(object):
if k not in allowed_kwargs:
raise TypeError('Unexpected keyword argument '
'passed to optimizer: ' + str(k))
+ # checks that clipnorm >= 0 and clipvalue >= 0
+ if kwargs[k] < 0:
+ raise ValueError('Expected {} >= 0, received: {}'.format(k, kwargs[k]))
self.__dict__.update(kwargs)
self.updates = []
self.weights = []
@@ -119,12 +87,13 @@ class Optimizer(object):
'gradient defined (i.e. are differentiable). '
'Common ops without gradient: '
'K.argmax, K.round, K.eval.')
- if hasattr(self, 'clipnorm') and self.clipnorm > 0:
- norm = K.sqrt(
- sum([math_ops.reduce_sum(math_ops.square(g)) for g in grads]))
- grads = [clip_norm(g, self.clipnorm, norm) for g in grads]
- if hasattr(self, 'clipvalue') and self.clipvalue > 0:
- grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads]
+ if hasattr(self, 'clipnorm'):
+ grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
+ if hasattr(self, 'clipvalue'):
+ grads = [
+ clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
+ for g in grads
+ ]
return grads
def set_weights(self, weights):
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 92b0cf3261..55fc3fdcf4 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -145,6 +145,12 @@ class KerasOptimizersTest(test.TestCase):
with self.assertRaises(NotImplementedError):
optimizer.from_config(None)
+ def test_negative_clipvalue_or_clipnorm(self):
+ with self.assertRaises(ValueError):
+ _ = keras.optimizers.SGD(lr=0.01, clipvalue=-0.5)
+ with self.assertRaises(ValueError):
+ _ = keras.optimizers.Adam(clipnorm=-2.0)
+
if __name__ == '__main__':
test.main()