From 2c766bbfabf477432b738051e836ece4062962e0 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 8 Dec 2016 00:08:26 -0800 Subject: Make base core layer match the behavior of tf.get_variable with regard to regularization. Change: 141405902 --- .../contrib/layers/python/layers/layers_test.py | 30 +++++++++++++++++ tensorflow/python/layers/base.py | 38 ++++++++++++---------- tensorflow/python/layers/core_test.py | 12 +++++++ 3 files changed, 63 insertions(+), 17 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 84082070d3..009fbdc485 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1467,6 +1467,36 @@ class FCTest(tf.test.TestCase): self.assertEqual( len(tf.contrib.framework.get_variables('fully_connected')), 4) + def testReuseWithRegularizer(self): + height, width = 3, 3 + regularizer = lambda x: tf.reduce_sum(x) * 1e-3 + inputs = tf.random_uniform((5, height * width * 3), seed=1) + + tf.contrib.layers.fully_connected(inputs, 32, scope='fc1', + weights_regularizer=regularizer) + self.assertEqual( + len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1) + self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 1) + tf.contrib.layers.fully_connected(inputs, 32, scope='fc1', + weights_regularizer=regularizer, + reuse=True) + self.assertEqual( + len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1) + self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 1) + + with tf.variable_scope('outer', reuse=False): + tf.contrib.layers.fully_connected(inputs, 32, + weights_regularizer=regularizer) + self.assertEqual( + len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 2) + self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 2) + with tf.variable_scope('outer', reuse=True): + tf.contrib.layers.fully_connected(inputs, 32, + weights_regularizer=regularizer) + self.assertEqual( + len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 2) + self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 2) + def testCreateFCWithoutActivation(self): height, width = 3, 3 with self.test_session(): diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 17e999b736..232763c758 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -199,34 +199,38 @@ class _Layer(object): """ if dtype is None: dtype = self._dtype + existing_variables = set(tf_variables.global_variables()) variable = variable_getter(name, shape=shape, initializer=initializer, dtype=dtype, trainable=trainable and self.trainable) # TODO(sguada) fix name = variable.op.name - if trainable: - self._trainable_variables.append(variable) - else: - self._non_trainable_variables.append(variable) - if regularizer and not self._reuse: - if isinstance(variable, tf_variables.PartitionedVariable): - for v in variable: - with ops.colocate_with(v.op): + if regularizer: + if not self._reuse and variable not in existing_variables: + # To match the behavior of tf.get_variable(), we only + # apply regularization if the variable is newly created. + if isinstance(variable, tf_variables.PartitionedVariable): + for v in variable: + with ops.colocate_with(v.op): + with ops.name_scope(name + '/Regularizer'): + regularization = regularizer(v) + if regularization is not None: + self._losses.append(regularization) + _add_elements_to_collection( + regularization, ops.GraphKeys.REGULARIZATION_LOSSES) + else: + with ops.colocate_with(variable.op): with ops.name_scope(name + '/Regularizer'): - regularization = regularizer(v) + regularization = regularizer(variable) if regularization is not None: self._losses.append(regularization) _add_elements_to_collection( regularization, ops.GraphKeys.REGULARIZATION_LOSSES) - else: - with ops.colocate_with(variable.op): - with ops.name_scope(name + '/Regularizer'): - regularization = regularizer(variable) - if regularization is not None: - self._losses.append(regularization) - _add_elements_to_collection( - regularization, ops.GraphKeys.REGULARIZATION_LOSSES) + if trainable: + self._trainable_variables.append(variable) + else: + self._non_trainable_variables.append(variable) return variable def __call__(self, inputs, **kwargs): diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index d3b0ee1550..dd60b437b6 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -151,6 +151,18 @@ class DenseTest(tf.test.TestCase): self.assertEqual(len(loss_keys), 1) self.assertListEqual(dense.losses, loss_keys) + def testWeightsRegularizerWithReuse(self): + regularizer = lambda x: tf.reduce_sum(x) * 1e-3 + inputs = tf.random_uniform((5, 3), seed=1) + _ = core_layers.dense(inputs, 2, name='my_dense', + weights_regularizer=regularizer) + self.assertEqual( + len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1) + _ = core_layers.dense(inputs, 2, name='my_dense', + weights_regularizer=regularizer, reuse=True) + self.assertEqual( + len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1) + def testBiasRegularizer(self): regularizer = lambda x: tf.reduce_sum(x) * 1e-3 dense = core_layers.Dense(2, name='my_dense', -- cgit v1.2.3