diff options
Diffstat (limited to 'tensorflow/python/ops/variable_scope.py')
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 87 |
1 files changed, 18 insertions, 69 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 77f67c18ee..aca44bcd44 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -191,36 +191,9 @@ class _ReuseMode(enum.Enum): # REUSE_TRUE = 3 -@tf_export("VariableSynchronization") -class VariableSynchronization(enum.Enum): - """Indicates when a distributed variable will be synced.""" - - # Indicates that the synchronization will be determined by the current - # `DistributionStrategy` (eg. With `MirroredStrategy` this would be - # `ON_WRITE`). - AUTO = 0 - - # Indicates that there will only be one copy of the variable, so there is no - # need to sync. - NONE = 1 - - # Indicates that the variable will be aggregated across devices - # every time it is updated. - ON_WRITE = 2 - - # Indicates that the variable will be aggregated across devices - # when it is read (eg. when checkpointing or when evaluating an op that uses - # the variable). - ON_READ = 3 - - -@tf_export("VariableAggregation") -class VariableAggregation(enum.Enum): - """Indicates how a distributed variable will be aggregated.""" - NONE = 0 - SUM = 1 - MEAN = 2 - +# TODO(apassos) remove these forwarding symbols. +VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name +VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name AUTO_REUSE = _ReuseMode.AUTO_REUSE tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE") @@ -2376,7 +2349,10 @@ def default_variable_creator(next_creator=None, **kwargs): validate_shape = kwargs.get("validate_shape", True) caching_device = kwargs.get("caching_device", None) name = kwargs.get("name", None) + variable_def = kwargs.get("variable_def", None) dtype = kwargs.get("dtype", None) + expected_shape = kwargs.get("expected_shape", None) + import_scope = kwargs.get("import_scope", None) constraint = kwargs.get("constraint", None) use_resource = kwargs.get("use_resource", None) @@ -2387,23 +2363,24 @@ def default_variable_creator(next_creator=None, **kwargs): if use_resource is None: use_resource = get_variable_scope().use_resource - if use_resource or (use_resource is None and context.executing_eagerly()): + use_resource = use_resource or context.executing_eagerly() + if use_resource: return resource_variable_ops.ResourceVariable( initial_value=initial_value, trainable=trainable, collections=collections, validate_shape=validate_shape, caching_device=caching_device, name=name, dtype=dtype, - constraint=constraint) - elif not use_resource and context.executing_eagerly(): - raise RuntimeError( - "VariableScope should use resource variable when eager execution is" - " enabled, but use_resource is False." - ) + constraint=constraint, variable_def=variable_def, + import_scope=import_scope) else: - return variables.Variable( + return variables.RefVariable( initial_value=initial_value, trainable=trainable, collections=collections, validate_shape=validate_shape, caching_device=caching_device, name=name, dtype=dtype, - constraint=constraint) + constraint=constraint, variable_def=variable_def, + expected_shape=expected_shape, import_scope=import_scope) + + +variables.default_variable_creator = default_variable_creator def _make_getter(captured_getter, captured_previous): @@ -2411,36 +2388,8 @@ def _make_getter(captured_getter, captured_previous): return lambda **kwargs: captured_getter(captured_previous, **kwargs) -def variable(initial_value=None, - trainable=None, - collections=None, - validate_shape=True, - caching_device=None, - name=None, - dtype=None, - constraint=None, - use_resource=None, - synchronization=VariableSynchronization.AUTO, - aggregation=VariableAggregation.NONE): - previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) - for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access - previous_getter = _make_getter(getter, previous_getter) - - # Reset `aggregation` that is explicitly set as `None` to the enum None value. - if aggregation is None: - aggregation = VariableAggregation.NONE - return previous_getter( - initial_value=initial_value, - trainable=trainable, - collections=collections, - validate_shape=validate_shape, - caching_device=caching_device, - name=name, - dtype=dtype, - constraint=constraint, - use_resource=use_resource, - synchronization=synchronization, - aggregation=aggregation) +# TODO(apassos) remove forwarding symbol +variable = variables.Variable @tf_contextlib.contextmanager |