diff options
author | Yifei Feng <yifeif@google.com> | 2018-05-24 19:12:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-24 19:15:01 -0700 |
commit | b59833c3fd91511b33255369016868e4ae6cda2e (patch) | |
tree | ecbd70cfd3abb5d934f6eb4b7280a35e8589f5cf /tensorflow/python/layers | |
parent | 2b99d9cbc7166efedaff9eee11744348da30fc8a (diff) |
Merge changes from github.
Revert #18413. Too many internal test failures due to the name scope change caused by this change.
Revert #18192. Cannot use re2::StringPiece internally. Need alternative for set call. Will pull and clean this up in a separate change.
PiperOrigin-RevId: 197991247
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/base.py | 14 | ||||
-rw-r--r-- | tensorflow/python/layers/base_test.py | 16 |
2 files changed, 28 insertions, 2 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 340c34fc5e..eda036ece4 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -191,6 +191,16 @@ class Layer(base_layer.Layer): RuntimeError: If called with partioned variable regularization and eager execution is enabled. """ + + def _should_add_regularizer(variable, existing_variable_set): + if isinstance(variable, tf_variables.PartitionedVariable): + for var in variable: + if var in existing_variable_set: + return False + return True + else: + return variable not in existing_variable_set + init_graph = None if not context.executing_eagerly(): default_graph = ops.get_default_graph() @@ -233,7 +243,8 @@ class Layer(base_layer.Layer): getter=vs.get_variable) if regularizer: - if context.executing_eagerly() or variable not in existing_variables: + if context.executing_eagerly() or _should_add_regularizer( + variable, existing_variables): self._handle_weight_regularization(name, variable, regularizer) if init_graph is not None: @@ -353,4 +364,3 @@ def _add_elements_to_collection(elements, collection_list): for element in elements: if element not in collection_set: collection.append(element) - diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index f08b552840..ab49e37b90 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -30,6 +30,7 @@ from tensorflow.python.layers import core as core_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -95,6 +96,21 @@ class BaseLayerTest(test.TestCase): regularizer=regularizer) self.assertEqual(len(layer.losses), 1) + def testReusePartitionedVaraiblesAndRegularizers(self): + regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 + partitioner = partitioned_variables.fixed_size_partitioner(3) + for reuse in [False, True]: + with variable_scope.variable_scope(variable_scope.get_variable_scope(), + partitioner=partitioner, + reuse=reuse): + layer = base_layers.Layer(name='my_layer') + variable = layer.add_variable( + 'reg_part_var', [4, 4], + initializer=init_ops.zeros_initializer(), + regularizer=regularizer) + self.assertEqual( + len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 3) + def testNoEagerActivityRegularizer(self): with context.eager_mode(): with self.assertRaisesRegexp(ValueError, 'activity_regularizer'): |