aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2016-12-08 00:08:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-08 00:23:29 -0800
commit2c766bbfabf477432b738051e836ece4062962e0 (patch)
treeb78cc38c6f6c97d2008643b50d73116e9674db84
parent6bfc29d4fb28cfcae2c0a49dbb451a075863fa68 (diff)
Make base core layer match the behavior of tf.get_variable with regard to regularization.
Change: 141405902
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py30
-rw-r--r--tensorflow/python/layers/base.py38
-rw-r--r--tensorflow/python/layers/core_test.py12
3 files changed, 63 insertions, 17 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 84082070d3..009fbdc485 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1467,6 +1467,36 @@ class FCTest(tf.test.TestCase):
self.assertEqual(
len(tf.contrib.framework.get_variables('fully_connected')), 4)
+ def testReuseWithRegularizer(self):
+ height, width = 3, 3
+ regularizer = lambda x: tf.reduce_sum(x) * 1e-3
+ inputs = tf.random_uniform((5, height * width * 3), seed=1)
+
+ tf.contrib.layers.fully_connected(inputs, 32, scope='fc1',
+ weights_regularizer=regularizer)
+ self.assertEqual(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+ self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 1)
+ tf.contrib.layers.fully_connected(inputs, 32, scope='fc1',
+ weights_regularizer=regularizer,
+ reuse=True)
+ self.assertEqual(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+ self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 1)
+
+ with tf.variable_scope('outer', reuse=False):
+ tf.contrib.layers.fully_connected(inputs, 32,
+ weights_regularizer=regularizer)
+ self.assertEqual(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 2)
+ self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 2)
+ with tf.variable_scope('outer', reuse=True):
+ tf.contrib.layers.fully_connected(inputs, 32,
+ weights_regularizer=regularizer)
+ self.assertEqual(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 2)
+ self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 2)
+
def testCreateFCWithoutActivation(self):
height, width = 3, 3
with self.test_session():
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 17e999b736..232763c758 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -199,34 +199,38 @@ class _Layer(object):
"""
if dtype is None:
dtype = self._dtype
+ existing_variables = set(tf_variables.global_variables())
variable = variable_getter(name,
shape=shape,
initializer=initializer,
dtype=dtype,
trainable=trainable and self.trainable)
# TODO(sguada) fix name = variable.op.name
- if trainable:
- self._trainable_variables.append(variable)
- else:
- self._non_trainable_variables.append(variable)
- if regularizer and not self._reuse:
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- with ops.colocate_with(v.op):
+ if regularizer:
+ if not self._reuse and variable not in existing_variables:
+ # To match the behavior of tf.get_variable(), we only
+ # apply regularization if the variable is newly created.
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ with ops.colocate_with(v.op):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(v)
+ if regularization is not None:
+ self._losses.append(regularization)
+ _add_elements_to_collection(
+ regularization, ops.GraphKeys.REGULARIZATION_LOSSES)
+ else:
+ with ops.colocate_with(variable.op):
with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
+ regularization = regularizer(variable)
if regularization is not None:
self._losses.append(regularization)
_add_elements_to_collection(
regularization, ops.GraphKeys.REGULARIZATION_LOSSES)
- else:
- with ops.colocate_with(variable.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(variable)
- if regularization is not None:
- self._losses.append(regularization)
- _add_elements_to_collection(
- regularization, ops.GraphKeys.REGULARIZATION_LOSSES)
+ if trainable:
+ self._trainable_variables.append(variable)
+ else:
+ self._non_trainable_variables.append(variable)
return variable
def __call__(self, inputs, **kwargs):
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index d3b0ee1550..dd60b437b6 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -151,6 +151,18 @@ class DenseTest(tf.test.TestCase):
self.assertEqual(len(loss_keys), 1)
self.assertListEqual(dense.losses, loss_keys)
+ def testWeightsRegularizerWithReuse(self):
+ regularizer = lambda x: tf.reduce_sum(x) * 1e-3
+ inputs = tf.random_uniform((5, 3), seed=1)
+ _ = core_layers.dense(inputs, 2, name='my_dense',
+ weights_regularizer=regularizer)
+ self.assertEqual(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+ _ = core_layers.dense(inputs, 2, name='my_dense',
+ weights_regularizer=regularizer, reuse=True)
+ self.assertEqual(
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+
def testBiasRegularizer(self):
regularizer = lambda x: tf.reduce_sum(x) * 1e-3
dense = core_layers.Dense(2, name='my_dense',