aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/base.py')
-rw-r--r--tensorflow/python/layers/base.py41
1 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index b8969a41ab..cf13b52617 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -152,10 +152,17 @@ class Layer(base_layer.Layer):
scope, default_name=self._base_name) as captured_scope:
self._scope = captured_scope
- def add_weight(self, name, shape, dtype=None,
- initializer=None, regularizer=None,
- trainable=True, constraint=None,
+ def add_weight(self,
+ name,
+ shape,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=None,
+ constraint=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
partitioner=None):
"""Adds a new variable to the layer, or gets an existing one; returns it.
@@ -170,9 +177,19 @@ class Layer(base_layer.Layer):
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
constraint: constraint instance (callable).
use_resource: Whether to use `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
partitioner: (optional) partitioner instance (callable). If
provided, when the requested variable is created it will be split
into multiple partitions according to `partitioner`. In this case,
@@ -190,7 +207,21 @@ class Layer(base_layer.Layer):
Raises:
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
+ ValueError: When trainable has been set to True with synchronization
+ set as `ON_READ`.
"""
+ if synchronization == vs.VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.')
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
def _should_add_regularizer(variable, existing_variable_set):
if isinstance(variable, tf_variables.PartitionedVariable):
@@ -240,6 +271,8 @@ class Layer(base_layer.Layer):
constraint=constraint,
partitioner=partitioner,
use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation,
getter=vs.get_variable)
if regularizer: