aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/base_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/base_test.py')
-rw-r--r--tensorflow/python/layers/base_test.py26
1 files changed, 24 insertions, 2 deletions
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 298e96e711..d2443db665 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -90,12 +90,34 @@ class BaseLayerTest(test.TestCase):
# regularizers only supported in GRAPH mode.
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
- variable = layer.add_variable(
+ _ = layer.add_variable(
'reg_var', [2, 2],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)
self.assertEqual(len(layer.losses), 1)
+ # Test that sync `ON_READ` variables are defaulted to be non-trainable.
+ variable_3 = layer.add_variable(
+ 'sync_on_read_var', [2, 2],
+ initializer=init_ops.zeros_initializer(),
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ self.assertEqual(layer.non_trainable_variables, [variable_2, variable_3])
+
+ def testInvalidTrainableSynchronizationCombination(self):
+ layer = base_layers.Layer(name='my_layer')
+
+ with self.assertRaisesRegexp(
+ ValueError, 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.'):
+ _ = layer.add_variable(
+ 'v', [2, 2],
+ initializer=init_ops.zeros_initializer(),
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=True)
+
def testReusePartitionedVaraiblesAndRegularizers(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
partitioner = partitioned_variables.fixed_size_partitioner(3)
@@ -104,7 +126,7 @@ class BaseLayerTest(test.TestCase):
partitioner=partitioner,
reuse=reuse):
layer = base_layers.Layer(name='my_layer')
- variable = layer.add_variable(
+ _ = layer.add_variable(
'reg_part_var', [4, 4],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)