aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework/python/ops/variables.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/framework/python/ops/variables.py')
-rw-r--r--tensorflow/contrib/framework/python/ops/variables.py97
1 files changed, 75 insertions, 22 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index e8e3180019..322d5c335e 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as tf_saver
@@ -199,10 +200,20 @@ def global_variable(initial_value,
@contrib_add_arg_scope
-def variable(name, shape=None, dtype=None, initializer=None,
- regularizer=None, trainable=True, collections=None,
- caching_device=None, device=None,
- partitioner=None, custom_getter=None, use_resource=None):
+def variable(name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ device=None,
+ partitioner=None,
+ custom_getter=None,
+ use_resource=None,
+ synchronization=variables.VariableSynchronization.AUTO,
+ aggregation=variables.VariableAggregation.NONE):
"""Gets an existing variable with these parameters or creates a new one.
Args:
@@ -228,6 +239,15 @@ def variable(name, shape=None, dtype=None, initializer=None,
custom_getter: Callable that allows overwriting the internal
get_variable method and has to have the same signature.
use_resource: If `True` use a ResourceVariable instead of a Variable.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
Returns:
The created or existing variable.
@@ -242,21 +262,36 @@ def variable(name, shape=None, dtype=None, initializer=None,
getter = functools.partial(custom_getter,
reuse=variable_scope.get_variable_scope().reuse)
with ops.device(device or ''):
- return getter(name, shape=shape, dtype=dtype,
- initializer=initializer,
- regularizer=regularizer,
- trainable=trainable,
- collections=collections,
- caching_device=caching_device,
- partitioner=partitioner,
- use_resource=use_resource)
+ return getter(
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
@contrib_add_arg_scope
-def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
- regularizer=None, trainable=True, collections=None,
- caching_device=None, device=None, partitioner=None,
- custom_getter=None, use_resource=None):
+def model_variable(name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ device=None,
+ partitioner=None,
+ custom_getter=None,
+ use_resource=None,
+ synchronization=variables.VariableSynchronization.AUTO,
+ aggregation=variables.VariableAggregation.NONE):
"""Gets an existing model variable with these parameters or creates a new one.
Args:
@@ -283,18 +318,36 @@ def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
custom_getter: Callable that allows overwriting the internal
get_variable method and has to have the same signature.
use_resource: If `True` use a ResourceVariable instead of a Variable.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
Returns:
The created or existing variable.
"""
collections = list(collections or [])
collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES]
- var = variable(name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer,
- trainable=trainable, collections=collections,
- caching_device=caching_device, device=device,
- partitioner=partitioner, custom_getter=custom_getter,
- use_resource=use_resource)
+ var = variable(
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ device=device,
+ partitioner=partitioner,
+ custom_getter=custom_getter,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
return var