aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-05-24 19:12:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-24 19:15:01 -0700
commitb59833c3fd91511b33255369016868e4ae6cda2e (patch)
treeecbd70cfd3abb5d934f6eb4b7280a35e8589f5cf /tensorflow/python/layers
parent2b99d9cbc7166efedaff9eee11744348da30fc8a (diff)
Merge changes from github.
Revert #18413. Too many internal test failures due to the name scope change caused by this change. Revert #18192. Cannot use re2::StringPiece internally. Need alternative for set call. Will pull and clean this up in a separate change. PiperOrigin-RevId: 197991247
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/base.py14
-rw-r--r--tensorflow/python/layers/base_test.py16
2 files changed, 28 insertions, 2 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 340c34fc5e..eda036ece4 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -191,6 +191,16 @@ class Layer(base_layer.Layer):
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
"""
+
+ def _should_add_regularizer(variable, existing_variable_set):
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for var in variable:
+ if var in existing_variable_set:
+ return False
+ return True
+ else:
+ return variable not in existing_variable_set
+
init_graph = None
if not context.executing_eagerly():
default_graph = ops.get_default_graph()
@@ -233,7 +243,8 @@ class Layer(base_layer.Layer):
getter=vs.get_variable)
if regularizer:
- if context.executing_eagerly() or variable not in existing_variables:
+ if context.executing_eagerly() or _should_add_regularizer(
+ variable, existing_variables):
self._handle_weight_regularization(name, variable, regularizer)
if init_graph is not None:
@@ -353,4 +364,3 @@ def _add_elements_to_collection(elements, collection_list):
for element in elements:
if element not in collection_set:
collection.append(element)
-
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index f08b552840..ab49e37b90 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.layers import core as core_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
@@ -95,6 +96,21 @@ class BaseLayerTest(test.TestCase):
regularizer=regularizer)
self.assertEqual(len(layer.losses), 1)
+ def testReusePartitionedVaraiblesAndRegularizers(self):
+ regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
+ partitioner = partitioned_variables.fixed_size_partitioner(3)
+ for reuse in [False, True]:
+ with variable_scope.variable_scope(variable_scope.get_variable_scope(),
+ partitioner=partitioner,
+ reuse=reuse):
+ layer = base_layers.Layer(name='my_layer')
+ variable = layer.add_variable(
+ 'reg_part_var', [4, 4],
+ initializer=init_ops.zeros_initializer(),
+ regularizer=regularizer)
+ self.assertEqual(
+ len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 3)
+
def testNoEagerActivityRegularizer(self):
with context.eager_mode():
with self.assertRaisesRegexp(ValueError, 'activity_regularizer'):