aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/variable_scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/variable_scope.py')
-rw-r--r--tensorflow/python/ops/variable_scope.py87
1 files changed, 18 insertions, 69 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 77f67c18ee..aca44bcd44 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -191,36 +191,9 @@ class _ReuseMode(enum.Enum):
# REUSE_TRUE = 3
-@tf_export("VariableSynchronization")
-class VariableSynchronization(enum.Enum):
- """Indicates when a distributed variable will be synced."""
-
- # Indicates that the synchronization will be determined by the current
- # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
- # `ON_WRITE`).
- AUTO = 0
-
- # Indicates that there will only be one copy of the variable, so there is no
- # need to sync.
- NONE = 1
-
- # Indicates that the variable will be aggregated across devices
- # every time it is updated.
- ON_WRITE = 2
-
- # Indicates that the variable will be aggregated across devices
- # when it is read (eg. when checkpointing or when evaluating an op that uses
- # the variable).
- ON_READ = 3
-
-
-@tf_export("VariableAggregation")
-class VariableAggregation(enum.Enum):
- """Indicates how a distributed variable will be aggregated."""
- NONE = 0
- SUM = 1
- MEAN = 2
-
+# TODO(apassos) remove these forwarding symbols.
+VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name
+VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name
AUTO_REUSE = _ReuseMode.AUTO_REUSE
tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
@@ -2376,7 +2349,10 @@ def default_variable_creator(next_creator=None, **kwargs):
validate_shape = kwargs.get("validate_shape", True)
caching_device = kwargs.get("caching_device", None)
name = kwargs.get("name", None)
+ variable_def = kwargs.get("variable_def", None)
dtype = kwargs.get("dtype", None)
+ expected_shape = kwargs.get("expected_shape", None)
+ import_scope = kwargs.get("import_scope", None)
constraint = kwargs.get("constraint", None)
use_resource = kwargs.get("use_resource", None)
@@ -2387,23 +2363,24 @@ def default_variable_creator(next_creator=None, **kwargs):
if use_resource is None:
use_resource = get_variable_scope().use_resource
- if use_resource or (use_resource is None and context.executing_eagerly()):
+ use_resource = use_resource or context.executing_eagerly()
+ if use_resource:
return resource_variable_ops.ResourceVariable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
caching_device=caching_device, name=name, dtype=dtype,
- constraint=constraint)
- elif not use_resource and context.executing_eagerly():
- raise RuntimeError(
- "VariableScope should use resource variable when eager execution is"
- " enabled, but use_resource is False."
- )
+ constraint=constraint, variable_def=variable_def,
+ import_scope=import_scope)
else:
- return variables.Variable(
+ return variables.RefVariable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
caching_device=caching_device, name=name, dtype=dtype,
- constraint=constraint)
+ constraint=constraint, variable_def=variable_def,
+ expected_shape=expected_shape, import_scope=import_scope)
+
+
+variables.default_variable_creator = default_variable_creator
def _make_getter(captured_getter, captured_previous):
@@ -2411,36 +2388,8 @@ def _make_getter(captured_getter, captured_previous):
return lambda **kwargs: captured_getter(captured_previous, **kwargs)
-def variable(initial_value=None,
- trainable=None,
- collections=None,
- validate_shape=True,
- caching_device=None,
- name=None,
- dtype=None,
- constraint=None,
- use_resource=None,
- synchronization=VariableSynchronization.AUTO,
- aggregation=VariableAggregation.NONE):
- previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
- for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
- previous_getter = _make_getter(getter, previous_getter)
-
- # Reset `aggregation` that is explicitly set as `None` to the enum None value.
- if aggregation is None:
- aggregation = VariableAggregation.NONE
- return previous_getter(
- initial_value=initial_value,
- trainable=trainable,
- collections=collections,
- validate_shape=validate_shape,
- caching_device=caching_device,
- name=name,
- dtype=dtype,
- constraint=constraint,
- use_resource=use_resource,
- synchronization=synchronization,
- aggregation=aggregation)
+# TODO(apassos) remove forwarding symbol
+variable = variables.Variable
@tf_contextlib.contextmanager